Skip to main content

TensorFlow Extended Deployment

Introduction

TensorFlow Extended (TFX) is Google's end-to-end platform for deploying production machine learning pipelines. One of the most crucial steps in the machine learning lifecycle is deployment - making your trained model available for use in real-world applications. In this guide, we'll explore how TFX helps you deploy models reliably and efficiently in production environments.

Model deployment brings its own set of challenges, including scaling to handle varying loads, ensuring consistent performance, versioning, and monitoring. TFX provides several components specifically designed to address these deployment challenges.

Key TFX Deployment Components

1. TensorFlow Serving

TensorFlow Serving is the primary deployment solution within TFX. It's designed to serve machine learning models in production environments with high performance and scalability.

Features of TensorFlow Serving:

  • Model versioning: Manage multiple versions of your model simultaneously
  • High performance: Optimized for CPU and GPU environments
  • Model updates: Update models without downtime
  • Standardized API: REST and gRPC interfaces
  • Batching: Process multiple requests efficiently

Let's look at a basic example of how to export a model for TensorFlow Serving:

python
import tensorflow as tf
import tempfile

# Create a simple model
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation='relu', input_shape=(5,)),
tf.keras.layers.Dense(1)
])

model.compile(optimizer='adam', loss='mse')

# Train with dummy data
x = tf.random.normal((100, 5))
y = tf.random.normal((100, 1))
model.fit(x, y, epochs=2)

# Export the model
MODEL_DIR = tempfile.gettempdir()
version = 1
export_path = f"{MODEL_DIR}/{version}"

tf.keras.models.save_model(
model,
export_path,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None
)

print(f"Model saved to: {export_path}")

2. Pusher Component

The Pusher component in TFX is responsible for deploying trained models to a serving infrastructure. It's the final stage in the TFX pipeline that takes a validated model and "pushes" it to your deployment target.

Here's an example of how to configure a Pusher component:

python
from tfx.components import Pusher
from tfx.proto import pusher_pb2

pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory='/serving_model/taxi_simple'
)
)
)

In this code:

  • We reference the model from the training component
  • We include a blessing from the evaluator component (deployment only happens if the model passes quality checks)
  • We define where to push the model (in this case, to a filesystem location)

Deployment Options with TFX

TFX provides flexibility in where and how you deploy your models. Here are the common deployment targets:

1. Local Filesystem Deployment

This is the simplest deployment method, where the model is exported to a directory on the local filesystem.

python
push_destination = pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory='/path/to/serving/models/dir'
)
)

2. Docker Container Deployment

For containerized deployments, you can use the TensorFlow Serving Docker container:

bash
docker pull tensorflow/serving

docker run -p 8501:8501 \
--mount type=bind,source=/path/to/my_model,target=/models/my_model \
-e MODEL_NAME=my_model -t tensorflow/serving

3. Cloud-Based Deployments

Google Cloud AI Platform

python
push_destination = pusher_pb2.PushDestination(
ai_platform=pusher_pb2.PushDestination.AIPlatform(
model_name='taxi_model',
version_name='v1'
)
)

Kubernetes with TF Serving

yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: tf-serving
spec:
replicas: 3
selector:
matchLabels:
app: tf-serving
template:
metadata:
labels:
app: tf-serving
spec:
containers:
- name: tf-serving
image: tensorflow/serving:latest
ports:
- containerPort: 8501
volumeMounts:
- name: model-volume
mountPath: /models/my_model
env:
- name: MODEL_NAME
value: "my_model"
volumes:
- name: model-volume
hostPath:
path: /path/to/model

Real-World Example: Customer Churn Prediction API

Let's create a more complete example of deploying a customer churn prediction model with TFX:

1. Define TFX Pipeline Components

python
import tensorflow as tf
import tensorflow_model_analysis as tfma
from tfx.components import (
CsvExampleGen, StatisticsGen, SchemaGen, ExampleValidator,
Transform, Trainer, Evaluator, Pusher
)
from tfx.proto import pusher_pb2, trainer_pb2
from tfx.utils.dsl_utils import external_input
from tfx.orchestration import pipeline

