Python Matplotlib Basics
Introduction
Matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python. When working with PyTorch for machine learning, data visualization becomes an essential skill for understanding your data and model performance. This tutorial will cover the fundamentals of Matplotlib, providing you with the tools to create effective visualizations for your PyTorch projects.
Why Matplotlib for PyTorch?
Before diving into PyTorch, understanding how to visualize your data is crucial because:
- It helps you understand the distribution and relationships in your training data
- It allows you to visualize model performance metrics (accuracy, loss curves)
- It enables you to interpret and communicate your results effectively
Getting Started with Matplotlib
Installation
First, make sure Matplotlib is installed in your environment:
pip install matplotlib
Importing Matplotlib
The most common way to import Matplotlib is as follows:
import matplotlib.pyplot as plt
import numpy as np # We'll use NumPy for data generation
Basic Plotting
Creating a Simple Line Plot
Let's create our first plot - a simple line graph:
# Generate data
x = np.linspace(0, 10, 100) # 100 points from 0 to 10
y = np.sin(x)
# Create a figure and axis
plt.figure(figsize=(8, 4)) # width, height in inches
plt.plot(x, y)
plt.title('Sine Wave')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True)
plt.show()
Output:
This example demonstrates the basic workflow:
- Generate or load your data
- Create a figure (optional but recommended to control size)
- Plot your data with the appropriate function
- Add labels and titles
- Display the plot
Multiple Lines on the Same Plot
You can add multiple lines to the same plot, which is useful for comparing different data:
plt.figure(figsize=(10, 5))
plt.plot(x, np.sin(x), label='sin(x)')
plt.plot(x, np.cos(x), label='cos(x)')
plt.plot(x, np.sin(x) * np.cos(x), label='sin(x)cos(x)')
plt.title('Trigonometric Functions')
plt.xlabel('x')
plt.ylabel('y')
plt.legend() # Add a legend
plt.grid(True)
plt.show()
Output:
Understanding Figure and Axes
For more complex visualizations, understanding the difference between figure and axes is important:
# Create a figure and a set of subplots
fig, ax = plt.subplots(figsize=(8, 4))
# Plot on the axes
ax.plot(x, np.sin(x))
ax.set_title('Sine Wave using Object-Oriented Approach')
ax.set_xlabel('x')
ax.set_ylabel('sin(x)')
ax.grid(True)
plt.show()
This object-oriented approach gives you more control and is recommended for complex plots.
Common Plot Types for Data Analysis
Scatter Plots
Scatter plots are useful for showing relationships between two variables:
# Generate random data
np.random.seed(42)
x = np.random.rand(50)
y = x + np.random.normal(0, 0.2, 50) # y = x + noise
plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=0.7, s=100)
plt.title('Scatter Plot Example')
plt.xlabel('x-values')
plt.ylabel('y-values')
plt.grid(True)
plt.show()
Output:
Histograms
Histograms help visualize the distribution of a dataset:
# Generate normal distributed data
data = np.random.normal(0, 1, 1000)
plt.figure(figsize=(8, 6))
plt.hist(data, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
plt.title('Histogram of Normal Distribution')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()
Output:
Bar Charts
Bar charts are excellent for comparing discrete categories:
categories = ['Category A', 'Category B', 'Category C', 'Category D']
values = [25, 40, 30, 55]
plt.figure(figsize=(8, 6))
plt.bar(categories, values, color='salmon', width=0.6)
plt.title('Bar Chart Example')
plt.xlabel('Category')
plt.ylabel('Value')
plt.ylim(0, 60) # Set y-axis limits
plt.grid(True, axis='y') # Grid lines only on y-axis
plt.show()
Output:
Subplots - Creating Multiple Plots
When analyzing data, you'll often want to compare multiple visualizations side by side:
# Create a figure with 2x2 subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Plot 1: Line plot
axes[0, 0].plot(x, np.sin(x))
axes[0, 0].set_title('Sine Wave')
axes[0, 0].set_xlabel('x')
axes[0, 0].set_ylabel('sin(x)')
axes[0, 0].grid(True)
# Plot 2: Scatter plot
axes[0, 1].scatter(x, y, alpha=0.7)
axes[0, 1].set_title('Scatter Plot')
axes[0, 1].set_xlabel('x')
axes[0, 1].set_ylabel('y')
axes[0, 1].grid(True)
# Plot 3: Histogram
axes[1, 0].hist(data, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
axes[1, 0].set_title('Histogram')
axes[1, 0].set_xlabel('Value')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].grid(True)
# Plot 4: Bar chart
axes[1, 1].bar(categories, values, color='salmon')
axes[1, 1].set_title('Bar Chart')
axes[1, 1].set_xlabel('Category')
axes[1, 1].set_ylabel('Value')
axes[1, 1].grid(True, axis='y')
# Adjust layout
plt.tight_layout()
plt.show()
Output:
Customizing Plots
Setting Colors, Line Styles, and Markers
Matplotlib offers extensive customization options:
plt.figure(figsize=(10, 6))
# Different line styles, colors, and markers
plt.plot(x, np.sin(x), 'r--', label='sin(x) - red dashed')
plt.plot(x, np.cos(x), 'b-', linewidth=2, label='cos(x) - blue solid')
plt.plot(x, np.sin(x) * np.cos(x), 'g-.',
marker='o', markersize=5, markevery=10,
label='sin(x)cos(x) - green dash-dot with circles')
plt.title('Customized Line Styles and Colors')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.grid(True)
plt.show()
Output:
Adding Text and Annotations
You can add text and annotations to highlight specific points:
plt.figure(figsize=(10, 5))
plt.plot(x, np.sin(x))
# Add text at a specific point
plt.text(5, 0.5, 'Local Maximum', fontsize=12)
# Add an arrow pointing to a specific point
plt.annotate('Local Minimum',
xy=(8, np.sin(8)), # Point to annotate
xytext=(9, -0.5), # Position of the text
arrowprops=dict(facecolor='black', shrink=0.05, width=1.5))
plt.title('Sine Wave with Annotations')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True)
plt.show()
Output:
Real-World Example: Visualizing PyTorch Training Progress
Here's an example of how Matplotlib would be used in a PyTorch project to visualize training and validation loss:
# Sample data (this would come from your PyTorch training loop)
epochs = list(range(1, 21))
train_loss = [0.8, 0.65, 0.55, 0.48, 0.42, 0.38, 0.35, 0.33, 0.31, 0.29,
0.28, 0.27, 0.26, 0.25, 0.24, 0.24, 0.23, 0.23, 0.22, 0.22]
val_loss = [0.9, 0.7, 0.62, 0.58, 0.54, 0.51, 0.5, 0.49, 0.48, 0.47,
0.47, 0.46, 0.46, 0.45, 0.45, 0.45, 0.44, 0.44, 0.44, 0.44]
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_loss, 'b-', label='Training Loss')
plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
plt.title('Training and Validation Loss over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Adding an annotation to indicate potential overfitting
plt.annotate('Potential Overfitting',
xy=(15, val_loss[14]),
xytext=(16, val_loss[14] + 0.1),
arrowprops=dict(facecolor='black', shrink=0.05))
plt.show()
Output:
Saving Figures
It's often necessary to save your visualizations for reports or presentations:
plt.figure(figsize=(10, 6))
plt.plot(x, np.sin(x), label='sin(x)')
plt.title('Sine Wave to be Saved')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.legend()
plt.grid(True)
# Save the figure in different formats
plt.savefig('sine_wave.png', dpi=300, bbox_inches='tight') # PNG format
plt.savefig('sine_wave.pdf', bbox_inches='tight') # PDF format
plt.show()
The bbox_inches='tight'
parameter ensures that all parts of the figure (including labels) are included in the saved file.
Summary
In this tutorial, we've covered the basics of Matplotlib for data visualization, which is crucial for working with PyTorch and machine learning:
- Creating basic plots (line, scatter, histogram, bar charts)
- Working with figures and axes
- Creating multiple subplots
- Customizing plots with colors, styles, and annotations
- Real-world application for visualizing training progress
- Saving figures for reports and presentations
Matplotlib provides a versatile foundation for visualizing your data and model results, making it an essential tool in your PyTorch journey.
Additional Resources and Exercises
Resources
Exercises
-
Basic Plot: Create a line plot showing a quadratic function
f(x) = x²
for x values from -10 to 10. -
Data Comparison: Generate two normal distributions with different means and standard deviations. Plot their histograms on the same figure to compare them.
-
PyTorch Application: Create a visualization that shows training accuracy and validation accuracy across 30 epochs (you can make up sample data).
-
Advanced Challenge: Create a 2x2 subplot showing different aspects of a sine wave: the wave itself, its derivative, a scatter plot of select points, and a histogram of values.
By mastering these Matplotlib basics, you'll be well-equipped to visualize and understand your PyTorch models and data.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)