Pandas Pivot Tables
Introduction
Pivot tables are one of the most powerful data reshaping and summarization tools in pandas. They allow you to transform your data from a long format to a wide format, making it easier to analyze and visualize relationships between variables. If you've used pivot tables in spreadsheet applications like Excel, you'll find pandas pivot tables familiar but with even more flexibility and power.
In this tutorial, we'll explore how to create pivot tables using the pivot_table()
function in pandas, understand its various parameters, and see how it can be applied to real-world data analysis tasks.
Basic Concept of Pivot Tables
A pivot table allows you to:
- Reorganize and summarize data in a tabular format
- Aggregate data using functions like sum, mean, count, etc.
- Transform data from a "long" format to a "wide" format
- Create cross-tabulations to see relationships between variables
Creating Simple Pivot Tables
Let's start with a basic example using a sample dataset:
import pandas as pd
import numpy as np
# Create a sample dataframe
data = {
'date': pd.date_range(start='2023-01-01', periods=10),
'product': ['A', 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C'],
'region': ['East', 'West', 'West', 'East', 'East', 'West', 'West', 'East', 'West', 'East'],
'sales': [200, 150, 320, 180, 250, 280, 120, 230, 310, 190]
}
df = pd.DataFrame(data)
print("Original DataFrame:")
print(df)
Output:
Original DataFrame:
date product region sales
0 2023-01-01 A East 200
1 2023-01-02 B West 150
2 2023-01-03 A West 320
3 2023-01-04 C East 180
4 2023-01-05 B East 250
5 2023-01-06 A West 280
6 2023-01-07 C West 120
7 2023-01-08 B East 230
8 2023-01-09 A West 310
9 2023-01-10 C East 190
Now, let's create a simple pivot table to see total sales by product and region:
# Create a basic pivot table
pivot1 = pd.pivot_table(df, values='sales', index='product', columns='region', aggfunc='sum')
print("\nPivot Table (Total Sales by Product and Region):")
print(pivot1)
Output:
Pivot Table (Total Sales by Product and Region):
region East West
product
A 200 910
B 480 150
C 370 120
Understanding Pivot Table Parameters
The pivot_table()
function has several important parameters:
data
: The pandas DataFramevalues
: Column(s) to aggregateindex
: Row labelscolumns
: Column labelsaggfunc
: Aggregation function (e.g., 'sum', 'mean', 'count', or any NumPy function)fill_value
: Value to replace missing valuesmargins
: Add row/column totals
Let's see how these parameters work together:
# Advanced pivot table with multiple parameters
pivot2 = pd.pivot_table(
df,
values='sales',
index='product',
columns='region',
aggfunc='sum',
fill_value=0,
margins=True,
margins_name='Total'
)
print("\nAdvanced Pivot Table (with totals and filled NaN values):")
print(pivot2)
Output:
Advanced Pivot Table (with totals and filled NaN values):
region East West Total
product
A 200 910 1110
B 480 150 630
C 370 120 490
Total 1050 1180 2230
Multiple Aggregation Functions
You can apply multiple aggregation functions at once:
# Pivot table with multiple aggregation functions
pivot3 = pd.pivot_table(
df,
values='sales',
index='product',
columns='region',
aggfunc=['sum', 'mean', 'count'],
fill_value=0
)
print("\nPivot Table with Multiple Aggregation Functions:")
print(pivot3)
Output:
Pivot Table with Multiple Aggregation Functions:
sum mean count
region East West East West East West
product
A 200 910 200.00 303.33 1 3
B 480 150 240.00 150.00 2 1
C 370 120 185.00 120.00 2 1
Multiple Index and Column Values
Pivot tables become even more powerful when you use multiple indices or columns:
# Let's add a 'quarter' column to our data
df['quarter'] = pd.PeriodIndex(df['date'], freq='Q').astype(str)
print("\nDataFrame with Quarter Information:")
print(df)
# Pivot table with multiple indices
pivot4 = pd.pivot_table(
df,
values='sales',
index=['product', 'quarter'],
columns='region',
aggfunc='sum',
fill_value=0
)
print("\nPivot Table with Multiple Indices:")
print(pivot4)
Output:
DataFrame with Quarter Information:
date product region sales quarter
0 2023-01-01 A East 200 2023Q1
1 2023-01-02 B West 150 2023Q1
2 2023-01-03 A West 320 2023Q1
3 2023-01-04 C East 180 2023Q1
4 2023-01-05 B East 250 2023Q1
5 2023-01-06 A West 280 2023Q1
6 2023-01-07 C West 120 2023Q1
7 2023-01-08 B East 230 2023Q1
8 2023-01-09 A West 310 2023Q1
9 2023-01-10 C East 190 2023Q1
Pivot Table with Multiple Indices:
region East West
product quarter
A 2023Q1 200 910
B 2023Q1 480 150
C 2023Q1 370 120
Real-world Application: Sales Analysis
Let's apply pivot tables to a more realistic sales dataset:
# Create a more realistic sales dataset
np.random.seed(42)
dates = pd.date_range('2023-01-01', '2023-12-31', freq='D')
products = ['Laptop', 'Phone', 'Tablet', 'Monitor', 'Keyboard']
regions = ['North', 'South', 'East', 'West', 'Central']
channels = ['Online', 'Store', 'Distributor']
n_rows = 1000
sales_data = {
'date': np.random.choice(dates, n_rows),
'product': np.random.choice(products, n_rows),
'region': np.random.choice(regions, n_rows),
'channel': np.random.choice(channels, n_rows),
'units': np.random.randint(1, 50, n_rows),
'unit_price': np.random.uniform(100, 1500, n_rows).round(2)
}
sales_df = pd.DataFrame(sales_data)
sales_df['revenue'] = sales_df['units'] * sales_df['unit_price']
sales_df['month'] = sales_df['date'].dt.strftime('%Y-%m')
print("\nSales Data Sample:")
print(sales_df.head())
# Analysis 1: Monthly revenue by product
monthly_product_sales = pd.pivot_table(
sales_df,
values='revenue',
index='month',
columns='product',
aggfunc='sum',
fill_value=0
)
print("\nMonthly Revenue by Product:")
print(monthly_product_sales.head())
# Analysis 2: Channel performance by region
channel_region_performance = pd.pivot_table(
sales_df,
values=['revenue', 'units'],
index='channel',
columns='region',
aggfunc={'revenue': 'sum', 'units': 'sum'},
fill_value=0,
margins=True
)
print("\nChannel Performance by Region:")
print(channel_region_performance)
This comprehensive example demonstrates how pivot tables can be used for real sales analysis, showing monthly trends and regional performance across different channels.
Working with Pivot Table Results
Once you've created a pivot table, you can manipulate it like any other DataFrame:
# Sort the monthly product sales
sorted_monthly = monthly_product_sales.sort_values(by='2023-01', ascending=False, axis=1)
print("\nSorted Monthly Product Sales:")
print(sorted_monthly.head())
# Add calculated columns
monthly_product_sales['Total'] = monthly_product_sales.sum(axis=1)
monthly_product_sales['Laptop_Percentage'] = (monthly_product_sales['Laptop'] / monthly_product_sales['Total'] * 100).round(2)
print("\nMonthly Sales with Calculated Columns:")
print(monthly_product_sales[['Laptop', 'Total', 'Laptop_Percentage']].head())
Reshaping a Pivot Table with stack() and unstack()
After creating a pivot table, you can further reshape it using stack()
and unstack()
:
# Let's create a simple pivot table first
simple_pivot = pd.pivot_table(
df,
values='sales',
index=['product'],
columns=['region', 'quarter'],
aggfunc='sum',
fill_value=0
)
print("\nSimple Multi-level Pivot Table:")
print(simple_pivot)
# Stack the innermost column level
stacked = simple_pivot.stack(level=1)
print("\nStacked Pivot Table (by quarter):")
print(stacked)
# Unstack the index level
unstacked = stacked.unstack(level=0)
print("\nUnstacked Pivot Table (by product):")
print(unstacked)
Difference Between pivot_table() and pivot()
Pandas offers two similar functions: pivot_table()
and pivot()
. The main differences are:
pivot_table()
can aggregate data when there are duplicate entries, using functions like sum, mean, etc.pivot()
doesn't perform aggregation and will error if there are duplicate entries in your index/column combinations
# This will work with pivot() because there are no duplicates
unique_data = df.drop_duplicates(['product', 'region'])
pivot_result = unique_data.pivot(index='product', columns='region', values='sales')
print("\nPivot Result (without duplicates):")
print(pivot_result)
# This would raise an error with pivot() but works with pivot_table()
# df.pivot(index='product', columns='region', values='sales') # Would raise ValueError
pivot_table_result = pd.pivot_table(df, index='product', columns='region', values='sales', aggfunc='mean')
print("\nPivot Table Result (with duplicates aggregated by mean):")
print(pivot_table_result)
Summary
Pandas pivot tables are a powerful tool for data reshaping and analysis. They allow you to:
- Transform data from long to wide format
- Aggregate data using multiple functions
- Create multi-level indices and columns
- Include totals and subtotals
- Handle missing values appropriately
- Perform complex data summarization with minimal code
Mastering pivot tables will significantly enhance your data analysis capabilities, making it easier to identify patterns and extract insights from your data.
Exercises
- Create a pivot table showing the average sales by product and region.
- Build a pivot table with multiple aggregation functions (min, max, and mean) for sales grouped by product.
- Create a pivot table that includes quarterly totals as both a row and column.
- Use a pivot table to find which product had the highest sales in each region.
- Generate a pivot table with both region and quarter as column indices, and calculate the percentage of sales for each product per quarter.
Additional Resources
- Pandas Official Documentation - Pivot Tables
- Pandas User Guide - Reshaping and Pivot Tables
- Python for Data Analysis by Wes McKinney (creator of pandas)
With practice, pivot tables will become an essential tool in your data analysis toolkit, enabling you to efficiently explore and summarize complex datasets.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)