Skip to main content

TensorFlow Metadata

TensorFlow Metadata (TFMD) is a crucial component of the TensorFlow Extended (TFX) ecosystem that helps you track and manage metadata for your machine learning workflows. It provides infrastructure and tools for recording and retrieving metadata about your datasets, models, and the execution of your ML pipelines.

Introduction to TensorFlow Metadata

In machine learning projects, keeping track of datasets, models, and the various transformations they undergo is essential for reproducibility, debugging, and collaboration. TensorFlow Metadata solves this challenge by providing a structured way to store and retrieve this information.

The core component of TensorFlow Metadata is ML Metadata (MLMD), which offers a standardized way to track:

  • Artifacts: Datasets, models, and other outputs of your pipeline components
  • Executions: Records of component runs in your ML pipeline
  • Contexts: Groupings of artifacts and executions (e.g., experiments, pipeline runs)

Let's dive into how you can use TensorFlow Metadata in your ML projects.

Setting Up TensorFlow Metadata

First, let's install the necessary packages:

bash
pip install tfx ml-metadata tensorflow-data-validation

Once installed, you can set up a metadata store which will be the central repository for your ML metadata:

python
from ml_metadata import metadata_store
from ml_metadata.proto import metadata_store_pb2

# Define the connection configuration
connection_config = metadata_store_pb2.ConnectionConfig()

# For an in-memory database (for testing)
connection_config.fake_database.SetInParent()

# For a SQLite database
# connection_config.sqlite.filename_uri = "metadata.sqlite"

# Create the metadata store
store = metadata_store.MetadataStore(connection_config)

Key Concepts in TensorFlow Metadata

1. Artifacts, Types, and Properties

An artifact represents a data object in your ML pipeline, such as a dataset or model. Let's see how to define and create artifacts:

python
# Define an artifact type
dataset_type = metadata_store_pb2.ArtifactType()
dataset_type.name = "Dataset"
dataset_type.properties["format"] = metadata_store_pb2.STRING
dataset_type.properties["size"] = metadata_store_pb2.INT
type_id = store.put_artifact_type(dataset_type)

# Create an artifact
dataset = metadata_store_pb2.Artifact()
dataset.type_id = type_id
dataset.properties["format"].string_value = "csv"
dataset.properties["size"].int_value = 1000
dataset_id = store.put_artifacts([dataset])[0]

# Retrieve artifact
retrieved_dataset = store.get_artifacts_by_id([dataset_id])[0]
print(f"Dataset format: {retrieved_dataset.properties['format'].string_value}")
print(f"Dataset size: {retrieved_dataset.properties['size'].int_value}")

Output:

Dataset format: csv
Dataset size: 1000

2. Executions and Events

Executions represent specific runs of your pipeline components. Events connect artifacts to executions:

python
# Define an execution type
preprocessing_type = metadata_store_pb2.ExecutionType()
preprocessing_type.name = "DataPreprocessing"
execution_type_id = store.put_execution_type(preprocessing_type)

# Create an execution
execution = metadata_store_pb2.Execution()
execution.type_id = execution_type_id
execution.properties["state"].string_value = "RUNNING"
execution_id = store.put_executions([execution])[0]

# Create an event to link artifact and execution
event = metadata_store_pb2.Event()
event.artifact_id = dataset_id
event.execution_id = execution_id
event.type = metadata_store_pb2.Event.INPUT
store.put_events([event])

# Mark execution as complete
execution = store.get_executions_by_id([execution_id])[0]
execution.properties["state"].string_value = "COMPLETED"
store.put_executions([execution])

3. Contexts

Contexts help group related artifacts and executions:

python
# Define a context type
experiment_type = metadata_store_pb2.ContextType()
experiment_type.name = "Experiment"
context_type_id = store.put_context_type(experiment_type)

# Create a context
context = metadata_store_pb2.Context()
context.type_id = context_type_id
context.name = "customer_churn_prediction"
context_id = store.put_contexts([context])[0]

# Associate artifacts and executions with this context
store.put_attributions_and_associations(
attributions=[(context_id, dataset_id)],
associations=[(context_id, execution_id)]
)

Using Schema with TensorFlow Data Validation

TensorFlow Metadata works closely with TensorFlow Data Validation (TFDV) to define and enforce schema for your datasets:

python
import tensorflow_data_validation as tfdv
import pandas as pd
import tensorflow as tf

# Sample dataset
data = pd.DataFrame({
'feature1': [1, 2, 3, 4, 5],
'feature2': ['a', 'b', 'c', 'd', 'e'],
'label': [0, 1, 0, 1, 0]
})

