TensorFlow Data Validation
Introduction
TensorFlow Data Validation (TFDV) is an essential library in the TensorFlow Extended (TFX) ecosystem that helps you understand, validate, and monitor your ML data. When building machine learning pipelines, the quality and consistency of your data is critical for model performance and reliability. TFDV provides tools to:
- Generate descriptive statistics from your data
- Infer a schema that describes the expected data
- Detect anomalies and validate new data against your schema
- Visualize and compare statistics across different datasets
For beginners working on ML projects, TFDV helps you catch data issues early and ensures your models are trained on high-quality, consistent data.
Why Data Validation Matters
Imagine training a model on a dataset where age is represented in years (values like 25, 30, 42), and then trying to use that model on new data where age is suddenly in months (values like 300, 360, 504). Your model's performance would suffer dramatically! TFDV helps catch these kinds of data inconsistencies before they impact your models.
Getting Started with TFDV
Let's start by installing TensorFlow Data Validation:
pip install tensorflow-data-validation
Basic Imports
import tensorflow_data_validation as tfdv
import pandas as pd
import tensorflow as tf
Generating Data Statistics
The first step in understanding your data is to generate descriptive statistics. TFDV can analyze your dataset and create a comprehensive report of statistics.
Example: Analyzing a CSV Dataset
Let's analyze a simple census dataset:
# Load a CSV file
csv_file = 'census_data.csv'
# Generate statistics
stats = tfdv.generate_statistics_from_csv(csv_file)
# Visualize the statistics
tfdv.visualize_statistics(stats)
Output: This generates an interactive visualization showing distributions of each feature, missing value counts, and other statistics.
Example: Analyzing a DataFrame
You can also analyze a Pandas DataFrame directly:
# Sample census data
data = {
'age': [39, 40, 27, 52, 31],
'workclass': ['State-gov', 'Private', 'Private', 'Self-emp', 'Private'],
'education': ['Bachelors', 'Some-college', 'Masters', 'High-school', 'Bachelors'],
'income': ['<=50K', '>50K', '>50K', '<=50K', '>50K']
}
df = pd.DataFrame(data)
# Generate statistics from DataFrame
stats = tfdv.generate_statistics_from_dataframe(df)
# Visualize the statistics
tfdv.visualize_statistics(stats)
Inferring a Schema
After generating statistics, you can infer a schema that captures the expected properties of valid data:
# Infer schema from statistics
schema = tfdv.infer_schema(stats)
# Display the schema
tfdv.display_schema(schema)
The schema captures information like:
- Feature names and types
- Value domains (categorical values)
- Expected ranges for numerical features
- Required vs. optional features
Validating Data Against a Schema
Now that we have a schema, we can validate new data to check for anomalies:
# Generate statistics for validation data
validation_csv = 'new_census_data.csv'
validation_stats = tfdv.generate_statistics_from_csv(validation_csv)
# Check for anomalies
anomalies = tfdv.validate_statistics(validation_stats, schema)
# Display anomalies
tfdv.display_anomalies(anomalies)
Common Anomalies TFDV Can Detect
- Missing Features: Features present in training but missing in serving
- Extra Features: New features appearing in serving data
- Type Mismatches: Changes in data types (e.g., int to float)
- Domain Violations: New categorical values
- Distribution Shifts: Statistical changes in feature distributions
Modifying a Schema
You can update your schema to adapt to changes in your data:
# Add a new categorical value to a feature
tfdv.get_domain(schema, 'workclass').value.append('Government')
# Make a feature optional
tfdv.get_feature(schema, 'education').presence.min_fraction = 0.0
# Update value range for a numeric feature
tfdv.set_domain(schema, 'age',
tfdv.FloatDomain(name='age', min=16, max=100))
Comparing Training and Serving Data
A common ML issue is "training-serving skew," where your production data differs from training data:
# Generate statistics for training and serving data
train_stats = tfdv.generate_statistics_from_csv('train_data.csv')
serving_stats = tfdv.generate_statistics_from_csv('serving_data.csv')
# Compare datasets
tfdv.visualize_statistics(
lhs_statistics=train_stats,
rhs_statistics=serving_stats,
lhs_name='Training Data',
rhs_name='Serving Data'
)
Handling Different Environments
ML systems often need different validation rules for training, evaluation, and serving environments:
# Create environment-specific schemas
schema = tfdv.infer_schema(train_stats)
# Define environments
tfdv.set_domain(schema, 'income',
tfdv.StringDomain(name='income', value=['<=50K', '>50K']))
# Specify that 'income' feature is not available during serving
serving_env = tfdv.get_environment(schema, 'SERVING')
serving_env.features.add(name='income')
serving_env.features[-1].not_in_environment = True
Real-World Example: Building a Production Pipeline
Let's put all these concepts together in a more complete example that shows how TFDV fits into a production ML workflow:
# Step 1: Analyze training data and create schema
train_csv = 'train_data.csv'
train_stats = tfdv.generate_statistics_from_csv(train_csv)
schema = tfdv.infer_schema(train_stats)
# Step 2: Refine schema based on domain knowledge
# Make 'email' feature optional
tfdv.get_feature(schema, 'email').presence.min_fraction = 0.0
# Add range constraints for 'age'
tfdv.set_domain(schema, 'age',
tfdv.FloatDomain(name='age', min=18, max=120))
# Step 3: Define environment-specific rules
serving_env = tfdv.get_environment(schema, 'SERVING')
serving_env.features.add(name='label')
serving_env.features[-1].not_in_environment = True
# Step 4: Validate new data before training or serving
new_data_stats = tfdv.generate_statistics_from_csv('new_data.csv')
anomalies = tfdv.validate_statistics(new_data_stats, schema)
if anomalies.anomaly_info:
print("Data contains anomalies!")
tfdv.display_anomalies(anomalies)
else:
print("Data validation passed! Proceeding to next step in ML pipeline.")
Working with TensorFlow Example Format
TFDV integrates well with TensorFlow's native tf.Example
format:
# Generate statistics from TFRecord file containing tf.Examples
tfrecord_stats = tfdv.generate_statistics_from_tfrecord('data.tfrecord')
# For TFRecord files with tf.SequenceExample
seq_ex_stats = tfdv.generate_statistics_from_tfrecord(
'data.tfrecord',
stats_options=tfdv.StatsOptions(
sequence_feature_statistics_configs=[
tfdv.SequenceStatsConfig(name='my_sequence_feature')
]
)
)
Advanced Features
Detecting Data Drift
Data drift occurs when your production data gradually changes over time, affecting model performance:
# Compare statistics between time periods
january_stats = tfdv.generate_statistics_from_csv('january_data.csv')
february_stats = tfdv.generate_statistics_from_csv('february_data.csv')
# Detect drift with custom thresholds
drift_options = tfdv.StatsOptions(
schema=schema,
drift_comparator=tfdv.default_drift_comparator(
numerical_threshold=0.01, # 1% drift threshold
categorical_threshold=0.001 # 0.1% for categorical features
)
)
drift_anomalies = tfdv.validate_statistics(
february_stats, schema, previous_statistics=january_stats,
stats_options=drift_options
)
tfdv.display_anomalies(drift_anomalies)
Working with Sparse Features
TFDV can also handle sparse features (commonly used in recommender systems):
# Configure options for sparse features
stats_options = tfdv.StatsOptions(
vocab_paths={
'item_id': 'item_vocab.txt' # Path to vocabulary file
},
num_top_values=1000, # Number of most frequent values to keep
num_rank_histogram_buckets=10 # Number of buckets for rank histogram
)
# Generate statistics with these options
sparse_stats = tfdv.generate_statistics_from_csv(
'sparse_data.csv',
stats_options=stats_options
)
Integration with TFX Pipelines
In a complete TFX pipeline, TFDV typically runs in the ExampleValidator component:
# Pseudocode for TFX pipeline integration
import tfx
from tfx.components import StatisticsGen, SchemaGen, ExampleValidator
# Generate statistics
statistics_gen = StatisticsGen(
examples=example_gen.outputs['examples']
)
# Generate schema
schema_gen = SchemaGen(
statistics=statistics_gen.outputs['statistics']
)
# Validate examples
example_validator = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)
# Define the pipeline
pipeline = tfx.dsl.Pipeline(
# ...other components...
components=[
example_gen,
statistics_gen,
schema_gen,
example_validator,
# ...other components...
]
)
Summary
TensorFlow Data Validation is a powerful tool that helps you:
- Understand your data through comprehensive statistics
- Create schemas that define what "good data" looks like
- Detect anomalies that could impact model performance
- Compare datasets to identify drift and training-serving skew
- Ensure data quality across different environments
By integrating TFDV into your ML pipelines, you create more robust and reliable machine learning systems that can handle real-world data challenges.
Additional Resources
Practice Exercises
- Download a public dataset and use TFDV to generate and visualize statistics.
- Create a schema for your dataset and intentionally introduce anomalies in a copy of the data. Use TFDV to detect these anomalies.
- Set up environment-specific feature constraints for a hypothetical ML system with training and serving environments.
- Compare two versions of a dataset to detect data drift.
- Integrate TFDV into a simple TFX pipeline using the TFX template.
Happy data validating!
If you spot any mistakes on this website, please let me know at feedback@compilenrun.com. I’d greatly appreciate your feedback! :)