Skip to main content

Python Matplotlib Basics

Introduction

Matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python. When working with PyTorch for machine learning, data visualization becomes an essential skill for understanding your data and model performance. This tutorial will cover the fundamentals of Matplotlib, providing you with the tools to create effective visualizations for your PyTorch projects.

Why Matplotlib for PyTorch?

Before diving into PyTorch, understanding how to visualize your data is crucial because:

  • It helps you understand the distribution and relationships in your training data
  • It allows you to visualize model performance metrics (accuracy, loss curves)
  • It enables you to interpret and communicate your results effectively

Getting Started with Matplotlib

Installation

First, make sure Matplotlib is installed in your environment:

bash
pip install matplotlib

Importing Matplotlib

The most common way to import Matplotlib is as follows:

python
import matplotlib.pyplot as plt
import numpy as np # We'll use NumPy for data generation

Basic Plotting

Creating a Simple Line Plot

Let's create our first plot - a simple line graph:

python
# Generate data
x = np.linspace(0, 10, 100) # 100 points from 0 to 10
y = np.sin(x)

# Create a figure and axis
plt.figure(figsize=(8, 4)) # width, height in inches
plt.plot(x, y)
plt.title('Sine Wave')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True)
plt.show()

Output:
Sine Wave Plot

This example demonstrates the basic workflow:

  1. Generate or load your data
  2. Create a figure (optional but recommended to control size)
  3. Plot your data with the appropriate function
  4. Add labels and titles
  5. Display the plot

Multiple Lines on the Same Plot

You can add multiple lines to the same plot, which is useful for comparing different data:

python
plt.figure(figsize=(10, 5))
plt.plot(x, np.sin(x), label='sin(x)')
plt.plot(x, np.cos(x), label='cos(x)')
plt.plot(x, np.sin(x) * np.cos(x), label='sin(x)cos(x)')
plt.title('Trigonometric Functions')
plt.xlabel('x')
plt.ylabel('y')
plt.legend() # Add a legend
plt.grid(True)
plt.show()

Output:
Multiple Line Plot

Understanding Figure and Axes

For more complex visualizations, understanding the difference between figure and axes is important:

python
# Create a figure and a set of subplots
fig, ax = plt.subplots(figsize=(8, 4))

# Plot on the axes
ax.plot(x, np.sin(x))
ax.set_title('Sine Wave using Object-Oriented Approach')
ax.set_xlabel('x')
ax.set_ylabel('sin(x)')
ax.grid(True)

plt.show()

This object-oriented approach gives you more control and is recommended for complex plots.

Common Plot Types for Data Analysis

Scatter Plots

Scatter plots are useful for showing relationships between two variables:

python
# Generate random data
np.random.seed(42)
x = np.random.rand(50)
y = x + np.random.normal(0, 0.2, 50) # y = x + noise

plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=0.7, s=100)
plt.title('Scatter Plot Example')
plt.xlabel('x-values')
plt.ylabel('y-values')
plt.grid(True)
plt.show()

Output:
Scatter Plot

Histograms

Histograms help visualize the distribution of a dataset:

python
# Generate normal distributed data
data = np.random.normal(0, 1, 1000)