# Convert to TFRecord format
tf_examples = []
for _, row in data.iterrows():
feature = {
'feature1': tf.train.Feature(int64_list=tf.train.Int64List(value=[row['feature1']])),
'feature2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[row['feature2'].encode()])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[row['label']]))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
tf_examples.append(example.SerializeToString())

# Generate statistics
statistics = tfdv.generate_statistics_from_tfrecord(tf_examples)

# Infer schema
schema = tfdv.infer_schema(statistics)
print("Schema features:")
for feature in schema.feature:
print(f"- {feature.name}: {feature.type}")

# Save schema as an artifact
schema_artifact = metadata_store_pb2.Artifact()
schema_artifact.type_id = type_id # Reusing dataset_type from earlier
schema_artifact.properties["content_type"].string_value = "schema"
schema_uri = "/tmp/schema.pbtxt"
tfdv.write_schema_text(schema, schema_uri)
schema_artifact.uri = schema_uri
schema_artifact_id = store.put_artifacts([schema_artifact])[0]

Output:

Schema features:
- feature1: INT
- feature2: BYTES
- label: INT

Real-World Application: ML Pipeline with Metadata Tracking

Let's build a simple ML pipeline that tracks metadata at each stage:

python
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
from datetime import datetime

# 1. Define artifact types
model_type = metadata_store_pb2.ArtifactType()
model_type.name = "Model"
model_type.properties["framework"] = metadata_store_pb2.STRING
model_type.properties["accuracy"] = metadata_store_pb2.DOUBLE
model_type_id = store.put_artifact_type(model_type)

# 2. Create pipeline context
pipeline_type = metadata_store_pb2.ContextType()
pipeline_type.name = "Pipeline"
pipeline_type_id = store.put_context_type(pipeline_type)

pipeline = metadata_store_pb2.Context()
pipeline.type_id = pipeline_type_id
pipeline.name = f"prediction_pipeline_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
pipeline_id = store.put_contexts([pipeline])[0]

# 3. Data preparation with metadata tracking
data_prep_type = metadata_store_pb2.ExecutionType()
data_prep_type.name = "DataPreparation"
data_prep_type_id = store.put_execution_type(data_prep_type)

data_prep = metadata_store_pb2.Execution()
data_prep.type_id = data_prep_type_id
data_prep.properties["start_time"].string_value = datetime.now().isoformat()
data_prep_id = store.put_executions([data_prep])[0]

# Generate some synthetic data
X = np.random.rand(1000, 5)
y = np.sum(X, axis=1) > 2.5
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Create output artifact for prepared data
train_data = metadata_store_pb2.Artifact()
train_data.type_id = type_id # Reusing dataset_type from earlier
train_data.uri = "/tmp/train_data"
train_data.properties["split"].string_value = "train"
train_data.properties["num_examples"].int_value = len(X_train)
train_data_id = store.put_artifacts([train_data])[0]

# Link data preparation execution to output artifact
event = metadata_store_pb2.Event()
event.execution_id = data_prep_id
event.artifact_id = train_data_id
event.type = metadata_store_pb2.Event.OUTPUT
store.put_events([event])

# Associate with pipeline context
store.put_attributions_and_associations(
attributions=[(pipeline_id, train_data_id)],
associations=[(pipeline_id, data_prep_id)]
)

# Mark data preparation as complete
data_prep.properties["end_time"].string_value = datetime.now().isoformat()
data_prep.properties["state"].string_value = "COMPLETED"
store.put_executions([data_prep])

# 4. Model training with metadata tracking
training_type = metadata_store_pb2.ExecutionType()
training_type.name = "Training"
training_type_id = store.put_execution_type(training_type)

training = metadata_store_pb2.Execution()
training.type_id = training_type_id
training.properties["start_time"].string_value = datetime.now().isoformat()
training_id = store.put_executions([training])[0]

# Link training execution to input data
input_event = metadata_store_pb2.Event()
input_event.execution_id = training_id
input_event.artifact_id = train_data_id
input_event.type = metadata_store_pb2.Event.INPUT
store.put_events([input_event])

# Build a simple model
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation='relu', input_shape=(5,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
history = model.fit(X_train, y_train, epochs=5, verbose=0)

# Evaluate the model
test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)

# Create model artifact
model_artifact = metadata_store_pb2.Artifact()
model_artifact.type_id = model_type_id
model_artifact.uri = "/tmp/model"
model_artifact.properties["framework"].string_value = "tensorflow"
model_artifact.properties["accuracy"].double_value = test_accuracy
model_id = store.put_artifacts([model_artifact])[0]

# Save the model
model.save(model_artifact.uri)

# Link training execution to model artifact
output_event = metadata_store_pb2.Event()
output_event.execution_id = training_id
output_event.artifact_id = model_id
output_event.type = metadata_store_pb2.Event.OUTPUT
store.put_events([output_event])

# Associate with pipeline context
store.put_attributions_and_associations(
attributions=[(pipeline_id, model_id)],
associations=[(pipeline_id, training_id)]
)

# Mark training as complete
training.properties["end_time"].string_value = datetime.now().isoformat()
training.properties["state"].string_value = "COMPLETED"
training.properties["accuracy"].double_value = test_accuracy
store.put_executions([training])

print(f"Model trained with accuracy: {test_accuracy:.4f}")
print(f"Pipeline ID: {pipeline_id}")
print(f"Model artifact ID: {model_id}")

Output:

Model trained with accuracy: 0.8650
Pipeline ID: 1
Model artifact ID: 3

Querying Metadata for Insights

One of the key benefits of tracking metadata is the ability to query it to gain insights about your ML workflows:

python
# Get all models with accuracy > 0.8
artifact_query = metadata_store_pb2.ArtifactQuery()
artifact_query.type_ids.append(model_type_id)
models = store.get_artifacts(artifact_query)
good_models = [m for m in models if m.properties["accuracy"].double_value > 0.8]
print(f"Found {len(good_models)} models with accuracy > 0.8")

# Get the latest pipeline run
contexts = store.get_contexts_by_type("Pipeline")
latest_pipeline = max(contexts, key=lambda c: c.name)
print(f"Latest pipeline: {latest_pipeline.name}")

# Get all artifacts produced in the latest pipeline
artifact_ids = [a.id for a in store.get_attributions_by_context(latest_pipeline.id)]
pipeline_artifacts = store.get_artifacts_by_id(artifact_ids)
print(f"Artifacts in latest pipeline: {len(pipeline_artifacts)}")
for artifact in pipeline_artifacts:
artifact_type = store.get_artifact_types_by_id([artifact.type_id])[0]
print(f"- {artifact_type.name} ({artifact.uri})")

Output:

Found 1 models with accuracy > 0.8
Latest pipeline: prediction_pipeline_20231005_152022
Artifacts in latest pipeline: 2
- Dataset (/tmp/train_data)
- Model (/tmp/model)

Integration with TFX Components

TensorFlow Metadata seamlessly integrates with TFX components. Here's a simple example of how metadata works with TFX's ExampleGen component:

python
import tensorflow as tf
import tfx
from tfx.components import CsvExampleGen
from tfx.orchestration.metadata import Metadata
from tfx.orchestration.local.local_dag_runner import LocalDagRunner

# Define pipeline
def _create_pipeline():
example_gen = CsvExampleGen(input_base='/path/to/data')

return tfx.dsl.Pipeline(
pipeline_name='metadata_example_pipeline',
pipeline_root='/tmp/tfx_pipeline_output',
components=[example_gen],
enable_cache=True,
metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
'/tmp/tfx_metadata.db')
)

# The pipeline would typically be executed like this:
# LocalDagRunner().run(_create_pipeline())

When running this pipeline, TFX automatically creates and manages metadata for all components and artifacts. The metadata store can then be explored using TensorBoard's MLMD plugin or queried programmatically.

Summary

TensorFlow Metadata provides a robust system for tracking and managing metadata throughout your ML workflows. Key benefits include:

  • Traceability: Track the lineage of datasets and models
  • Reproducibility: Record execution parameters and outcomes
  • Collaboration: Share consistent metadata across team members
  • Automation: Enable automated pipelines to make decisions based on metadata
  • Compliance: Document data sources and transformations for regulatory requirements

By integrating metadata tracking in your ML pipelines, you can build more maintainable, reliable, and transparent systems.

Additional Resources

Exercises

  1. Create a metadata store using SQLite and track a simple dataset preprocessing workflow.
  2. Use TFDV to generate statistics for a dataset and store the schema in ML Metadata.
  3. Implement a TFX pipeline that uses metadata to conditionally train a model only if the data statistics match expectations.
  4. Build a query system to find the best-performing model across multiple experimental runs.
  5. Create a visualization that shows the lineage of a model (all the artifacts and executions that contributed to it).


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)