Python Matplotlib
Introduction
Data visualization is a crucial skill in data science that allows you to communicate insights effectively. Matplotlib is one of the most popular and powerful visualization libraries in Python. It provides a comprehensive set of tools for creating static, animated, and interactive visualizations in Python.
In this tutorial, you'll learn how to use Matplotlib to create various types of plots and customize them to effectively communicate your data. Whether you're analyzing trends, comparing categories, or exploring relationships, Matplotlib gives you the flexibility to create visualizations tailored to your needs.
Getting Started with Matplotlib
Installation
Before we begin, make sure you have Matplotlib installed. You can install it using pip:
pip install matplotlib
Basic Structure of Matplotlib
Matplotlib is built on a hierarchy of objects. The key components are:
- Figure: The overall window or page where everything is drawn
- Axes: The actual plots that contain the data, titles, labels, etc.
Let's create your first plot:
import matplotlib.pyplot as plt
import numpy as np
# Create some data
x = np.linspace(0, 10, 100) # 100 points from 0 to 10
y = np.sin(x)
# Create a figure and axis
plt.figure(figsize=(8, 4)) # Width=8, height=4 inches
plt.plot(x, y)
plt.title('Sine Wave')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.grid(True)
plt.show()
This code produces:
Two Ways to Use Matplotlib
Matplotlib provides two main approaches for creating visualizations:
1. MATLAB-style Interface (Pyplot)
This approach is simpler and more convenient for basic plots:
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x))
plt.title('Sine and Cosine Waves')
plt.legend(['sin(x)', 'cos(x)'])
plt.show()
2. Object-Oriented Interface
This approach gives you more control and is better for complex plots:
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
# Create figure and axis objects
fig, ax = plt.subplots(figsize=(8, 4))
# Use the axis object to plot data
ax.plot(x, np.sin(x), label='sin(x)')
ax.plot(x, np.cos(x), label='cos(x)')
# Customize using the axis object
ax.set_title('Sine and Cosine Waves')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.legend()
ax.grid(True)
plt.show()
The output will be similar to the previous example but offers more flexibility for customization.
Common Plot Types
Line Plot
We've already seen line plots in the examples above. They're great for showing trends over time.
Bar Plot
Bar plots are useful for comparing quantities across categories:
import matplotlib.pyplot as plt
import numpy as np
# Data
categories = ['Category A', 'Category B', 'Category C', 'Category D']
values = [15, 34, 23, 17]
# Creating bar plot
fig, ax = plt.subplots(figsize=(8, 4))
bars = ax.bar(categories, values, color='skyblue')
# Adding value labels on top of bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 0.5,
f'{height}', ha='center', va='bottom')
ax.set_title('Bar Plot Example')
ax.set_xlabel('Categories')
ax.set_ylabel('Values')
ax.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()
Scatter Plot
Scatter plots help visualize relationships between two variables:
import matplotlib.pyplot as plt
import numpy as np
# Generate random data
np.random.seed(42)
x = np.random.rand(50)
y = x + np.random.normal(0, 0.2, 50) # y is correlated with x plus some noise
# Create scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(x, y, c='purple', alpha=0.6, edgecolors='black')
plt.title('Scatter Plot Example')
plt.xlabel('X Variable')
plt.ylabel('Y Variable')
plt.grid(True, alpha=0.3)
# Add a trend line
z = np.polyfit(x, y, 1)
p = np.poly1d(z)
plt.plot(x, p(x), 'r--', lw=2)
plt.tight_layout()
plt.show()
Histogram
Histograms show the distribution of a dataset:
import matplotlib.pyplot as plt
import numpy as np
# Generate data with normal distribution
data = np.random.normal(100, 15, 1000) # mean=100, std=15, size=1000
# Create histogram
plt.figure(figsize=(10, 6))
plt.hist(data, bins=30, color='lightblue', edgecolor='black')
plt.axvline(data.mean(), color='red', linestyle='dashed', linewidth=1)
plt.title('Histogram of Normal Distribution')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.text(data.mean()+5, plt.ylim()[1]*0.9, f'Mean: {data.mean():.1f}',
color='red')
plt.grid(alpha=0.3)
plt.show()
Pie Chart
Pie charts show proportion and percentage:
import matplotlib.pyplot as plt
# Data
labels = ['Product A', 'Product B', 'Product C', 'Product D']
sizes = [15, 30, 45, 10]
explode = (0, 0.1, 0, 0) # explode the 2nd slice
# Create pie chart
plt.figure(figsize=(8, 8))
plt.pie(sizes, explode=explode, labels=labels, autopct='%1.1f%%',
shadow=True, startangle=90)
plt.axis('equal') # Equal aspect ratio ensures pie is drawn as a circle
plt.title('Market Share')
plt.show()
Subplots
Subplots allow you to create multiple plots in one figure:
import matplotlib.pyplot as plt
import numpy as np
# Create some data
x = np.linspace(0, 5, 100)
# Create a figure with 2x2 grid of subplots
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
# Plot on each subplot
axs[0, 0].plot(x, np.sin(x), 'r-')
axs[0, 0].set_title('Sine Function')
axs[0, 1].plot(x, np.cos(x), 'g-')
axs[0, 1].set_title('Cosine Function')
axs[1, 0].plot(x, np.sin(x) * np.cos(x), 'b-')
axs[1, 0].set_title('Sin × Cos')
axs[1, 1].plot(x, np.sin(x) + np.cos(x), 'm-')
axs[1, 1].set_title('Sin + Cos')
# Add a main title
fig.suptitle('Different Trigonometric Functions', fontsize=16)
# Add space between subplots
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
Customizing Your Plots
Colors and Styles
Matplotlib offers various color specifications and style options:
import matplotlib.pyplot as plt
import numpy as np
# Data
x = np.linspace(0, 10, 100)
# Plot with different colors and styles
plt.figure(figsize=(10, 6))
# Different line styles and colors
plt.plot(x, np.sin(x), 'r-', label='red solid')
plt.plot(x, np.sin(x+1), 'b--', label='blue dashed')
plt.plot(x, np.sin(x+2), 'g-.', label='green dashdot')
plt.plot(x, np.sin(x+3), 'mo:', label='magenta dotted with circles')
# Add grid, legend and labels
plt.grid(True, alpha=0.3)
plt.legend(title='Line Styles')
plt.title('Different Line Colors and Styles')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.tight_layout()
plt.show()
Adding Annotations
Annotations help highlight important aspects of your data:
import matplotlib.pyplot as plt
import numpy as np
# Data
x = np.linspace(0, 10, 100)
y = np.sin(x)
# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2)
# Highlight a specific point
max_x = np.pi/2
max_y = np.sin(max_x)
plt.plot(max_x, max_y, 'ro', markersize=10)
# Add text annotation with arrow
plt.annotate('Maximum value (π/2, 1)',
xy=(max_x, max_y),
xytext=(max_x+1, max_y-0.3),
arrowprops=dict(facecolor='black', shrink=0.05, width=1.5),
fontsize=12)
# Add other annotations
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
plt.axvline(x=max_x, color='r', linestyle='--', alpha=0.3)
plt.title('Sine Function with Annotations')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Real-World Data Science Example
Let's create a more complex visualization using real-world data. We'll analyze a simple dataset about monthly sales:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Create sample sales data
months = pd.date_range(start='2023-01-01', periods=12, freq='M')
product_a = [12000, 13500, 16000, 15000, 17000, 19500, 21000, 22000, 20500, 18000, 16500, 22500]
product_b = [8000, 8800, 9500, 11000, 12500, 14000, 13800, 13000, 14500, 15500, 16000, 18000]
# Create a DataFrame
sales_data = pd.DataFrame({
'Month': months,
'Product A': product_a,
'Product B': product_b
})
# Calculate rolling average (3-month)
sales_data['Product A (3MA)'] = sales_data['Product A'].rolling(3).mean()
sales_data['Product B (3MA)'] = sales_data['Product B'].rolling(3).mean()
# Create figure and axis
fig, ax = plt.subplots(figsize=(12, 6))
# Plot actual sales with markers
ax.plot(sales_data['Month'], sales_data['Product A'], 'o-', label='Product A', color='blue', alpha=0.7)
ax.plot(sales_data['Month'], sales_data['Product B'], 's-', label='Product B', color='green', alpha=0.7)
# Plot moving averages
ax.plot(sales_data['Month'], sales_data['Product A (3MA)'], '--', label='Product A (3-Month Avg)', color='darkblue')
ax.plot(sales_data['Month'], sales_data['Product B (3MA)'], '--', label='Product B (3-Month Avg)', color='darkgreen')
# Highlight best sales month for Product A
best_month_idx = sales_data['Product A'].idxmax()
best_month = sales_data.loc[best_month_idx]
ax.annotate(f'Peak Sales: ${best_month["Product A"]:,.0f}',
xy=(best_month['Month'], best_month['Product A']),
xytext=(best_month['Month'], best_month['Product A'] + 2000),
arrowprops=dict(facecolor='black', shrink=0.05),
fontsize=10)
# Set title and labels
ax.set_title('Monthly Sales Comparison (2023)', fontsize=16)
ax.set_xlabel('Month')
ax.set_ylabel('Sales (USD)')
# Format x-axis to show month names
ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%b'))
# Add grid, legend and thousand separator for y-axis
ax.grid(True, alpha=0.3)
ax.legend()
ax.get_yaxis().set_major_formatter(plt.matplotlib.ticker.StrMethodFormatter('${x:,.0f}'))
# Add total annual sales as text
total_a = sales_data['Product A'].sum()
total_b = sales_data['Product B'].sum()
ax.text(0.02, 0.95, f'Annual Sales:\nProduct A: ${total_a:,.0f}\nProduct B: ${total_b:,.0f}',
transform=ax.transAxes, bbox=dict(facecolor='white', alpha=0.8))
plt.tight_layout()
plt.show()
Saving Your Plots
To save your visualizations as files:
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.sin(x)
plt.figure(figsize=(8, 4))
plt.plot(x, y)
plt.title('Sine Wave')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.grid(True)
# Save in different formats
plt.savefig('sine_wave.png', dpi=300) # PNG with high resolution
plt.savefig('sine_wave.pdf') # PDF format (vector graphics)
plt.savefig('sine_wave.svg') # SVG format (vector graphics)
Customizing with Matplotlib Styles
Matplotlib comes with predefined styles that can quickly enhance the appearance of your plots:
import matplotlib.pyplot as plt
import numpy as np
# Available styles
print(plt.style.available)
# Use a specific style
plt.style.use('seaborn-darkgrid')
# Create plot with the selected style
x = np.linspace(0, 10, 100)
plt.figure(figsize=(8, 4))
plt.plot(x, np.sin(x), label='Sine')
plt.plot(x, np.cos(x), label='Cosine')
plt.title('Styled Plot')
plt.legend()
plt.show()
# Reset to default style
plt.style.use('default')
Summary
Matplotlib is a versatile and powerful library for data visualization in Python. In this tutorial, you've learned:
- How to create basic plots using both the pyplot interface and the object-oriented approach
- Various types of plots including line, bar, scatter, histogram, and pie charts
- How to create and customize subplots
- Ways to style and annotate your visualizations
- Working with real-world data to create informative visualizations
- Saving plots in different formats
With these skills, you can now effectively communicate insights from your data science projects through compelling visualizations.
Additional Resources
To further enhance your Matplotlib skills:
- Official Documentation: Visit the Matplotlib Documentation for complete reference
- Matplotlib Gallery: Explore the Matplotlib Gallery for inspiration and examples
- Matplotlib Cheat Sheets: Download the Matplotlib Cheat Sheet for quick reference
Exercises
Practice your Matplotlib skills with these exercises:
- Create a scatter plot of randomly generated data with a color gradient based on a third variable
- Make a stacked bar chart showing quarterly sales for different product categories
- Create a visualization that shows the correlation matrix of a dataset of your choice
- Create a custom visualization that combines multiple plot types (e.g., line and bar)
- Recreate a visualization from a newspaper or magazine article using Matplotlib
Remember that visualization is both an art and a science. The best way to improve is through practice and experimentation!
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)