Skip to main content

TensorFlow ML Metadata

Introduction

ML Metadata (MLMD) is a crucial component of TensorFlow Extended (TFX) that helps you track and manage the metadata associated with your machine learning workflows. In any machine learning system, keeping track of datasets, models, experiments, and their relationships is essential for reproducibility, debugging, and compliance.

MLMD serves as the "memory" of your ML pipeline by:

  1. Recording the lineage of ML artifacts (datasets, models, etc.)
  2. Tracking execution of components in your ML workflow
  3. Managing versioning of your ML assets
  4. Enabling experiment tracking and model governance

By the end of this tutorial, you'll understand how to use ML Metadata to enhance traceability and reproducibility in your TensorFlow projects.

Understanding ML Metadata Core Concepts

Before diving into code, let's understand the key concepts in MLMD:

Core Entities in ML Metadata

  1. Artifacts: The data objects produced and consumed by components (datasets, models, statistics)
  2. Executions: Records of component runs in a workflow
  3. Contexts: Logical groupings of artifacts and executions (like pipeline runs)
  4. Types: Type definitions for artifacts, executions, and contexts

Here's how these entities relate to each other:

Artifact <-- (input/output) --> Execution <-- (association) --> Context

Setting Up ML Metadata

Let's start by setting up a basic ML Metadata store:

python
# Import required packages
from ml_metadata import metadata_store
from ml_metadata.proto import metadata_store_pb2

# Create a connection config
connection_config = metadata_store_pb2.ConnectionConfig()

# For this example, we'll use a simple SQLite database
connection_config.sqlite.filename_uri = 'metadata.sqlite'

# Initialize the metadata store
store = metadata_store.MetadataStore(connection_config)

print("ML Metadata store initialized successfully!")

Output:

ML Metadata store initialized successfully!

Working with Artifact Types

Let's define artifact types for our ML workflow:

python
# Define a dataset type
dataset_type = metadata_store_pb2.ArtifactType()
dataset_type.name = "DataSet"
dataset_type.properties["size"] = metadata_store_pb2.INT
dataset_type.properties["format"] = metadata_store_pb2.STRING

# Register the type with the metadata store
dataset_type_id = store.put_artifact_type(dataset_type)

# Define a model type
model_type = metadata_store_pb2.ArtifactType()
model_type.name = "Model"
model_type.properties["framework"] = metadata_store_pb2.STRING
model_type.properties["accuracy"] = metadata_store_pb2.DOUBLE

# Register the model type
model_type_id = store.put_artifact_type(model_type)

print(f"Dataset Type ID: {dataset_type_id}")
print(f"Model Type ID: {model_type_id}")

Output:

Dataset Type ID: 1
Model Type ID: 2

Creating and Logging Artifacts

Now that we have our types registered, we can create artifacts:

python
# Create a dataset artifact
dataset = metadata_store_pb2.Artifact()
dataset.type_id = dataset_type_id
dataset.properties["size"].int_value = 10000
dataset.properties["format"].string_value = "tfrecord"
dataset.uri = "gs://my-bucket/datasets/mnist"

# Store the dataset in the metadata store
dataset_id = store.put_artifacts([dataset])[0]

# Create a model artifact
model = metadata_store_pb2.Artifact()
model.type_id = model_type_id
model.properties["framework"].string_value = "tensorflow"
model.properties["accuracy"].double_value = 0.95
model.uri = "gs://my-bucket/models/mnist-classifier"

# Store the model in the metadata store
model_id = store.put_artifacts([model])[0]

print(f"Dataset artifact created with ID: {dataset_id}")
print(f"Model artifact created with ID: {model_id}")

Output:

Dataset artifact created with ID: 1
Model artifact created with ID: 2

Tracking Executions

Let's define and track an execution (like a training run):

python
# Define an execution type for training
training_type = metadata_store_pb2.ExecutionType()
training_type.name = "Training"
training_type.properties["framework"] = metadata_store_pb2.STRING
training_type.properties["iterations"] = metadata_store_pb2.INT

# Register the execution type
training_type_id = store.put_execution_type(training_type)

