Skip to main content

TensorFlow Extended Orchestration

Introduction

TensorFlow Extended (TFX) is an end-to-end platform for deploying production ML pipelines. One of the most powerful aspects of TFX is its orchestration capabilities, which allow you to automate and manage complex machine learning workflows. In this tutorial, we'll explore how to orchestrate TFX pipelines to create reliable, reproducible, and production-ready machine learning systems.

Orchestration in TFX refers to the process of coordinating the execution of different components in your ML workflow. Think of it as the conductor of an orchestra, ensuring each section plays at the right time and in harmony with others.

TFX Orchestration Basics

What is Pipeline Orchestration?

Pipeline orchestration is the process of:

  1. Defining the sequence and dependencies between components
  2. Scheduling when components should run
  3. Handling resource allocation
  4. Managing the flow of data between components
  5. Providing monitoring and logging capabilities

TFX supports multiple orchestration platforms, including Apache Airflow, Kubeflow Pipelines, and Apache Beam.

Components of a TFX Pipeline

Before diving into orchestration, let's briefly review the key components that make up a TFX pipeline:

  • ExampleGen: Ingests and splits the input dataset
  • StatisticsGen: Computes statistics on the dataset
  • SchemaGen: Infers a schema for the dataset
  • ExampleValidator: Validates examples against the schema
  • Transform: Performs feature engineering
  • Trainer: Trains the model
  • Evaluator: Evaluates the model
  • Pusher: Deploys the model to a serving environment

Creating Your First TFX Pipeline

Let's start by creating a simple TFX pipeline. First, we need to install TFX:

bash
pip install tfx

Now, let's define a simple pipeline:

python
import tfx.v1 as tfx
from tfx.orchestration import pipeline
from tfx.orchestration.local.local_dag_runner import LocalDagRunner

# Pipeline name and root directory for metadata and pipelines
PIPELINE_NAME = "my_first_pipeline"
PIPELINE_ROOT = "/tmp/tfx_pipeline_example"
METADATA_PATH = f"{PIPELINE_ROOT}/metadata.db"
DATA_PATH = "data/mnist"

# Define our components
example_gen = tfx.components.CsvExampleGen(input_base=DATA_PATH)
statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=True)
example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema'])
trainer = tfx.components.Trainer(
module_file="trainer.py",
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
train_args=tfx.proto.TrainArgs(num_steps=1000),
eval_args=tfx.proto.EvalArgs(num_steps=500))

# Define the pipeline
p = pipeline.Pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_root=PIPELINE_ROOT,
components=[
example_gen,
statistics_gen,
schema_gen,
example_validator,
trainer,
],
metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(METADATA_PATH)
)

# Run the pipeline
LocalDagRunner().run(p)

This example creates a simple pipeline that processes CSV data, generates statistics, creates a schema, validates examples against the schema, and trains a model.

Local Orchestration vs. Production Orchestration

Local Orchestration

For development and testing, TFX provides a local orchestrator:

python
from tfx.orchestration.local.local_dag_runner import LocalDagRunner

# Run pipeline locally
LocalDagRunner().run(pipeline)

Output:

Running pipeline on BeamDagRunner.
Pipeline run completed successfully.

This is great for prototyping, but not suitable for production environments.

Apache Airflow Orchestration

For production-grade orchestration, Apache Airflow is a popular choice:

python
from tfx.orchestration.airflow.airflow_dag_runner import AirflowDagRunner
from tfx.orchestration.airflow.airflow_pipeline_config import AirflowPipelineConfig

# Define Airflow-specific configs
airflow_config = AirflowPipelineConfig(
schedule_interval='0 0 * * *', # Run daily at midnight
start_date=datetime.datetime(2023, 1, 1)
)

# Create an Airflow DAG and run
AirflowDagRunner(config=airflow_config).run(pipeline)

Kubeflow Pipelines Orchestration

For Kubernetes-based orchestration:

python
from tfx.orchestration.kubeflow.kubeflow_dag_runner import KubeflowDagRunner
from tfx.orchestration.kubeflow.kubeflow_run_config import KubeflowRunConfig

# Define Kubeflow-specific configs
kubeflow_config = KubeflowRunConfig(
gcp_project_id="your-gcp-project",
gcp_region="us-central1",
tensorflow_version="2.8.0"
)

# Create a Kubeflow pipeline and compile
KubeflowDagRunner(config=kubeflow_config).run(pipeline)

Pipeline Configuration Parameters

To make your pipelines more flexible, you can use pipeline parameters:

python
from tfx.orchestration import data_types

# Define pipeline parameters
num_train_steps = data_types.RuntimeParameter(
name='train_steps',
default=1000,
ptype=int
)

trainer = tfx.components.Trainer(
module_file="trainer.py",
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
train_args=tfx.proto.TrainArgs(num_steps=num_train_steps),
eval_args=tfx.proto.EvalArgs(num_steps=500)
)

This allows you to specify the number of training steps at runtime rather than hardcoding it.

Real-World Example: Continuous Training Pipeline

Let's look at a more complete example of a pipeline that continuously retrains a model as new data arrives:

python
import datetime
import os
import tfx.v1 as tfx
from tfx.orchestration import pipeline
from tfx.orchestration.airflow.airflow_dag_runner import AirflowDagRunner
from tfx.orchestration.airflow.airflow_pipeline_config import AirflowPipelineConfig