# Pipeline inputs
data_root = 'gs://your-bucket/customer_data'
pipeline_root = 'gs://your-bucket/pipelines'
serving_model_dir = 'gs://your-bucket/serving_models/churn_model'

# Define components
example_gen = CsvExampleGen(input=external_input(data_root))
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'])
example_validator = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)
transform = Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file='preprocessing.py'
)
trainer = Trainer(
module_file='model.py',
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000)
)
eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(signature_name='serving_default')],
metrics_specs=[
tfma.MetricsSpec(
metrics=[
tfma.MetricConfig(class_name='BinaryAccuracy'),
tfma.MetricConfig(class_name='AUC'),
]
)
],
slicing_specs=[tfma.SlicingSpec()]
)
evaluator = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
eval_config=eval_config
)
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=serving_model_dir
)
)
)

# Define the pipeline
p = pipeline.Pipeline(
pipeline_name='customer_churn_pipeline',
pipeline_root=pipeline_root,
components=[
example_gen, statistics_gen, schema_gen, example_validator,
transform, trainer, evaluator, pusher
]
)

2. Create a Preprocessing Module

python
# preprocessing.py
import tensorflow as tf
import tensorflow_transform as tft

def preprocessing_fn(inputs):
"""Preprocess input features into transformed features."""

# Extract features
numericals = ['age', 'tenure', 'monthly_charges', 'total_charges']
categoricals = ['gender', 'partner', 'dependents', 'phone_service']

outputs = {}

# Scale numerical features
for feature in numericals:
outputs[feature] = tft.scale_to_z_score(inputs[feature])

# Convert categorical features to one-hot
for feature in categoricals:
outputs[feature] = tft.compute_and_apply_vocabulary(
inputs[feature], vocab_filename=feature)

# Convert target to float
outputs['churn'] = tf.cast(inputs['churn'], tf.float32)

return outputs

3. Create the Model Training Module

python
# model.py
import tensorflow as tf
import tensorflow_transform as tft
from tensorflow.keras import layers
from tfx.components.trainer.fn_args_utils import FnArgs

def _get_serve_tf_examples_fn(model, tf_transform_output):
"""Returns a function that parses a serialized tf.Example."""

model.tft_layer = tf_transform_output.transform_features_layer()

@tf.function
def serve_tf_examples_fn(serialized_tf_examples):
"""Returns the output to be used in the serving signature."""
feature_spec = tf_transform_output.raw_feature_spec()
parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
transformed_features = model.tft_layer(parsed_features)
return model(transformed_features)

return serve_tf_examples_fn

def run_fn(fn_args: FnArgs):
"""Train the model."""
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

train_dataset = tf.data.experimental.make_batched_features_dataset(
file_pattern=fn_args.train_files,
batch_size=64,
features=tf_transform_output.transformed_feature_spec(),
reader=tf.data.TFRecordDataset,
shuffle=True)

eval_dataset = tf.data.experimental.make_batched_features_dataset(
file_pattern=fn_args.eval_files,
batch_size=64,
features=tf_transform_output.transformed_feature_spec(),
reader=tf.data.TFRecordDataset,
shuffle=False)

# Get feature columns
feature_spec = tf_transform_output.transformed_feature_spec()
feature_columns = []

# Create feature columns for each feature
for key, spec in feature_spec.items():
if key != 'churn': # Skip the label
feature_columns.append(tf.feature_column.numeric_column(key, shape=spec.shape))

# Create the model
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)

model = tf.keras.Sequential([
feature_layer,
layers.Dense(128, activation='relu'),
layers.Dropout(0.2),
layers.Dense(64, activation='relu'),
layers.Dense(1, activation='sigmoid')
])

model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.AUC()]
)

# Train the model
model.fit(
train_dataset,
epochs=10,
validation_data=eval_dataset
)

# Save the model with signatures
signatures = {
'serving_default': _get_serve_tf_examples_fn(
model, tf_transform_output).get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name='examples'))
}

model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)

4. Make API Requests to the Deployed Model

Once your model is deployed, you can make requests to it:

python
import json
import requests
import numpy as np