# Create a training execution
training_run = metadata_store_pb2.Execution()
training_run.type_id = training_type_id
training_run.properties["framework"].string_value = "tensorflow"
training_run.properties["iterations"].int_value = 1000

# Register the execution
execution_id = store.put_executions([training_run])[0]

print(f"Training execution recorded with ID: {execution_id}")

Output:

Training execution recorded with ID: 1

Establishing Relationships

The real power of MLMD comes from establishing connections between artifacts and executions:

python
# Define the input/output relationships
input_event = metadata_store_pb2.Event()
input_event.artifact_id = dataset_id
input_event.execution_id = execution_id
input_event.type = metadata_store_pb2.Event.DECLARED_INPUT

output_event = metadata_store_pb2.Event()
output_event.artifact_id = model_id
output_event.execution_id = execution_id
output_event.type = metadata_store_pb2.Event.DECLARED_OUTPUT

# Record the events
store.put_events([input_event, output_event])

print("Relationship between dataset, training run, and model established!")

Output:

Relationship between dataset, training run, and model established!

Grouping with Contexts

To group related artifacts and executions, we can use contexts:

python
# Create a context type
experiment_type = metadata_store_pb2.ContextType()
experiment_type.name = "Experiment"
experiment_type.properties["description"] = metadata_store_pb2.STRING

# Register the context type
experiment_type_id = store.put_context_type(experiment_type)

# Create an experiment context
experiment = metadata_store_pb2.Context()
experiment.type_id = experiment_type_id
experiment.name = "mnist_classification_experiment"
experiment.properties["description"].string_value = "Image classification on MNIST"

# Register the context
experiment_id = store.put_contexts([experiment])[0]

# Add the execution to the context
store.put_attributions_and_associations(
[metadata_store_pb2.Attribution(context_id=experiment_id, artifact_id=dataset_id),
metadata_store_pb2.Attribution(context_id=experiment_id, artifact_id=model_id)],
[metadata_store_pb2.Association(context_id=experiment_id, execution_id=execution_id)]
)

print(f"Experiment context created with ID: {experiment_id}")

Output:

Experiment context created with ID: 1

Querying the Metadata Store

One of the key benefits of MLMD is the ability to query the metadata for analysis and tracking:

python
# Get all artifacts of type "Model"
models = store.get_artifacts_by_type("Model")
print(f"Found {len(models)} models in the metadata store")

# Get artifacts associated with our experiment
artifacts_in_exp = store.get_artifacts_by_context(experiment_id)
print(f"Found {len(artifacts_in_exp)} artifacts in experiment {experiment.name}")

# Get the lineage of our model
model_events = store.get_events_by_artifact_ids([model_id])
for event in model_events:
if event.type == metadata_store_pb2.Event.DECLARED_OUTPUT:
execution = store.get_executions_by_id([event.execution_id])[0]
print(f"Model was created by execution {event.execution_id} of type {store.get_execution_type(execution.type_id).name}")

Output:

Found 1 models in the metadata store
Found 2 artifacts in experiment mnist_classification_experiment
Model was created by execution 1 of type Training

Real-world Example: ML Pipeline with MLMD

Let's see how MLMD fits into a more realistic TFX pipeline. Here's a simplified example:

python
import tensorflow as tf
from ml_metadata import metadata_store
from ml_metadata.proto import metadata_store_pb2
import time

# Set up the metadata store
connection_config = metadata_store_pb2.ConnectionConfig()
connection_config.sqlite.filename_uri = 'pipeline_metadata.sqlite'
store = metadata_store.MetadataStore(connection_config)

# --- Register types ---
# Dataset type
dataset_type = metadata_store_pb2.ArtifactType()
dataset_type.name = "Dataset"
dataset_type_id = store.put_artifact_type(dataset_type)

# Statistics type
stats_type = metadata_store_pb2.ArtifactType()
stats_type.name = "Statistics"
stats_type_id = store.put_artifact_type(stats_type)

# Model type
model_type = metadata_store_pb2.ArtifactType()
model_type.name = "Model"
model_type_id = store.put_artifact_type(model_type)

