Skip to main content

TensorFlow TFX Introduction

What is TensorFlow Extended (TFX)?

TensorFlow Extended (TFX) is an end-to-end platform for deploying production machine learning pipelines. When you're ready to move beyond simple model training and experimentation to building production ML systems, TFX provides the components you need for a complete ML pipeline.

Unlike basic TensorFlow which focuses on model development, TFX addresses the entire machine learning lifecycle - from data ingestion and preprocessing to model training, evaluation, deployment, and serving.

Why Use TFX?

When moving from experimental machine learning to production systems, developers face numerous challenges:

  • Data validation: Ensuring consistent, high-quality data
  • Feature engineering: Transforming raw data at scale
  • Model versioning: Tracking model lineage and metadata
  • Model deployment: Serving models reliably at scale
  • Pipeline orchestration: Automating the entire workflow

TFX solves these problems by providing a suite of components that work together to create production-ready ML systems.

Key Components of TFX

TFX provides several core components that handle different aspects of the ML pipeline:

  1. ExampleGen: Ingests and splits the input dataset
  2. StatisticsGen: Calculates statistics for the dataset
  3. SchemaGen: Examines the statistics and creates a data schema
  4. ExampleValidator: Looks for anomalies and missing values
  5. Transform: Performs feature engineering on the dataset
  6. Trainer: Trains the model using TensorFlow
  7. Evaluator: Performs deep analysis of training results
  8. InfraValidator: Validates model serving infrastructure
  9. Pusher: Deploys the model to a serving infrastructure
  10. BulkInferrer: Performs batch inference on unlabeled examples

Your First TFX Pipeline

Let's create a simple TFX pipeline to get started. First, we need to install TFX:

bash
pip install tfx

Here's an example of a basic TFX pipeline that trains a model on the Chicago Taxi dataset:

python
import tensorflow as tf
import tfx
from tfx.components import CsvExampleGen, StatisticsGen, SchemaGen
from tfx.orchestration import pipeline
from tfx.orchestration.local.local_dag_runner import LocalDagRunner

# Pipeline name and root directory for metadata and pipelines
PIPELINE_NAME = "chicago_taxi_simple"
PIPELINE_ROOT = "/tmp/pipelines/chicago_taxi_simple"
DATA_ROOT = "/tmp/data/chicago_taxi_simple"
METADATA_PATH = "/tmp/metadata/chicago_taxi_simple/metadata.db"

# Define the pipeline
def _create_pipeline():
example_gen = CsvExampleGen(input_base=DATA_ROOT)
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=True)

return pipeline.Pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_root=PIPELINE_ROOT,
components=[example_gen, statistics_gen, schema_gen],
enable_cache=True,
metadata_connection_config=tfx.orchestration.metadata
.sqlite_metadata_connection_config(METADATA_PATH)
)

# Run the pipeline
if __name__ == '__main__':
LocalDagRunner().run(_create_pipeline())

This simple pipeline has three components:

  1. CsvExampleGen - reads the CSV data
  2. StatisticsGen - calculates statistics about our dataset
  3. SchemaGen - infers a schema based on the statistics

Understanding TFX Pipeline Execution

When you run a TFX pipeline, each component:

  1. Takes inputs (either from external sources or outputs from other components)
  2. Performs its task (e.g., generating statistics or training a model)
  3. Produces outputs for the next components in the pipeline

The pipeline execution is tracked by TFX's metadata store, which records information about each run, including inputs, outputs, and parameters.

Practical Example: Chicago Taxi Trip Duration Prediction

Let's look at a more complete example using the Chicago Taxi dataset to predict trip durations.

Step 1: Pipeline Setup and Data Ingestion

python
import os
import tensorflow as tf
import tfx
from tfx.components import (CsvExampleGen, StatisticsGen, SchemaGen,
ExampleValidator, Transform, Trainer, Evaluator, Pusher)
from tfx.proto import example_gen_pb2, trainer_pb2
from tfx.orchestration import pipeline
from tfx.orchestration.local.local_dag_runner import LocalDagRunner

# Define pipeline constants
PIPELINE_NAME = "chicago_taxi_pipeline"
PIPELINE_ROOT = os.path.join("pipelines", PIPELINE_NAME)
DATA_ROOT = os.path.join("data", "chicago_taxi")
SERVING_MODEL_DIR = os.path.join("serving_model", PIPELINE_NAME)
MODULE_FILE = os.path.join("modules", "taxi_utils.py")

# Define pipeline components
def create_pipeline():
# Data ingestion
output = example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(splits=[
example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=8),
example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=2)
])
)
example_gen = CsvExampleGen(input_base=DATA_ROOT, output_config=output)

Step 2: Data Validation

python
    # Statistics generation for understanding the data
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

# Schema generation based on statistics
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'])

# Validate examples against the schema
example_validator = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)

Step 3: Feature Engineering

python
    # Feature engineering
transform = Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=MODULE_FILE
)