# Sample customer data
data = {
"instances": [
{
"age": 42,
"tenure": 24,
"monthly_charges": 65.5,
"total_charges": 1556.0,
"gender": "Male",
"partner": "Yes",
"dependents": "No",
"phone_service": "Yes"
}
]
}

# Make a prediction request
response = requests.post(
"http://localhost:8501/v1/models/churn_model:predict",
data=json.dumps(data)
)

# Process the response
predictions = response.json()["predictions"]
churn_probability = predictions[0][0]

print(f"Churn probability: {churn_probability:.2%}")
# Output example: Churn probability: 23.45%

# Business logic based on prediction
if churn_probability > 0.5:
print("High risk customer: Recommend retention strategy")
else:
print("Low risk customer: Standard engagement")

Best Practices for TFX Deployments

  1. Model Versioning: Always version your models to facilitate rollbacks if a new model doesn't perform as expected.

  2. A/B Testing: Deploy multiple model versions and send a portion of traffic to each to evaluate performance:

python
# Configure A/B testing in TF Serving
model_config = """
model_config_list {
config {
name: 'churn_model'
base_path: '/models/churn_model/'
model_platform: 'tensorflow'
model_version_policy {
specific {
versions: 1
versions: 2
}
}
version_labels {
key: 'stable'
value: 1
}
version_labels {
key: 'canary'
value: 2
}
}
}
"""

with open('/models/model_config.txt', 'w') as f:
f.write(model_config)
  1. Monitoring: Implement monitoring for your deployed models to detect performance degradation.
python
from tfx.components import ModelValidator

# Add model validation to pipeline
model_validator = ModelValidator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model']
)
  1. Scaling: Use Kubernetes or cloud-based solutions to scale with demand.

  2. CI/CD Integration: Automate the deployment process with continuous integration/continuous deployment pipelines.

Deployment Patterns

Pattern 1: Online Prediction

This pattern is used when you need real-time predictions, such as product recommendations during a user session.

python
# Client code for online prediction
def get_prediction(instances):
# Format instances for TF Serving
request_data = {"instances": instances}

# Make prediction request
response = requests.post(
"http://model-service:8501/v1/models/model:predict",
json=request_data
)

# Parse and return results
return response.json()["predictions"]

Pattern 2: Batch Prediction

This pattern is suitable for scenarios where predictions can be made in bulk, such as daily risk assessments.

python
# Batch prediction using TF Serving
import tensorflow as tf

batch_predictor = tf.saved_model.load('/models/churn_model/1')

# Load a batch of data
batch_data = tf.data.TFRecordDataset(['gs://your-bucket/data/batch_records.tfrecord'])
parsed_data = batch_data.map(parse_function) # Your parsing function

# Make predictions
predictions = []
for batch in parsed_data.batch(100):
batch_predictions = batch_predictor(batch)
predictions.extend(batch_predictions.numpy())

# Save results
with tf.io.TFRecordWriter('gs://your-bucket/predictions.tfrecord') as writer:
for pred in predictions:
example = create_example(pred) # Your function to create TF Examples
writer.write(example.SerializeToString())

Summary

TensorFlow Extended provides a comprehensive framework for deploying machine learning models to production. In this guide, we've covered:

  1. The key TFX deployment components like TensorFlow Serving and Pusher
  2. Various deployment options including local filesystem, Docker containers, and cloud platforms
  3. A real-world example of building and deploying a customer churn prediction model
  4. Best practices and common deployment patterns for TFX deployments

By following the practices outlined in this guide, you'll be able to deploy machine learning models reliably and efficiently using TFX, making them accessible for real-world applications.

Additional Resources

Exercises

  1. Exercise 1: Deploy a simple image classification model using TensorFlow Serving in a Docker container.
  2. Exercise 2: Set up a monitoring system for a deployed model using TensorFlow Model Analysis.
  3. Exercise 3: Create a TFX pipeline that includes all the components from data ingestion to deployment.
  4. Exercise 4: Implement an A/B testing scenario with two model versions using TensorFlow Serving.
  5. Exercise 5: Build a simple client application that makes prediction requests to your deployed model using the REST API.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)