# --- Register execution types ---
# Data ingestion
ingest_type = metadata_store_pb2.ExecutionType()
ingest_type.name = "DataIngestion"
ingest_type_id = store.put_execution_type(ingest_type)

# Data validation
validate_type = metadata_store_pb2.ExecutionType()
validate_type.name = "DataValidation"
validate_type_id = store.put_execution_type(validate_type)

# Training
train_type = metadata_store_pb2.ExecutionType()
train_type.name = "Training"
train_type_id = store.put_execution_type(train_type)

# --- Create pipeline context ---
pipeline_type = metadata_store_pb2.ContextType()
pipeline_type.name = "Pipeline"
pipeline_type_id = store.put_context_type(pipeline_type)

pipeline = metadata_store_pb2.Context()
pipeline.type_id = pipeline_type_id
pipeline.name = f"image_classification_pipeline_{int(time.time())}"
pipeline_id = store.put_contexts([pipeline])[0]

# --- Create pipeline nodes and track metadata ---

# 1. Data Ingestion
dataset = metadata_store_pb2.Artifact()
dataset.type_id = dataset_type_id
dataset.uri = "gs://my-bucket/datasets/images"
dataset_id = store.put_artifacts([dataset])[0]

ingest_exec = metadata_store_pb2.Execution()
ingest_exec.type_id = ingest_type_id
ingest_exec.last_known_state = metadata_store_pb2.Execution.COMPLETE
ingest_id = store.put_executions([ingest_exec])[0]

ingest_output = metadata_store_pb2.Event()
ingest_output.artifact_id = dataset_id
ingest_output.execution_id = ingest_id
ingest_output.type = metadata_store_pb2.Event.DECLARED_OUTPUT
store.put_events([ingest_output])

# 2. Data Validation
stats = metadata_store_pb2.Artifact()
stats.type_id = stats_type_id
stats.uri = "gs://my-bucket/stats/image_stats"
stats_id = store.put_artifacts([stats])[0]

validate_exec = metadata_store_pb2.Execution()
validate_exec.type_id = validate_type_id
validate_exec.last_known_state = metadata_store_pb2.Execution.COMPLETE
validate_id = store.put_executions([validate_exec])[0]

validate_input = metadata_store_pb2.Event()
validate_input.artifact_id = dataset_id
validate_input.execution_id = validate_id
validate_input.type = metadata_store_pb2.Event.DECLARED_INPUT

validate_output = metadata_store_pb2.Event()
validate_output.artifact_id = stats_id
validate_output.execution_id = validate_id
validate_output.type = metadata_store_pb2.Event.DECLARED_OUTPUT
store.put_events([validate_input, validate_output])

# 3. Model Training
model = metadata_store_pb2.Artifact()
model.type_id = model_type_id
model.uri = "gs://my-bucket/models/image_classifier"
model_id = store.put_artifacts([model])[0]

train_exec = metadata_store_pb2.Execution()
train_exec.type_id = train_type_id
train_exec.last_known_state = metadata_store_pb2.Execution.COMPLETE
train_id = store.put_executions([train_exec])[0]

train_input = metadata_store_pb2.Event()
train_input.artifact_id = dataset_id
train_input.execution_id = train_id
train_input.type = metadata_store_pb2.Event.DECLARED_INPUT

train_output = metadata_store_pb2.Event()
train_output.artifact_id = model_id
train_output.execution_id = train_id
train_output.type = metadata_store_pb2.Event.DECLARED_OUTPUT
store.put_events([train_input, train_output])

# Add all components to the pipeline context
store.put_attributions_and_associations(
[metadata_store_pb2.Attribution(context_id=pipeline_id, artifact_id=a_id)
for a_id in [dataset_id, stats_id, model_id]],
[metadata_store_pb2.Association(context_id=pipeline_id, execution_id=e_id)
for e_id in [ingest_id, validate_id, train_id]]
)

print(f"Pipeline with ID {pipeline_id} recorded successfully")

# Query for result
executions_in_pipeline = store.get_executions_by_context(pipeline_id)
print(f"Pipeline has {len(executions_in_pipeline)} executions")