# Pipeline constants
PIPELINE_NAME = "continuous_training"
PIPELINE_ROOT = os.path.join("gs://your-gcs-bucket", PIPELINE_NAME)
DATA_ROOT = "gs://your-gcs-bucket/data"
MODULE_FILE = "gs://your-gcs-bucket/modules/trainer.py"
SERVING_MODEL_DIR = os.path.join(PIPELINE_ROOT, "serving_model")

# Define components
example_gen = tfx.components.BigQueryExampleGen(
query="SELECT * FROM your_dataset.your_table WHERE date > '2023-01-01'"
)

statistics_gen = tfx.components.StatisticsGen(
examples=example_gen.outputs['examples']
)

schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics']
)

example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)

transform = tfx.components.Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=MODULE_FILE
)

trainer = tfx.components.Trainer(
module_file=MODULE_FILE,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
train_args=tfx.proto.TrainArgs(num_steps=10000),
eval_args=tfx.proto.EvalArgs(num_steps=5000)
)

model_resolver = tfx.dsl.Resolver(
strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
model_blessing=tfx.dsl.Channel(type=tfx.types.standard_artifacts.ModelBlessing)
).with_id('latest_blessed_model_resolver')

evaluator = tfx.components.Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
baseline_model=model_resolver.outputs['model'],
eval_config=tfx.proto.EvalConfig(...)
)

pusher = tfx.components.Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=SERVING_MODEL_DIR
)
)
)

# Define pipeline
tfx_pipeline = pipeline.Pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_root=PIPELINE_ROOT,
components=[
example_gen,
statistics_gen,
schema_gen,
example_validator,
transform,
trainer,
model_resolver,
evaluator,
pusher
],
enable_cache=True
)

# Orchestrate with Airflow
airflow_config = AirflowPipelineConfig(
schedule_interval='0 0 * * *', # Daily at midnight
start_date=datetime.datetime(2023, 1, 1)
)

# Run the pipeline with Airflow
AirflowDagRunner(config=airflow_config).run(tfx_pipeline)

In this example:

  1. We use BigQuery as our data source
  2. We incorporate a transform step for feature engineering
  3. We use a model resolver to fetch the latest blessed model for comparison
  4. We evaluate the new model against the baseline
  5. If the new model passes evaluation, we push it to serving

Advanced Orchestration Features

Pipeline Caching

TFX can cache the outputs of pipeline components to avoid redundant computations:

python
# Enable caching in the pipeline
tfx_pipeline = pipeline.Pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_root=PIPELINE_ROOT,
components=[...],
enable_cache=True # This enables caching
)

Custom Component Orchestration

You can create custom components to fit your specific needs:

python
from tfx.dsl.components.base import executor_spec
from tfx.dsl.components.base import base_component
from tfx.types import standard_artifacts

class CustomComponent(base_component.BaseComponent):
"""My custom TFX component."""

SPEC_CLASS = MyCustomComponentSpec
EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(MyCustomExecutor)

def __init__(self, input_data, parameter1=None):
spec = self.SPEC_CLASS(
input_data=input_data,
parameter1=parameter1,
output=standard_artifacts.Model()
)
super().__init__(spec=spec)

Conditionals and Dynamic Pipelines

TFX supports conditional execution in pipelines (available in newer versions):

python
from tfx.dsl.control_flow import condition

# Create components as before
example_gen = tfx.components.CsvExampleGen(input_base=DATA_PATH)
statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])

# Define conditional component
with condition(lambda: should_validate_examples):
example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)

Common Orchestration Issues & Troubleshooting

Dependency Conflicts

When running in production, make sure all dependencies are installed on your orchestration platform:

python
# In your Dockerfile or requirements.txt
tfx==1.12.0
apache-airflow==2.5.1

Monitoring Pipeline Execution

Use TFX's metadata store to monitor your pipeline:

python
from tfx.orchestration.metadata import Metadata

# Connect to metadata store
metadata_connection = Metadata(metadata_connection_config)

# Query execution history
executions = metadata_connection.store.get_executions()
for execution in executions:
print(f"Execution ID: {execution.id}, State: {execution.last_known_state}")

Restarting Failed Pipelines

In production environments, you often need to handle failures gracefully:

python
try:
runner.run(pipeline)
except Exception as e:
# Log the error
logging.error(f"Pipeline failed: {e}")

# Cleanup temporary resources
cleanup_resources()

# Optionally restart the pipeline
if should_retry:
runner.run(pipeline)

Summary

TFX orchestration provides a powerful framework for automating machine learning workflows, from development to production. By understanding the various orchestration options and how to configure them, you can build reliable, reproducible ML pipelines that scale with your needs.

In this tutorial, we've covered:

  • The basics of TFX pipeline orchestration
  • Creating and running pipelines with different orchestrators (local, Airflow, Kubeflow)
  • Configuring pipelines with parameters
  • Building a real-world continuous training pipeline
  • Advanced features like caching and custom components
  • Troubleshooting common issues

With these tools and techniques, you're well-equipped to orchestrate your own TFX pipelines for any ML use case.

Additional Resources

Exercises

  1. Create a simple TFX pipeline that processes a CSV dataset and trains a basic model.
  2. Modify your pipeline to add a Transform component for feature engineering.
  3. Convert your local pipeline to run on Apache Airflow or Kubeflow Pipelines.
  4. Implement a custom component that sends notifications when training completes.
  5. Build a pipeline that implements A/B testing between two model architectures.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)