TensorFlow Extended Orchestration
Introduction
TensorFlow Extended (TFX) is an end-to-end platform for deploying production ML pipelines. One of the most powerful aspects of TFX is its orchestration capabilities, which allow you to automate and manage complex machine learning workflows. In this tutorial, we'll explore how to orchestrate TFX pipelines to create reliable, reproducible, and production-ready machine learning systems.
Orchestration in TFX refers to the process of coordinating the execution of different components in your ML workflow. Think of it as the conductor of an orchestra, ensuring each section plays at the right time and in harmony with others.
TFX Orchestration Basics
What is Pipeline Orchestration?
Pipeline orchestration is the process of:
- Defining the sequence and dependencies between components
- Scheduling when components should run
- Handling resource allocation
- Managing the flow of data between components
- Providing monitoring and logging capabilities
TFX supports multiple orchestration platforms, including Apache Airflow, Kubeflow Pipelines, and Apache Beam.
Components of a TFX Pipeline
Before diving into orchestration, let's briefly review the key components that make up a TFX pipeline:
- ExampleGen: Ingests and splits the input dataset
- StatisticsGen: Computes statistics on the dataset
- SchemaGen: Infers a schema for the dataset
- ExampleValidator: Validates examples against the schema
- Transform: Performs feature engineering
- Trainer: Trains the model
- Evaluator: Evaluates the model
- Pusher: Deploys the model to a serving environment
Creating Your First TFX Pipeline
Let's start by creating a simple TFX pipeline. First, we need to install TFX:
pip install tfx
Now, let's define a simple pipeline:
import tfx.v1 as tfx
from tfx.orchestration import pipeline
from tfx.orchestration.local.local_dag_runner import LocalDagRunner
# Pipeline name and root directory for metadata and pipelines
PIPELINE_NAME = "my_first_pipeline"
PIPELINE_ROOT = "/tmp/tfx_pipeline_example"
METADATA_PATH = f"{PIPELINE_ROOT}/metadata.db"
DATA_PATH = "data/mnist"
# Define our components
example_gen = tfx.components.CsvExampleGen(input_base=DATA_PATH)
statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=True)
example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema'])
trainer = tfx.components.Trainer(
module_file="trainer.py",
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
train_args=tfx.proto.TrainArgs(num_steps=1000),
eval_args=tfx.proto.EvalArgs(num_steps=500))
# Define the pipeline
p = pipeline.Pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_root=PIPELINE_ROOT,
components=[
example_gen,
statistics_gen,
schema_gen,
example_validator,
trainer,
],
metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(METADATA_PATH)
)
# Run the pipeline
LocalDagRunner().run(p)
This example creates a simple pipeline that processes CSV data, generates statistics, creates a schema, validates examples against the schema, and trains a model.
Local Orchestration vs. Production Orchestration
Local Orchestration
For development and testing, TFX provides a local orchestrator:
from tfx.orchestration.local.local_dag_runner import LocalDagRunner
# Run pipeline locally
LocalDagRunner().run(pipeline)
Output:
Running pipeline on BeamDagRunner.
Pipeline run completed successfully.
This is great for prototyping, but not suitable for production environments.
Apache Airflow Orchestration
For production-grade orchestration, Apache Airflow is a popular choice:
from tfx.orchestration.airflow.airflow_dag_runner import AirflowDagRunner
from tfx.orchestration.airflow.airflow_pipeline_config import AirflowPipelineConfig
# Define Airflow-specific configs
airflow_config = AirflowPipelineConfig(
schedule_interval='0 0 * * *', # Run daily at midnight
start_date=datetime.datetime(2023, 1, 1)
)
# Create an Airflow DAG and run
AirflowDagRunner(config=airflow_config).run(pipeline)
Kubeflow Pipelines Orchestration
For Kubernetes-based orchestration:
from tfx.orchestration.kubeflow.kubeflow_dag_runner import KubeflowDagRunner
from tfx.orchestration.kubeflow.kubeflow_run_config import KubeflowRunConfig
# Define Kubeflow-specific configs
kubeflow_config = KubeflowRunConfig(
gcp_project_id="your-gcp-project",
gcp_region="us-central1",
tensorflow_version="2.8.0"
)
# Create a Kubeflow pipeline and compile
KubeflowDagRunner(config=kubeflow_config).run(pipeline)
Pipeline Configuration Parameters
To make your pipelines more flexible, you can use pipeline parameters:
from tfx.orchestration import data_types
# Define pipeline parameters
num_train_steps = data_types.RuntimeParameter(
name='train_steps',
default=1000,
ptype=int
)
trainer = tfx.components.Trainer(
module_file="trainer.py",
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
train_args=tfx.proto.TrainArgs(num_steps=num_train_steps),
eval_args=tfx.proto.EvalArgs(num_steps=500)
)
This allows you to specify the number of training steps at runtime rather than hardcoding it.
Real-World Example: Continuous Training Pipeline
Let's look at a more complete example of a pipeline that continuously retrains a model as new data arrives:
import datetime
import os
import tfx.v1 as tfx
from tfx.orchestration import pipeline
from tfx.orchestration.airflow.airflow_dag_runner import AirflowDagRunner
from tfx.orchestration.airflow.airflow_pipeline_config import AirflowPipelineConfig
# Pipeline constants
PIPELINE_NAME = "continuous_training"
PIPELINE_ROOT = os.path.join("gs://your-gcs-bucket", PIPELINE_NAME)
DATA_ROOT = "gs://your-gcs-bucket/data"
MODULE_FILE = "gs://your-gcs-bucket/modules/trainer.py"
SERVING_MODEL_DIR = os.path.join(PIPELINE_ROOT, "serving_model")
# Define components
example_gen = tfx.components.BigQueryExampleGen(
query="SELECT * FROM your_dataset.your_table WHERE date > '2023-01-01'"
)
statistics_gen = tfx.components.StatisticsGen(
examples=example_gen.outputs['examples']
)
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics']
)
example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)
transform = tfx.components.Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=MODULE_FILE
)
trainer = tfx.components.Trainer(
module_file=MODULE_FILE,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
train_args=tfx.proto.TrainArgs(num_steps=10000),
eval_args=tfx.proto.EvalArgs(num_steps=5000)
)
model_resolver = tfx.dsl.Resolver(
strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
model_blessing=tfx.dsl.Channel(type=tfx.types.standard_artifacts.ModelBlessing)
).with_id('latest_blessed_model_resolver')
evaluator = tfx.components.Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
baseline_model=model_resolver.outputs['model'],
eval_config=tfx.proto.EvalConfig(...)
)
pusher = tfx.components.Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=SERVING_MODEL_DIR
)
)
)
# Define pipeline
tfx_pipeline = pipeline.Pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_root=PIPELINE_ROOT,
components=[
example_gen,
statistics_gen,
schema_gen,
example_validator,
transform,
trainer,
model_resolver,
evaluator,
pusher
],
enable_cache=True
)
# Orchestrate with Airflow
airflow_config = AirflowPipelineConfig(
schedule_interval='0 0 * * *', # Daily at midnight
start_date=datetime.datetime(2023, 1, 1)
)
# Run the pipeline with Airflow
AirflowDagRunner(config=airflow_config).run(tfx_pipeline)
In this example:
- We use BigQuery as our data source
- We incorporate a transform step for feature engineering
- We use a model resolver to fetch the latest blessed model for comparison
- We evaluate the new model against the baseline
- If the new model passes evaluation, we push it to serving
Advanced Orchestration Features
Pipeline Caching
TFX can cache the outputs of pipeline components to avoid redundant computations:
# Enable caching in the pipeline
tfx_pipeline = pipeline.Pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_root=PIPELINE_ROOT,
components=[...],
enable_cache=True # This enables caching
)
Custom Component Orchestration
You can create custom components to fit your specific needs:
from tfx.dsl.components.base import executor_spec
from tfx.dsl.components.base import base_component
from tfx.types import standard_artifacts
class CustomComponent(base_component.BaseComponent):
"""My custom TFX component."""
SPEC_CLASS = MyCustomComponentSpec
EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(MyCustomExecutor)
def __init__(self, input_data, parameter1=None):
spec = self.SPEC_CLASS(
input_data=input_data,
parameter1=parameter1,
output=standard_artifacts.Model()
)
super().__init__(spec=spec)
Conditionals and Dynamic Pipelines
TFX supports conditional execution in pipelines (available in newer versions):
from tfx.dsl.control_flow import condition
# Create components as before
example_gen = tfx.components.CsvExampleGen(input_base=DATA_PATH)
statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])
# Define conditional component
with condition(lambda: should_validate_examples):
example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)
Common Orchestration Issues & Troubleshooting
Dependency Conflicts
When running in production, make sure all dependencies are installed on your orchestration platform:
# In your Dockerfile or requirements.txt
tfx==1.12.0
apache-airflow==2.5.1
Monitoring Pipeline Execution
Use TFX's metadata store to monitor your pipeline:
from tfx.orchestration.metadata import Metadata
# Connect to metadata store
metadata_connection = Metadata(metadata_connection_config)
# Query execution history
executions = metadata_connection.store.get_executions()
for execution in executions:
print(f"Execution ID: {execution.id}, State: {execution.last_known_state}")
Restarting Failed Pipelines
In production environments, you often need to handle failures gracefully:
try:
runner.run(pipeline)
except Exception as e:
# Log the error
logging.error(f"Pipeline failed: {e}")
# Cleanup temporary resources
cleanup_resources()
# Optionally restart the pipeline
if should_retry:
runner.run(pipeline)
Summary
TFX orchestration provides a powerful framework for automating machine learning workflows, from development to production. By understanding the various orchestration options and how to configure them, you can build reliable, reproducible ML pipelines that scale with your needs.
In this tutorial, we've covered:
- The basics of TFX pipeline orchestration
- Creating and running pipelines with different orchestrators (local, Airflow, Kubeflow)
- Configuring pipelines with parameters
- Building a real-world continuous training pipeline
- Advanced features like caching and custom components
- Troubleshooting common issues
With these tools and techniques, you're well-equipped to orchestrate your own TFX pipelines for any ML use case.
Additional Resources
- Official TFX Documentation
- TFX GitHub Repository
- Apache Airflow Documentation
- Kubeflow Pipelines Documentation
Exercises
- Create a simple TFX pipeline that processes a CSV dataset and trains a basic model.
- Modify your pipeline to add a Transform component for feature engineering.
- Convert your local pipeline to run on Apache Airflow or Kubeflow Pipelines.
- Implement a custom component that sends notifications when training completes.
- Build a pipeline that implements A/B testing between two model architectures.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)