# Find the final model
final_artifacts = store.get_artifacts_by_context(pipeline_id)
models = [a for a in final_artifacts if store.get_artifact_type(a.type_id).name == "Model"]
print(f"Pipeline produced {len(models)} models at: {models[0].uri}")

Output:

Pipeline with ID 1 recorded successfully
Pipeline has 3 executions
Pipeline produced 1 models at: gs://my-bucket/models/image_classifier

Integrating MLMD with TFX Components

In a complete TFX pipeline, ML Metadata is integrated automatically. Here's a simplified TFX pipeline that showcases how metadata is tracked:

python
# This is a conceptual example of how TFX uses MLMD under the hood
import tfx
from tfx.components import CsvExampleGen, StatisticsGen, Trainer
from tfx.orchestration import pipeline
from tfx.orchestration.local.local_dag_runner import LocalDagRunner

# Define pipeline components
example_gen = CsvExampleGen(input_base='/data/csv_data')
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
trainer = Trainer(
module_file='model.py',
examples=example_gen.outputs['examples'],
train_args={'num_steps': 1000},
eval_args={'num_steps': 500}
)

# Define the pipeline
tfx_pipeline = pipeline.Pipeline(
pipeline_name='metadata_example_pipeline',
pipeline_root='/tmp/tfx_pipeline',
components=[example_gen, statistics_gen, trainer],
enable_cache=True # This uses MLMD to detect if outputs can be reused
)

# The LocalDagRunner will use MLMD to track all artifacts and executions
LocalDagRunner().run(tfx_pipeline)

When this pipeline runs, behind the scenes TFX is:

  1. Recording all inputs and outputs as artifacts in MLMD
  2. Tracking each component execution with timestamps and status
  3. Establishing the lineage between data and models
  4. Using metadata to enable intelligent caching
  5. Grouping everything under a pipeline context

Visualizing Metadata

While not directly part of MLMD, TFX provides visualization tools for metadata:

python
# Example code to visualize metadata (not executable in this context)
# This would typically be run in a notebook connected to your metadata store

from ml_metadata.proto import metadata_store_pb2
from tfx.orchestration.experimental.interactive import visualizations

# Connect to your metadata store
store = ...

# Get a specific pipeline execution context
pipeline_context = store.get_contexts_by_type("Pipeline")[-1] # Latest pipeline

# Visualize the lineage
visualizations.plot_artifact_lineage(
store,
pipeline_context.id
)

This would generate an interactive graph showing how data flows through your pipeline components.

Best Practices

When working with ML Metadata, consider these best practices:

  1. Define clear type schemas: Have well-defined property schemas for your artifact types
  2. Use meaningful context groupings: Group related runs under semantic contexts (experiments, pipelines)
  3. Record detailed properties: The more metadata you record, the more useful your lineage tracking
  4. Query the metadata store regularly: Use the metadata for debugging, compliance, and analysis
  5. Consider backup strategies: For production, ensure your metadata store is backed up properly
  6. Version your types: As your ML system evolves, version your metadata types accordingly

Summary

ML Metadata (MLMD) is a powerful component of TensorFlow Extended that helps you track artifacts, executions, and their relationships throughout your machine learning workflows. Key benefits include:

  • Lineage tracking: Understanding how models were created from data
  • Reproducibility: Having a record of all inputs that produced a model
  • Caching: Avoiding recomputation of pipeline steps with matching inputs
  • Governance: Supporting compliance and auditability requirements
  • Debugging: Understanding what happened in complex ML pipelines

By integrating MLMD into your TFX workflows, you can create more robust, reproducible, and explainable machine learning systems.

Additional Resources

Exercises

  1. Create a metadata store and record a simple ML workflow with data preprocessing, training, and evaluation steps.
  2. Write queries to trace the lineage of a model back to its source data.
  3. Implement a versioning scheme for datasets using MLMD contexts.
  4. Build a simple dashboard that visualizes the relationships between artifacts in your ML workflow.
  5. Extend the metadata tracking to include information about the hardware used for training and the time taken.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)