Skip to main content

PyTorch Audio

Introduction

PyTorch Audio (torchaudio) is a library within the PyTorch ecosystem that provides tools for working with audio data. It's designed to make audio processing, feature extraction, and building audio-based machine learning models as intuitive as possible. With torchaudio, you can load audio files, apply transformations, create spectrograms, and build sophisticated audio processing pipelines directly within the PyTorch framework.

Whether you're interested in speech recognition, music analysis, sound classification, or any audio-based application, torchaudio provides the necessary tools to streamline your workflow.

Getting Started with torchaudio

Installation

First, let's install torchaudio:

bash
pip install torchaudio

For the best compatibility, it's recommended to install torchaudio with the same version as your PyTorch installation.

Basic Usage: Loading and Inspecting Audio Files

Let's start with loading an audio file:

python
import torch
import torchaudio
import matplotlib.pyplot as plt

# Load audio file
waveform, sample_rate = torchaudio.load('example.wav')

# Print shape and other information
print(f"Shape of waveform: {waveform.shape}")
print(f"Sample rate of audio: {sample_rate}")
print(f"Duration of audio: {waveform.shape[1] / sample_rate} seconds")

# Plot waveform
plt.figure(figsize=(10, 4))
plt.plot(waveform[0])
plt.title("Audio Waveform")
plt.xlabel("Time")
plt.ylabel("Amplitude")
plt.tight_layout()
plt.show()

Output (example):

Shape of waveform: torch.Size([2, 176400])
Sample rate of audio: 44100
Duration of audio: 4.0 seconds

The output indicates that this is a stereo audio file (2 channels) with 176,400 samples at a sampling rate of 44,100 Hz, resulting in a 4-second duration.

Audio Transformations

torchaudio provides various transformations that can be applied to audio data:

Resampling Audio

python
# Resample audio from 44.1kHz to 16kHz
resample_transform = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=16000
)
resampled_waveform = resample_transform(waveform)
print(f"Original shape: {waveform.shape}")
print(f"Resampled shape: {resampled_waveform.shape}")

Output:

Original shape: torch.Size([2, 176400])
Resampled shape: torch.Size([2, 64000])

Creating Spectrograms

Spectrograms are visual representations of the spectrum of frequencies in a sound signal as they vary with time. They're commonly used as inputs for audio machine learning models.

python
# Create a spectrogram using Short-Time Fourier Transform (STFT)
spectrogram_transform = torchaudio.transforms.Spectrogram()
spectrogram = spectrogram_transform(waveform)

print(f"Spectrogram shape: {spectrogram.shape}")

# Plot the spectrogram
plt.figure(figsize=(10, 4))
plt.imshow(spectrogram[0].log2().numpy(), cmap='viridis', origin='lower', aspect='auto')
plt.colorbar(format='%+2.0f dB')
plt.title("Spectrogram")
plt.xlabel("Time frame")
plt.ylabel("Frequency bin")
plt.tight_layout()
plt.show()

Output:

Spectrogram shape: torch.Size([2, 201, 345])

Mel Spectrogram

For many audio tasks, Mel spectrograms often work better than standard spectrograms as they better approximate how humans perceive sound:

python
# Create a Mel spectrogram
mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=1024,
n_mels=64
)
mel_spectrogram = mel_spectrogram_transform(waveform)

print(f"Mel spectrogram shape: {mel_spectrogram.shape}")

# Plot the Mel spectrogram
plt.figure(figsize=(10, 4))
plt.imshow(mel_spectrogram[0].log2().numpy(), cmap='viridis', origin='lower', aspect='auto')
plt.colorbar(format='%+2.0f dB')
plt.title("Mel Spectrogram")
plt.xlabel("Time frame")
plt.ylabel("Mel frequency bin")
plt.tight_layout()
plt.show()

Output:

Mel spectrogram shape: torch.Size([2, 64, 345])

Common Audio Transformations

torchaudio provides several transformations that are commonly used in audio processing:

Time Stretching

python
# Time stretch by a factor
stretch = torchaudio.transforms.TimeStretch(
fixed_rate=1.5, # 1.5x faster
n_freq=201
)
# Convert to complex spectrogram for time stretching
complex_spec = torch.view_as_complex(torch.stack([spectrogram, torch.zeros_like(spectrogram)], dim=-1))
stretched_spec = stretch(complex_spec)
print(f"Original spectrogram shape: {spectrogram.shape}")
print(f"Stretched spectrogram shape: {stretched_spec.shape}")

Adding Noise

python
# Add Gaussian noise to the waveform
def add_white_noise(waveform, noise_level=0.005):
noise = torch.randn_like(waveform) * noise_level
noisy_waveform = waveform + noise
return noisy_waveform

noisy_waveform = add_white_noise(waveform)

Practical Example: Audio Classification

Let's walk through a simple example of how to build an audio classification model using torchaudio. We'll create a model that can classify different types of sounds.

