Skip to main content

Python Matplotlib

Introduction

Data visualization is a crucial skill in data science that allows you to communicate insights effectively. Matplotlib is one of the most popular and powerful visualization libraries in Python. It provides a comprehensive set of tools for creating static, animated, and interactive visualizations in Python.

In this tutorial, you'll learn how to use Matplotlib to create various types of plots and customize them to effectively communicate your data. Whether you're analyzing trends, comparing categories, or exploring relationships, Matplotlib gives you the flexibility to create visualizations tailored to your needs.

Getting Started with Matplotlib

Installation

Before we begin, make sure you have Matplotlib installed. You can install it using pip:

bash
pip install matplotlib

Basic Structure of Matplotlib

Matplotlib is built on a hierarchy of objects. The key components are:

  • Figure: The overall window or page where everything is drawn
  • Axes: The actual plots that contain the data, titles, labels, etc.

Let's create your first plot:

python
import matplotlib.pyplot as plt
import numpy as np

# Create some data
x = np.linspace(0, 10, 100) # 100 points from 0 to 10
y = np.sin(x)

# Create a figure and axis
plt.figure(figsize=(8, 4)) # Width=8, height=4 inches
plt.plot(x, y)
plt.title('Sine Wave')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.grid(True)
plt.show()

This code produces:

Sine Wave Plot

Two Ways to Use Matplotlib

Matplotlib provides two main approaches for creating visualizations:

1. MATLAB-style Interface (Pyplot)

This approach is simpler and more convenient for basic plots:

python
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x))
plt.title('Sine and Cosine Waves')
plt.legend(['sin(x)', 'cos(x)'])
plt.show()

Sine and Cosine Plot

2. Object-Oriented Interface

This approach gives you more control and is better for complex plots:

python
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)

# Create figure and axis objects
fig, ax = plt.subplots(figsize=(8, 4))

# Use the axis object to plot data
ax.plot(x, np.sin(x), label='sin(x)')
ax.plot(x, np.cos(x), label='cos(x)')

# Customize using the axis object
ax.set_title('Sine and Cosine Waves')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.legend()
ax.grid(True)

plt.show()

The output will be similar to the previous example but offers more flexibility for customization.

Common Plot Types

Line Plot

We've already seen line plots in the examples above. They're great for showing trends over time.

Bar Plot

Bar plots are useful for comparing quantities across categories:

python
import matplotlib.pyplot as plt
import numpy as np

# Data
categories = ['Category A', 'Category B', 'Category C', 'Category D']
values = [15, 34, 23, 17]

# Creating bar plot
fig, ax = plt.subplots(figsize=(8, 4))
bars = ax.bar(categories, values, color='skyblue')

# Adding value labels on top of bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 0.5,
f'{height}', ha='center', va='bottom')

ax.set_title('Bar Plot Example')
ax.set_xlabel('Categories')
ax.set_ylabel('Values')
ax.grid(axis='y', linestyle='--', alpha=0.7)

plt.show()

Bar Plot Example

Scatter Plot

Scatter plots help visualize relationships between two variables:

python
import matplotlib.pyplot as plt
import numpy as np

# Generate random data
np.random.seed(42)
x = np.random.rand(50)
y = x + np.random.normal(0, 0.2, 50) # y is correlated with x plus some noise

# Create scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(x, y, c='purple', alpha=0.6, edgecolors='black')
plt.title('Scatter Plot Example')
plt.xlabel('X Variable')
plt.ylabel('Y Variable')
plt.grid(True, alpha=0.3)

# Add a trend line
z = np.polyfit(x, y, 1)
p = np.poly1d(z)
plt.plot(x, p(x), 'r--', lw=2)

plt.tight_layout()
plt.show()

Scatter Plot Example

Histogram

Histograms show the distribution of a dataset:

python
import matplotlib.pyplot as plt
import numpy as np

# Generate data with normal distribution
data = np.random.normal(100, 15, 1000) # mean=100, std=15, size=1000

# Create histogram
plt.figure(figsize=(10, 6))
plt.hist(data, bins=30, color='lightblue', edgecolor='black')
plt.axvline(data.mean(), color='red', linestyle='dashed', linewidth=1)
plt.title('Histogram of Normal Distribution')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.text(data.mean()+5, plt.ylim()[1]*0.9, f'Mean: {data.mean():.1f}',
color='red')
plt.grid(alpha=0.3)
plt.show()