plt.figure(figsize=(8, 6))
plt.hist(data, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
plt.title('Histogram of Normal Distribution')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

Output:
Histogram

Bar Charts

Bar charts are excellent for comparing discrete categories:

python
categories = ['Category A', 'Category B', 'Category C', 'Category D']
values = [25, 40, 30, 55]

plt.figure(figsize=(8, 6))
plt.bar(categories, values, color='salmon', width=0.6)
plt.title('Bar Chart Example')
plt.xlabel('Category')
plt.ylabel('Value')
plt.ylim(0, 60) # Set y-axis limits
plt.grid(True, axis='y') # Grid lines only on y-axis
plt.show()

Output:
Bar Chart

Subplots - Creating Multiple Plots

When analyzing data, you'll often want to compare multiple visualizations side by side:

python
# Create a figure with 2x2 subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: Line plot
axes[0, 0].plot(x, np.sin(x))
axes[0, 0].set_title('Sine Wave')
axes[0, 0].set_xlabel('x')
axes[0, 0].set_ylabel('sin(x)')
axes[0, 0].grid(True)

# Plot 2: Scatter plot
axes[0, 1].scatter(x, y, alpha=0.7)
axes[0, 1].set_title('Scatter Plot')
axes[0, 1].set_xlabel('x')
axes[0, 1].set_ylabel('y')
axes[0, 1].grid(True)

# Plot 3: Histogram
axes[1, 0].hist(data, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
axes[1, 0].set_title('Histogram')
axes[1, 0].set_xlabel('Value')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].grid(True)

# Plot 4: Bar chart
axes[1, 1].bar(categories, values, color='salmon')
axes[1, 1].set_title('Bar Chart')
axes[1, 1].set_xlabel('Category')
axes[1, 1].set_ylabel('Value')
axes[1, 1].grid(True, axis='y')

# Adjust layout
plt.tight_layout()
plt.show()

Output:
Subplots

Customizing Plots

Setting Colors, Line Styles, and Markers

Matplotlib offers extensive customization options:

python
plt.figure(figsize=(10, 6))

# Different line styles, colors, and markers
plt.plot(x, np.sin(x), 'r--', label='sin(x) - red dashed')
plt.plot(x, np.cos(x), 'b-', linewidth=2, label='cos(x) - blue solid')
plt.plot(x, np.sin(x) * np.cos(x), 'g-.',
marker='o', markersize=5, markevery=10,
label='sin(x)cos(x) - green dash-dot with circles')

plt.title('Customized Line Styles and Colors')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.grid(True)
plt.show()

Output:
Customized Plot

Adding Text and Annotations

You can add text and annotations to highlight specific points:

python
plt.figure(figsize=(10, 5))
plt.plot(x, np.sin(x))

# Add text at a specific point
plt.text(5, 0.5, 'Local Maximum', fontsize=12)

# Add an arrow pointing to a specific point
plt.annotate('Local Minimum',
xy=(8, np.sin(8)), # Point to annotate
xytext=(9, -0.5), # Position of the text
arrowprops=dict(facecolor='black', shrink=0.05, width=1.5))

plt.title('Sine Wave with Annotations')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True)
plt.show()

Output:
Annotated Plot

Real-World Example: Visualizing PyTorch Training Progress

Here's an example of how Matplotlib would be used in a PyTorch project to visualize training and validation loss:

python
# Sample data (this would come from your PyTorch training loop)
epochs = list(range(1, 21))
train_loss = [0.8, 0.65, 0.55, 0.48, 0.42, 0.38, 0.35, 0.33, 0.31, 0.29,
0.28, 0.27, 0.26, 0.25, 0.24, 0.24, 0.23, 0.23, 0.22, 0.22]
val_loss = [0.9, 0.7, 0.62, 0.58, 0.54, 0.51, 0.5, 0.49, 0.48, 0.47,
0.47, 0.46, 0.46, 0.45, 0.45, 0.45, 0.44, 0.44, 0.44, 0.44]

plt.figure(figsize=(10, 6))
plt.plot(epochs, train_loss, 'b-', label='Training Loss')
plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
plt.title('Training and Validation Loss over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Adding an annotation to indicate potential overfitting
plt.annotate('Potential Overfitting',
xy=(15, val_loss[14]),
xytext=(16, val_loss[14] + 0.1),
arrowprops=dict(facecolor='black', shrink=0.05))

plt.show()

Output:
PyTorch Training Visualization

Saving Figures

It's often necessary to save your visualizations for reports or presentations:

python
plt.figure(figsize=(10, 6))
plt.plot(x, np.sin(x), label='sin(x)')
plt.title('Sine Wave to be Saved')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.legend()
plt.grid(True)

# Save the figure in different formats
plt.savefig('sine_wave.png', dpi=300, bbox_inches='tight') # PNG format
plt.savefig('sine_wave.pdf', bbox_inches='tight') # PDF format

plt.show()

The bbox_inches='tight' parameter ensures that all parts of the figure (including labels) are included in the saved file.

Summary

In this tutorial, we've covered the basics of Matplotlib for data visualization, which is crucial for working with PyTorch and machine learning:

  • Creating basic plots (line, scatter, histogram, bar charts)
  • Working with figures and axes
  • Creating multiple subplots
  • Customizing plots with colors, styles, and annotations
  • Real-world application for visualizing training progress
  • Saving figures for reports and presentations

Matplotlib provides a versatile foundation for visualizing your data and model results, making it an essential tool in your PyTorch journey.

Additional Resources and Exercises

Resources

Exercises

  1. Basic Plot: Create a line plot showing a quadratic function f(x) = x² for x values from -10 to 10.

  2. Data Comparison: Generate two normal distributions with different means and standard deviations. Plot their histograms on the same figure to compare them.

  3. PyTorch Application: Create a visualization that shows training accuracy and validation accuracy across 30 epochs (you can make up sample data).

  4. Advanced Challenge: Create a 2x2 subplot showing different aspects of a sine wave: the wave itself, its derivative, a scatter plot of select points, and a histogram of values.

By mastering these Matplotlib basics, you'll be well-equipped to visualize and understand your PyTorch models and data.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)