Dataset Preparation

First, we'll prepare our dataset:

python
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader

class AudioDataset(Dataset):
def __init__(self, audio_dir, transform=None):
self.audio_dir = audio_dir
self.transform = transform
self.classes = os.listdir(audio_dir)
self.file_list = []
self.labels = []

for i, class_name in enumerate(self.classes):
class_dir = os.path.join(audio_dir, class_name)
for audio_file in os.listdir(class_dir):
if audio_file.endswith('.wav'):
self.file_list.append(os.path.join(class_dir, audio_file))
self.labels.append(i)

def __len__(self):
return len(self.file_list)

def __getitem__(self, idx):
audio_path = self.file_list[idx]
label = self.labels[idx]

waveform, sample_rate = torchaudio.load(audio_path)

# Ensure consistent length
if waveform.shape[1] < 16000: # Pad if too short
waveform = torch.nn.functional.pad(waveform, (0, 16000 - waveform.shape[1]))
else: # Truncate if too long
waveform = waveform[:, :16000]

# Apply transforms if any
if self.transform:
waveform = self.transform(waveform)

return waveform, label

# Example usage:
transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=64)
dataset = AudioDataset(audio_dir='path/to/audio_files', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

Building a Simple Classifier

Now, let's build a simple CNN to classify our audio data:

python
import torch.nn as nn
import torch.nn.functional as F

class AudioClassifier(nn.Module):
def __init__(self, num_classes):
super(AudioClassifier, self).__init__()
# Input shape: [batch_size, 1, 64, n_frames]
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.pool3 = nn.MaxPool2d(kernel_size=2)

# Calculate the output size after convolutions and pooling
self.fc1_input = 64 * 8 * (16000 // 256) # Depends on input size and pooling
self.fc1 = nn.Linear(self.fc1_input, 128)
self.fc2 = nn.Linear(128, num_classes)

def forward(self, x):
# Add channel dimension for single-channel audio
x = x.unsqueeze(1)

x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = self.pool3(F.relu(self.conv3(x)))

# Flatten
x = x.view(-1, self.fc1_input)

x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

# Initialize model
model = AudioClassifier(num_classes=len(dataset.classes))

Training Loop

Here's a simple training loop:

python
import torch.optim as optim

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
model.train()
running_loss = 0.0

for waveforms, labels in dataloader:
waveforms, labels = waveforms.to(device), labels.to(device)

# Zero the parameter gradients
optimizer.zero_grad()

# Forward pass
outputs = model(waveforms)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()

running_loss += loss.item()

print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}')

print('Training complete')

Pre-trained Models in torchaudio

torchaudio provides access to pre-trained models for tasks like speech recognition:

python
# Load a pre-trained model for speech recognition
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model()

# Function to transcribe audio
def transcribe_audio(waveform, sample_rate):
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)

with torch.no_grad():
emission, _ = model(waveform)

# Get the predicted transcript
transcript = bundle.get_labels(emission.argmax(dim=-1)[0])
return transcript

# Example usage
waveform, sample_rate = torchaudio.load("speech.wav")
transcript = transcribe_audio(waveform, sample_rate)
print(f"Transcription: {transcript}")

Datasets in torchaudio

torchaudio includes several built-in datasets that you can use for experiments:

python
# Access the LIBRISPEECH dataset
train_dataset = torchaudio.datasets.LIBRISPEECH(
root="./data",
url="train-clean-100",
download=True
)

# Get an example
waveform, sample_rate, transcript, speaker_id, _, _ = train_dataset[0]
print(f"Sample rate: {sample_rate}")
print(f"Transcript: {transcript}")
print(f"Speaker ID: {speaker_id}")

Other available datasets include:

  • YESNO (simple yes/no utterances)
  • VCTK (multi-speaker speech dataset)
  • SPEECHCOMMANDS (keyword spotting dataset)
  • LJSpeech (single-speaker speech dataset)

Summary

In this tutorial, we explored PyTorch Audio (torchaudio), learning how to:

  1. Load and visualize audio data
  2. Apply common transformations like resampling and creating spectrograms
  3. Build a custom audio dataset class
  4. Create a CNN for audio classification
  5. Use pre-trained models for speech recognition
  6. Access built-in datasets

torchaudio is a powerful library that makes audio processing with PyTorch straightforward. By combining it with PyTorch's neural network capabilities, you can build sophisticated audio processing applications, from speech recognition to music generation.

Additional Resources

Exercises

  1. Basic Exercise: Load an audio file of your choice and visualize its waveform and spectrogram.

  2. Intermediate Exercise: Implement data augmentation for audio by applying random transformations (e.g., time stretching, pitch shifting) to an audio dataset.

  3. Advanced Exercise: Build and train a model to classify different musical genres using the GTZAN dataset (which you can download online).

  4. Research Exercise: Implement a simple speech recognition system using torchaudio's pre-trained models and evaluate its performance on different types of speech inputs.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)