Histogram Example

Pie Chart

Pie charts show proportion and percentage:

python
import matplotlib.pyplot as plt

# Data
labels = ['Product A', 'Product B', 'Product C', 'Product D']
sizes = [15, 30, 45, 10]
explode = (0, 0.1, 0, 0) # explode the 2nd slice

# Create pie chart
plt.figure(figsize=(8, 8))
plt.pie(sizes, explode=explode, labels=labels, autopct='%1.1f%%',
shadow=True, startangle=90)
plt.axis('equal') # Equal aspect ratio ensures pie is drawn as a circle
plt.title('Market Share')
plt.show()

Pie Chart Example

Subplots

Subplots allow you to create multiple plots in one figure:

python
import matplotlib.pyplot as plt
import numpy as np

# Create some data
x = np.linspace(0, 5, 100)

# Create a figure with 2x2 grid of subplots
fig, axs = plt.subplots(2, 2, figsize=(10, 8))

# Plot on each subplot
axs[0, 0].plot(x, np.sin(x), 'r-')
axs[0, 0].set_title('Sine Function')

axs[0, 1].plot(x, np.cos(x), 'g-')
axs[0, 1].set_title('Cosine Function')

axs[1, 0].plot(x, np.sin(x) * np.cos(x), 'b-')
axs[1, 0].set_title('Sin × Cos')

axs[1, 1].plot(x, np.sin(x) + np.cos(x), 'm-')
axs[1, 1].set_title('Sin + Cos')

# Add a main title
fig.suptitle('Different Trigonometric Functions', fontsize=16)

# Add space between subplots
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

Subplots Example

Customizing Your Plots

Colors and Styles

Matplotlib offers various color specifications and style options:

python
import matplotlib.pyplot as plt
import numpy as np

# Data
x = np.linspace(0, 10, 100)

# Plot with different colors and styles
plt.figure(figsize=(10, 6))

# Different line styles and colors
plt.plot(x, np.sin(x), 'r-', label='red solid')
plt.plot(x, np.sin(x+1), 'b--', label='blue dashed')
plt.plot(x, np.sin(x+2), 'g-.', label='green dashdot')
plt.plot(x, np.sin(x+3), 'mo:', label='magenta dotted with circles')

# Add grid, legend and labels
plt.grid(True, alpha=0.3)
plt.legend(title='Line Styles')
plt.title('Different Line Colors and Styles')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')

plt.tight_layout()
plt.show()

Colors and Styles Example

Adding Annotations

Annotations help highlight important aspects of your data:

python
import matplotlib.pyplot as plt
import numpy as np

# Data
x = np.linspace(0, 10, 100)
y = np.sin(x)

# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2)

# Highlight a specific point
max_x = np.pi/2
max_y = np.sin(max_x)
plt.plot(max_x, max_y, 'ro', markersize=10)

# Add text annotation with arrow
plt.annotate('Maximum value (π/2, 1)',
xy=(max_x, max_y),
xytext=(max_x+1, max_y-0.3),
arrowprops=dict(facecolor='black', shrink=0.05, width=1.5),
fontsize=12)

# Add other annotations
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
plt.axvline(x=max_x, color='r', linestyle='--', alpha=0.3)

plt.title('Sine Function with Annotations')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Annotations Example

Real-World Data Science Example

Let's create a more complex visualization using real-world data. We'll analyze a simple dataset about monthly sales:

python
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Create sample sales data
months = pd.date_range(start='2023-01-01', periods=12, freq='M')
product_a = [12000, 13500, 16000, 15000, 17000, 19500, 21000, 22000, 20500, 18000, 16500, 22500]
product_b = [8000, 8800, 9500, 11000, 12500, 14000, 13800, 13000, 14500, 15500, 16000, 18000]

# Create a DataFrame
sales_data = pd.DataFrame({
'Month': months,
'Product A': product_a,
'Product B': product_b
})

# Calculate rolling average (3-month)
sales_data['Product A (3MA)'] = sales_data['Product A'].rolling(3).mean()
sales_data['Product B (3MA)'] = sales_data['Product B'].rolling(3).mean()

