TensorFlow Extended Deployment
Introduction
TensorFlow Extended (TFX) is Google's end-to-end platform for deploying production machine learning pipelines. One of the most crucial steps in the machine learning lifecycle is deployment - making your trained model available for use in real-world applications. In this guide, we'll explore how TFX helps you deploy models reliably and efficiently in production environments.
Model deployment brings its own set of challenges, including scaling to handle varying loads, ensuring consistent performance, versioning, and monitoring. TFX provides several components specifically designed to address these deployment challenges.
Key TFX Deployment Components
1. TensorFlow Serving
TensorFlow Serving is the primary deployment solution within TFX. It's designed to serve machine learning models in production environments with high performance and scalability.
Features of TensorFlow Serving:
- Model versioning: Manage multiple versions of your model simultaneously
- High performance: Optimized for CPU and GPU environments
- Model updates: Update models without downtime
- Standardized API: REST and gRPC interfaces
- Batching: Process multiple requests efficiently
Let's look at a basic example of how to export a model for TensorFlow Serving:
import tensorflow as tf
import tempfile
# Create a simple model
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation='relu', input_shape=(5,)),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse')
# Train with dummy data
x = tf.random.normal((100, 5))
y = tf.random.normal((100, 1))
model.fit(x, y, epochs=2)
# Export the model
MODEL_DIR = tempfile.gettempdir()
version = 1
export_path = f"{MODEL_DIR}/{version}"
tf.keras.models.save_model(
model,
export_path,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None
)
print(f"Model saved to: {export_path}")
2. Pusher Component
The Pusher component in TFX is responsible for deploying trained models to a serving infrastructure. It's the final stage in the TFX pipeline that takes a validated model and "pushes" it to your deployment target.
Here's an example of how to configure a Pusher component:
from tfx.components import Pusher
from tfx.proto import pusher_pb2
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory='/serving_model/taxi_simple'
)
)
)
In this code:
- We reference the model from the training component
- We include a blessing from the evaluator component (deployment only happens if the model passes quality checks)
- We define where to push the model (in this case, to a filesystem location)
Deployment Options with TFX
TFX provides flexibility in where and how you deploy your models. Here are the common deployment targets:
1. Local Filesystem Deployment
This is the simplest deployment method, where the model is exported to a directory on the local filesystem.
push_destination = pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory='/path/to/serving/models/dir'
)
)
2. Docker Container Deployment
For containerized deployments, you can use the TensorFlow Serving Docker container:
docker pull tensorflow/serving
docker run -p 8501:8501 \
--mount type=bind,source=/path/to/my_model,target=/models/my_model \
-e MODEL_NAME=my_model -t tensorflow/serving
3. Cloud-Based Deployments
Google Cloud AI Platform
push_destination = pusher_pb2.PushDestination(
ai_platform=pusher_pb2.PushDestination.AIPlatform(
model_name='taxi_model',
version_name='v1'
)
)
Kubernetes with TF Serving
apiVersion: apps/v1
kind: Deployment
metadata:
name: tf-serving
spec:
replicas: 3
selector:
matchLabels:
app: tf-serving
template:
metadata:
labels:
app: tf-serving
spec:
containers:
- name: tf-serving
image: tensorflow/serving:latest
ports:
- containerPort: 8501
volumeMounts:
- name: model-volume
mountPath: /models/my_model
env:
- name: MODEL_NAME
value: "my_model"
volumes:
- name: model-volume
hostPath:
path: /path/to/model
Real-World Example: Customer Churn Prediction API
Let's create a more complete example of deploying a customer churn prediction model with TFX:
1. Define TFX Pipeline Components
import tensorflow as tf
import tensorflow_model_analysis as tfma
from tfx.components import (
CsvExampleGen, StatisticsGen, SchemaGen, ExampleValidator,
Transform, Trainer, Evaluator, Pusher
)
from tfx.proto import pusher_pb2, trainer_pb2
from tfx.utils.dsl_utils import external_input
from tfx.orchestration import pipeline
# Pipeline inputs
data_root = 'gs://your-bucket/customer_data'
pipeline_root = 'gs://your-bucket/pipelines'
serving_model_dir = 'gs://your-bucket/serving_models/churn_model'
# Define components
example_gen = CsvExampleGen(input=external_input(data_root))
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'])
example_validator = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)
transform = Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file='preprocessing.py'
)
trainer = Trainer(
module_file='model.py',
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000)
)
eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(signature_name='serving_default')],
metrics_specs=[
tfma.MetricsSpec(
metrics=[
tfma.MetricConfig(class_name='BinaryAccuracy'),
tfma.MetricConfig(class_name='AUC'),
]
)
],
slicing_specs=[tfma.SlicingSpec()]
)
evaluator = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
eval_config=eval_config
)
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=serving_model_dir
)
)
)
# Define the pipeline
p = pipeline.Pipeline(
pipeline_name='customer_churn_pipeline',
pipeline_root=pipeline_root,
components=[
example_gen, statistics_gen, schema_gen, example_validator,
transform, trainer, evaluator, pusher
]
)
2. Create a Preprocessing Module
# preprocessing.py
import tensorflow as tf
import tensorflow_transform as tft
def preprocessing_fn(inputs):
"""Preprocess input features into transformed features."""
# Extract features
numericals = ['age', 'tenure', 'monthly_charges', 'total_charges']
categoricals = ['gender', 'partner', 'dependents', 'phone_service']
outputs = {}
# Scale numerical features
for feature in numericals:
outputs[feature] = tft.scale_to_z_score(inputs[feature])
# Convert categorical features to one-hot
for feature in categoricals:
outputs[feature] = tft.compute_and_apply_vocabulary(
inputs[feature], vocab_filename=feature)
# Convert target to float
outputs['churn'] = tf.cast(inputs['churn'], tf.float32)
return outputs
3. Create the Model Training Module
# model.py
import tensorflow as tf
import tensorflow_transform as tft
from tensorflow.keras import layers
from tfx.components.trainer.fn_args_utils import FnArgs
def _get_serve_tf_examples_fn(model, tf_transform_output):
"""Returns a function that parses a serialized tf.Example."""
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function
def serve_tf_examples_fn(serialized_tf_examples):
"""Returns the output to be used in the serving signature."""
feature_spec = tf_transform_output.raw_feature_spec()
parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
transformed_features = model.tft_layer(parsed_features)
return model(transformed_features)
return serve_tf_examples_fn
def run_fn(fn_args: FnArgs):
"""Train the model."""
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
train_dataset = tf.data.experimental.make_batched_features_dataset(
file_pattern=fn_args.train_files,
batch_size=64,
features=tf_transform_output.transformed_feature_spec(),
reader=tf.data.TFRecordDataset,
shuffle=True)
eval_dataset = tf.data.experimental.make_batched_features_dataset(
file_pattern=fn_args.eval_files,
batch_size=64,
features=tf_transform_output.transformed_feature_spec(),
reader=tf.data.TFRecordDataset,
shuffle=False)
# Get feature columns
feature_spec = tf_transform_output.transformed_feature_spec()
feature_columns = []
# Create feature columns for each feature
for key, spec in feature_spec.items():
if key != 'churn': # Skip the label
feature_columns.append(tf.feature_column.numeric_column(key, shape=spec.shape))
# Create the model
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
model = tf.keras.Sequential([
feature_layer,
layers.Dense(128, activation='relu'),
layers.Dropout(0.2),
layers.Dense(64, activation='relu'),
layers.Dense(1, activation='sigmoid')
])
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.AUC()]
)
# Train the model
model.fit(
train_dataset,
epochs=10,
validation_data=eval_dataset
)
# Save the model with signatures
signatures = {
'serving_default': _get_serve_tf_examples_fn(
model, tf_transform_output).get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name='examples'))
}
model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
4. Make API Requests to the Deployed Model
Once your model is deployed, you can make requests to it:
import json
import requests
import numpy as np
# Sample customer data
data = {
"instances": [
{
"age": 42,
"tenure": 24,
"monthly_charges": 65.5,
"total_charges": 1556.0,
"gender": "Male",
"partner": "Yes",
"dependents": "No",
"phone_service": "Yes"
}
]
}
# Make a prediction request
response = requests.post(
"http://localhost:8501/v1/models/churn_model:predict",
data=json.dumps(data)
)
# Process the response
predictions = response.json()["predictions"]
churn_probability = predictions[0][0]
print(f"Churn probability: {churn_probability:.2%}")
# Output example: Churn probability: 23.45%
# Business logic based on prediction
if churn_probability > 0.5:
print("High risk customer: Recommend retention strategy")
else:
print("Low risk customer: Standard engagement")
Best Practices for TFX Deployments
-
Model Versioning: Always version your models to facilitate rollbacks if a new model doesn't perform as expected.
-
A/B Testing: Deploy multiple model versions and send a portion of traffic to each to evaluate performance:
# Configure A/B testing in TF Serving
model_config = """
model_config_list {
config {
name: 'churn_model'
base_path: '/models/churn_model/'
model_platform: 'tensorflow'
model_version_policy {
specific {
versions: 1
versions: 2
}
}
version_labels {
key: 'stable'
value: 1
}
version_labels {
key: 'canary'
value: 2
}
}
}
"""
with open('/models/model_config.txt', 'w') as f:
f.write(model_config)
- Monitoring: Implement monitoring for your deployed models to detect performance degradation.
from tfx.components import ModelValidator
# Add model validation to pipeline
model_validator = ModelValidator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model']
)
-
Scaling: Use Kubernetes or cloud-based solutions to scale with demand.
-
CI/CD Integration: Automate the deployment process with continuous integration/continuous deployment pipelines.
Deployment Patterns
Pattern 1: Online Prediction
This pattern is used when you need real-time predictions, such as product recommendations during a user session.
# Client code for online prediction
def get_prediction(instances):
# Format instances for TF Serving
request_data = {"instances": instances}
# Make prediction request
response = requests.post(
"http://model-service:8501/v1/models/model:predict",
json=request_data
)
# Parse and return results
return response.json()["predictions"]
Pattern 2: Batch Prediction
This pattern is suitable for scenarios where predictions can be made in bulk, such as daily risk assessments.
# Batch prediction using TF Serving
import tensorflow as tf
batch_predictor = tf.saved_model.load('/models/churn_model/1')
# Load a batch of data
batch_data = tf.data.TFRecordDataset(['gs://your-bucket/data/batch_records.tfrecord'])
parsed_data = batch_data.map(parse_function) # Your parsing function
# Make predictions
predictions = []
for batch in parsed_data.batch(100):
batch_predictions = batch_predictor(batch)
predictions.extend(batch_predictions.numpy())
# Save results
with tf.io.TFRecordWriter('gs://your-bucket/predictions.tfrecord') as writer:
for pred in predictions:
example = create_example(pred) # Your function to create TF Examples
writer.write(example.SerializeToString())
Summary
TensorFlow Extended provides a comprehensive framework for deploying machine learning models to production. In this guide, we've covered:
- The key TFX deployment components like TensorFlow Serving and Pusher
- Various deployment options including local filesystem, Docker containers, and cloud platforms
- A real-world example of building and deploying a customer churn prediction model
- Best practices and common deployment patterns for TFX deployments
By following the practices outlined in this guide, you'll be able to deploy machine learning models reliably and efficiently using TFX, making them accessible for real-world applications.
Additional Resources
- TensorFlow Serving Documentation
- TFX Pusher Component Guide
- Model Deployment on Google Cloud AI Platform
- TensorFlow Extended: End-to-End ML Pipelines
Exercises
- Exercise 1: Deploy a simple image classification model using TensorFlow Serving in a Docker container.
- Exercise 2: Set up a monitoring system for a deployed model using TensorFlow Model Analysis.
- Exercise 3: Create a TFX pipeline that includes all the components from data ingestion to deployment.
- Exercise 4: Implement an A/B testing scenario with two model versions using TensorFlow Serving.
- Exercise 5: Build a simple client application that makes prediction requests to your deployed model using the REST API.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)