Step 4: Model Training and Evaluation

python
    # Train the model
trainer = Trainer(
module_file=MODULE_FILE,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000)
)

# Evaluate model performance
evaluator = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
feature_slicing_spec=tfx.proto.EvalConfig.SlicingSpec(
feature_keys=['trip_start_hour']
)
)

Step 5: Model Deployment

python
    # Push the model to a file destination
pusher = Pusher(
model=trainer.outputs['model'],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=SERVING_MODEL_DIR
)
)
)

# Define and return the pipeline
return pipeline.Pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_root=PIPELINE_ROOT,
components=[
example_gen, statistics_gen, schema_gen, example_validator,
transform, trainer, evaluator, pusher
],
enable_cache=True
)

# Run the pipeline
if __name__ == '__main__':
LocalDagRunner().run(create_pipeline())

The taxi_utils.py Module

This module would contain the preprocessing and modeling code:

python
# taxi_utils.py
import tensorflow as tf
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils

# Define feature columns
NUMERIC_FEATURES = ['trip_miles', 'fare', 'trip_seconds']
CATEGORICAL_FEATURES = ['pickup_community_area', 'dropoff_community_area',
'trip_start_hour', 'trip_start_day', 'trip_start_month']
LABEL_KEY = 'trip_duration_minutes'

def preprocessing_fn(inputs):
"""Preprocessing function for feature transformation."""
outputs = {}

# Scale numeric features
for feature in NUMERIC_FEATURES:
outputs[feature] = tft.scale_to_z_score(inputs[feature])

# Generate vocab and map categorical features
for feature in CATEGORICAL_FEATURES:
outputs[feature] = tft.compute_and_apply_vocabulary(
inputs[feature], vocab_filename=feature)

# Pass through the label
outputs[LABEL_KEY] = inputs[LABEL_KEY]

return outputs

def _build_estimator(config, hidden_units=None):
"""Build an estimator for predicting taxi trip duration."""
numeric_columns = [tf.feature_column.numeric_column(feature)
for feature in NUMERIC_FEATURES]

categorical_columns = [
tf.feature_column.categorical_column_with_identity(
feature, num_buckets=vocab_size, default_value=0)
for feature, vocab_size in CATEGORICAL_FEATURES_WITH_VOCAB_SIZE
]

return tf.estimator.DNNLinearCombinedRegressor(
config=config,
linear_feature_columns=categorical_columns,
dnn_feature_columns=numeric_columns,
dnn_hidden_units=hidden_units or [100, 50]
)

# Functions needed for the Trainer component
def trainer_fn(trainer_fn_args, schema):
"""Build the estimator using the high level API."""
# Training will need the transformations
# (contents not shown for brevity)
# ...

Orchestrating TFX Pipelines

TFX can run on various orchestrators:

  1. Local: For development and testing
  2. Apache Airflow: For production scheduling
  3. Kubeflow Pipelines: For Kubernetes-based execution
  4. Apache Beam: For distributed processing

For example, to run on Apache Airflow:

python
from tfx.orchestration.airflow.airflow_dag_runner import AirflowDagRunner
from tfx.orchestration.airflow.airflow_dag_runner import AirflowPipelineConfig

airflow_config = AirflowPipelineConfig(
schedule_interval='daily',
start_date=datetime.datetime(2020, 1, 1)
)

AirflowDagRunner(airflow_config).run(
create_pipeline()
)

TFX in Production

In production environments, TFX provides several key advantages:

  1. Reproducibility: Each run is tracked with full lineage and data provenance
  2. Scalability: Components designed to handle large datasets
  3. Monitoring: Continuous validation of data and models
  4. Integration: Works with TensorFlow Serving, ML Metadata, and other ML tools

Key TFX Concepts to Remember

  • Component: A single operation in your pipeline (e.g., ingesting data or training a model)
  • Artifact: Data passed between components (e.g., examples, statistics, models)
  • Pipeline: A directed acyclic graph of components working together
  • Metadata Store: A database that tracks artifacts and executions
  • Orchestrator: The system that schedules and executes pipeline components

Summary

TensorFlow Extended (TFX) is a powerful platform for building production ML pipelines. It provides components for every stage of the ML lifecycle, from data ingestion to model deployment. With TFX, you can build robust, scalable, and reproducible machine learning systems.

Key benefits of TFX include:

  • End-to-end ML pipeline capabilities
  • Production-ready components that work at scale
  • Data validation and monitoring
  • Model versioning and metadata tracking
  • Integration with popular orchestration platforms

Additional Resources

Exercises

  1. Install TFX and run the simple pipeline example provided in this tutorial.
  2. Modify the pipeline to use a different dataset (like the Iris dataset).
  3. Add a new component to the pipeline (like a Transform component for feature engineering).
  4. Try orchestrating the pipeline with a different runner (like Apache Beam).
  5. Add model analysis to explore how well your model performs across different data slices.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)