# Create figure and axis
fig, ax = plt.subplots(figsize=(12, 6))

# Plot actual sales with markers
ax.plot(sales_data['Month'], sales_data['Product A'], 'o-', label='Product A', color='blue', alpha=0.7)
ax.plot(sales_data['Month'], sales_data['Product B'], 's-', label='Product B', color='green', alpha=0.7)

# Plot moving averages
ax.plot(sales_data['Month'], sales_data['Product A (3MA)'], '--', label='Product A (3-Month Avg)', color='darkblue')
ax.plot(sales_data['Month'], sales_data['Product B (3MA)'], '--', label='Product B (3-Month Avg)', color='darkgreen')

# Highlight best sales month for Product A
best_month_idx = sales_data['Product A'].idxmax()
best_month = sales_data.loc[best_month_idx]
ax.annotate(f'Peak Sales: ${best_month["Product A"]:,.0f}',
xy=(best_month['Month'], best_month['Product A']),
xytext=(best_month['Month'], best_month['Product A'] + 2000),
arrowprops=dict(facecolor='black', shrink=0.05),
fontsize=10)

# Set title and labels
ax.set_title('Monthly Sales Comparison (2023)', fontsize=16)
ax.set_xlabel('Month')
ax.set_ylabel('Sales (USD)')

# Format x-axis to show month names
ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%b'))

# Add grid, legend and thousand separator for y-axis
ax.grid(True, alpha=0.3)
ax.legend()
ax.get_yaxis().set_major_formatter(plt.matplotlib.ticker.StrMethodFormatter('${x:,.0f}'))

# Add total annual sales as text
total_a = sales_data['Product A'].sum()
total_b = sales_data['Product B'].sum()
ax.text(0.02, 0.95, f'Annual Sales:\nProduct A: ${total_a:,.0f}\nProduct B: ${total_b:,.0f}',
transform=ax.transAxes, bbox=dict(facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

Real-world Sales Analysis Example

Saving Your Plots

To save your visualizations as files:

python
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.figure(figsize=(8, 4))
plt.plot(x, y)
plt.title('Sine Wave')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.grid(True)

# Save in different formats
plt.savefig('sine_wave.png', dpi=300) # PNG with high resolution
plt.savefig('sine_wave.pdf') # PDF format (vector graphics)
plt.savefig('sine_wave.svg') # SVG format (vector graphics)

Customizing with Matplotlib Styles

Matplotlib comes with predefined styles that can quickly enhance the appearance of your plots:

python
import matplotlib.pyplot as plt
import numpy as np

# Available styles
print(plt.style.available)

# Use a specific style
plt.style.use('seaborn-darkgrid')

# Create plot with the selected style
x = np.linspace(0, 10, 100)
plt.figure(figsize=(8, 4))
plt.plot(x, np.sin(x), label='Sine')
plt.plot(x, np.cos(x), label='Cosine')
plt.title('Styled Plot')
plt.legend()
plt.show()

# Reset to default style
plt.style.use('default')

Summary

Matplotlib is a versatile and powerful library for data visualization in Python. In this tutorial, you've learned:

  • How to create basic plots using both the pyplot interface and the object-oriented approach
  • Various types of plots including line, bar, scatter, histogram, and pie charts
  • How to create and customize subplots
  • Ways to style and annotate your visualizations
  • Working with real-world data to create informative visualizations
  • Saving plots in different formats

With these skills, you can now effectively communicate insights from your data science projects through compelling visualizations.

Additional Resources

To further enhance your Matplotlib skills:

  1. Official Documentation: Visit the Matplotlib Documentation for complete reference
  2. Matplotlib Gallery: Explore the Matplotlib Gallery for inspiration and examples
  3. Matplotlib Cheat Sheets: Download the Matplotlib Cheat Sheet for quick reference

Exercises

Practice your Matplotlib skills with these exercises:

  1. Create a scatter plot of randomly generated data with a color gradient based on a third variable
  2. Make a stacked bar chart showing quarterly sales for different product categories
  3. Create a visualization that shows the correlation matrix of a dataset of your choice
  4. Create a custom visualization that combines multiple plot types (e.g., line and bar)
  5. Recreate a visualization from a newspaper or magazine article using Matplotlib

Remember that visualization is both an art and a science. The best way to improve is through practice and experimentation!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)