Pandas Groupby Filtering
When analyzing data with pandas, you'll often need to not just group your data, but also filter out certain groups based on aggregate conditions. This is where groupby filtering comes in - a powerful technique that lets you select only the groups that meet specific criteria after aggregation.
Introduction to Groupby Filtering
Groupby filtering allows you to:
- Group data based on one or more columns
- Apply aggregate functions to each group
- Filter groups based on the results of those aggregations
This is particularly useful when you want to answer questions like:
- "Which product categories have an average price greater than $50?"
- "Which countries have more than 1000 customers?"
- "Which students scored above the class average in all subjects?"
Basic Groupby Filtering
Let's start with a simple example. We'll create a DataFrame with sales data and filter to find product categories with more than 100 total sales:
import pandas as pd
import numpy as np
# Create sample data
data = {
'category': ['Electronics', 'Clothing', 'Books', 'Electronics', 'Clothing',
'Books', 'Electronics', 'Clothing', 'Books'],
'product': ['TV', 'T-shirt', 'Novel', 'Laptop', 'Jeans',
'Textbook', 'Phone', 'Dress', 'Comic'],
'sales': [150, 45, 30, 200, 60, 20, 80, 35, 15]
}
df = pd.DataFrame(data)
print("Original DataFrame:")
print(df)
Output:
Original DataFrame:
category product sales
0 Electronics TV 150
1 Clothing T-shirt 45
2 Books Novel 30
3 Electronics Laptop 200
4 Clothing Jeans 60
5 Books Textbook 20
6 Electronics Phone 80
7 Clothing Dress 35
8 Books Comic 15
Now, let's find categories with total sales greater than 100:
# Group by category and calculate sum of sales
category_sales = df.groupby('category')['sales'].sum()
print("\nTotal sales per category:")
print(category_sales)
# Filter for categories with sales > 100
high_volume_categories = category_sales[category_sales > 100]
print("\nCategories with sales > 100:")
print(high_volume_categories)
Output:
Total sales per category:
category
Books 65
Clothing 140
Electronics 430
Name: sales, dtype: int64
Categories with sales > 100:
category
Clothing 140
Electronics 430
Name: sales, dtype: int64
Filtering with filter() Method
The filter()
method provides a more elegant way to filter groups based on a condition. It returns the original DataFrame with only the rows belonging to groups that satisfy the condition.
# Filter categories with total sales > 100
filtered_df = df.groupby('category').filter(lambda x: x['sales'].sum() > 100)
print("\nFiltered DataFrame (categories with sales > 100):")
print(filtered_df)
Output:
Filtered DataFrame (categories with sales > 100):
category product sales
0 Electronics TV 150
1 Clothing T-shirt 45
3 Electronics Laptop 200
4 Clothing Jeans 60
6 Electronics Phone 80
7 Clothing Dress 35
Notice that all rows from the 'Books' category have been removed since its total sales (65) is less than 100.
How filter() Works
The filter()
method:
- Groups the DataFrame by the specified column(s)
- Applies the function to each group
- Returns only the rows belonging to groups where the function returns True
Advanced Filtering Techniques
Multiple Conditions
You can use multiple conditions in your filter function:
# Filter categories with total sales > 100 AND at least 3 products
filtered_df = df.groupby('category').filter(
lambda x: (x['sales'].sum() > 100) & (len(x) >= 3)
)
print("\nFiltered DataFrame (sales > 100 AND at least 3 products):")
print(filtered_df)
Output:
Filtered DataFrame (sales > 100 AND at least 3 products):
category product sales
0 Electronics TV 150
1 Clothing T-shirt 45
3 Electronics Laptop 200
4 Clothing Jeans 60
6 Electronics Phone 80
7 Clothing Dress 35
Percentile-based Filtering
You can filter groups based on percentiles or statistical measures:
# Filter for categories where the average sales is above the 50th percentile
def above_median(x):
return x['sales'].mean() > x['sales'].median()
filtered_df = df.groupby('category').filter(above_median)
print("\nCategories with average sales above median:")
print(filtered_df)
Using transform() with Filtering
The transform()
method can be combined with filtering for powerful operations:
# Find products that have sales above the average for their category
category_avg = df.groupby('category')['sales'].transform('mean')
above_category_avg = df[df['sales'] > category_avg]
print("\nProducts with above-average sales in their category:")
print(above_category_avg)
Output:
Products with above-average sales in their category:
category product sales
0 Electronics TV 150
3 Electronics Laptop 200
4 Clothing Jeans 60
5 Books Textbook 20
Real-world Applications
Application 1: Customer Segmentation
Let's say we want to identify customer segments with high value and high engagement:
# Sample customer data
customer_data = {
'customer_id': range(1, 11),
'segment': ['A', 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'C', 'A'],
'purchases': [5, 2, 8, 3, 7, 4, 6, 1, 2, 9],
'amount_spent': [120, 45, 230, 70, 190, 85, 110, 30, 50, 280]
}
customers = pd.DataFrame(customer_data)
# Find high-value segments (average spending > 100)
high_value_segments = customers.groupby('segment').filter(
lambda x: x['amount_spent'].mean() > 100
)
print("\nHigh-value customer segments:")
print(high_value_segments)
Application 2: Sales Performance Analysis
Let's identify underperforming product categories by month:
# Create sample sales data
np.random.seed(42)
months = ['Jan', 'Feb', 'Mar', 'Apr'] * 3
categories = ['Electronics', 'Clothing', 'Home'] * 4
sales_data = {
'month': months,
'category': categories,
'sales': np.random.randint(50, 200, size=12),
'target': np.random.randint(100, 150, size=12)
}
sales_df = pd.DataFrame(sales_data)
# Find categories that didn't meet target in any month
underperforming = sales_df.groupby('category').filter(
lambda x: (x['sales'] < x['target']).any()
)
print("\nUnderperforming categories:")
print(underperforming)
Common Patterns and Best Practices
Pattern 1: Two-step Filter
For complex conditions, it's often clearer to use a two-step approach:
# Step 1: Calculate the aggregates
agg_results = df.groupby('category').agg({
'sales': ['sum', 'mean', 'count']
})
# Step 2: Filter based on aggregates
selected_categories = agg_results[
(agg_results[('sales', 'sum')] > 100) &
(agg_results[('sales', 'count')] >= 3)
].index
# Step 3: Filter the original dataframe
result = df[df['category'].isin(selected_categories)]
Pattern 2: Combining filter() with other GroupBy Operations
# First filter groups, then calculate aggregates on filtered data
result = (df.groupby('category')
.filter(lambda x: x['sales'].sum() > 100) # Filter step
.groupby('category') # Re-group
.agg({'sales': ['mean', 'sum']})) # Aggregate
Best Practices
- Performance considerations: For large datasets, filtering after aggregation can be more efficient
- Readability: For complex conditions, break down the logic into clear steps
- Chaining: Be cautious when chaining multiple operations - it can reduce readability
Handling Common Challenges
Challenge 1: Filtering with NaN Values
# Create data with NaN values
df_with_nan = df.copy()
df_with_nan.loc[0, 'sales'] = np.nan
# Filter using dropna inside the lambda function
filtered_df = df_with_nan.groupby('category').filter(
lambda x: x['sales'].dropna().sum() > 100
)
Challenge 2: Filtering with Multi-level Groups
# Create multi-level grouping example
df['year'] = [2020, 2021, 2022] * 3
# Filter with multi-level grouping
filtered_df = df.groupby(['year', 'category']).filter(
lambda x: x['sales'].sum() > 50
)
Summary
Groupby filtering is a powerful pandas technique that allows you to:
- Group your data by one or more columns
- Apply aggregate functions to each group
- Filter groups based on the results of those aggregations
Key methods covered:
- Basic filtering with boolean indexing on grouped results
- Using the
filter()
method to keep groups that satisfy conditions - Combining
transform()
with filtering for more complex operations - Real-world applications in customer segmentation and sales analysis
By mastering groupby filtering, you can perform sophisticated data analysis with concise code, selecting only the data that matters for your analysis.
Additional Resources
Exercises
- Filter the categories where the maximum sales value is at least twice the minimum sales value.
- Create a DataFrame with student scores across multiple subjects and find students who scored above average in all subjects.
- Use sales data to identify products that consistently performed above target for at least 3 consecutive months.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)