TensorFlow Object Detection
Introduction
Object detection is a computer vision technique that allows us to identify and locate objects in images or videos. Unlike image classification which only tells us what objects are present, object detection provides both the class label and the spatial location (bounding box) of each object.
In this guide, we'll explore how to use TensorFlow's Object Detection API, a powerful framework built on top of TensorFlow that simplifies the process of training and deploying object detection models. We'll cover:
- Understanding the fundamentals of object detection
- Setting up the TensorFlow Object Detection API
- Using pre-trained models for inference
- Fine-tuning models on custom datasets
- Deploying object detection models in real-world applications
Understanding Object Detection
Before diving into code, let's understand the key concepts of object detection:
- Classification: Identifying what objects are in an image
- Localization: Finding where those objects are (via bounding boxes)
- Multiple object detection: Identifying and locating multiple objects simultaneously
Object detection models typically output:
- Bounding box coordinates (usually as [x_min, y_min, x_max, y_max] or [x, y, width, height])
- Class labels for each detected object
- Confidence scores indicating the model's certainty about each detection
Setting Up TensorFlow Object Detection API
The TensorFlow Object Detection API is a collection of pre-trained models and utilities for training, evaluating, and deploying object detection models.
Installation
# Install TensorFlow
pip install tensorflow
# Clone the TensorFlow Models repository
git clone https://github.com/tensorflow/models.git
# Install the Object Detection API
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
pip install .
Verify Installation
Let's verify the installation:
import tensorflow as tf
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
print(tf.__version__)
Output:
2.13.0
Using Pre-trained Models for Inference
TensorFlow's model zoo contains many pre-trained object detection models with different speed/accuracy trade-offs. Let's use a pre-trained model to detect objects in an image:
1. Load a Pre-trained Model
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
# Load a model from TensorFlow Hub
model_url = "http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_mobilenet_v2_320x320_coco17_tpu-8.tar.gz"
model_dir = tf.keras.utils.get_file('ssd_mobilenet_v2', model_url, untar=True)
model_dir = model_dir + "/saved_model"
# Load the model
model = tf.saved_model.load(str(model_dir))
detect_fn = model.signatures['serving_default']
2. Prepare Input and Run Inference
# Load and prepare an image
def load_image_into_numpy_array(path):
image = Image.open(path)
image_np = np.array(image)
return image_np
image_path = 'sample_image.jpg' # Replace with your image path
image_np = load_image_into_numpy_array(image_path)
# Convert to tensor and run inference
input_tensor = tf.convert_to_tensor(image_np)
input_tensor = input_tensor[tf.newaxis, ...]
# Run inference
detections = detect_fn(input_tensor)
3. Process and Visualize Results
# Load COCO label map
PATH_TO_LABELS = 'object_detection/data/mscoco_label_map.pbtxt'
category_index = label_map_util.create_category_index_from_labelmap(
PATH_TO_LABELS, use_display_name=True)
# Visualization function
def visualize_detections(image_np, detections, category_index, threshold=0.5):
image_with_detections = image_np.copy()
# Get detections
boxes = detections['detection_boxes'][0].numpy()
classes = detections['detection_classes'][0].numpy().astype(np.int32)
scores = detections['detection_scores'][0].numpy()
# Visualization
vis_util.visualize_boxes_and_labels_on_image_array(
image_with_detections,
boxes,
classes,
scores,
category_index,
use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=threshold,
agnostic_mode=False)
return image_with_detections
# Visualize detections
image_with_detections = visualize_detections(image_np, detections, category_index)
# Display the result
plt.figure(figsize=(12, 8))
plt.imshow(image_with_detections)
plt.axis('off')
plt.show()
This code will display your image with bounding boxes around detected objects, along with class names and confidence scores.
Fine-tuning on Custom Datasets
To detect custom objects, you need to fine-tune a pre-trained model on your dataset. Here's the general process:
1. Prepare Your Dataset
Your dataset should include:
- Images with your objects
- Annotation files (usually in XML or JSON format) containing bounding box coordinates and class labels
Convert your dataset to TFRecord format:
import tensorflow as tf
from object_detection.utils import dataset_util
def create_tf_example(example):
# TODO: Convert your data sample to a TF Example
height = ... # Image height
width = ... # Image width
filename = ... # Image filename
encoded_image = ... # Encoded image bytes
image_format = ... # Image format (e.g., 'jpeg' or 'png')
xmins = [] # List of normalized x coordinates for the bounding boxes (left)
xmaxs = [] # List of normalized x coordinates for the bounding boxes (right)
ymins = [] # List of normalized y coordinates for the bounding boxes (top)
ymaxs = [] # List of normalized y coordinates for the bounding boxes (bottom)
classes_text = [] # List of string class names
classes = [] # List of integer class IDs
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename.encode('utf8')),
'image/source_id': dataset_util.bytes_feature(filename.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_image),
'image/format': dataset_util.bytes_feature(image_format.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
2. Create a Label Map
Create a label_map.pbtxt
file that maps class IDs to class names:
item {
id: 1
name: 'dog'
}
item {
id: 2
name: 'cat'
}
item {
id: 3
name: 'bird'
}
3. Configure Training Pipeline
Create a pipeline configuration file based on the model you're fine-tuning:
pipeline_config = """
model {
ssd {
num_classes: 3 # Set to your number of classes
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
# ... other model parameters ...
}
}
train_config {
batch_size: 8
# ... training parameters ...
}
train_input_reader {
tf_record_input_reader {
input_path: "path/to/train.record"
}
label_map_path: "path/to/label_map.pbtxt"
}
eval_config {
# ... evaluation parameters ...
}
eval_input_reader {
tf_record_input_reader {
input_path: "path/to/val.record"
}
label_map_path: "path/to/label_map.pbtxt"
shuffle: false
num_readers: 1
}
"""
with open('pipeline.config', 'w') as f:
f.write(pipeline_config)
4. Train the Model
python object_detection/model_main_tf2.py \
--pipeline_config_path=pipeline.config \
--model_dir=training/ \
--alsologtostderr
5. Export the Trained Model
After training, export your model:
python object_detection/exporter_main_v2.py \
--pipeline_config_path=pipeline.config \
--trained_checkpoint_dir=training/ \
--output_directory=exported_model/
Real-World Example: Traffic Monitoring System
Let's implement a simple traffic monitoring system using our object detection model:
import cv2
import numpy as np
import tensorflow as tf
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
# Load model and label map
model_path = 'exported_model/saved_model'
labels_path = 'path/to/label_map.pbtxt'
model = tf.saved_model.load(model_path)
detect_fn = model.signatures['serving_default']
category_index = label_map_util.create_category_index_from_labelmap(
labels_path, use_display_name=True)
# Set up vehicle counting
vehicle_classes = [3, 6, 8] # Class IDs for car, bus, truck
vehicle_count = 0
counted_vehicles = set() # Track vehicles by their position
def process_frame(frame):
global vehicle_count
# Convert frame to tensor
input_tensor = tf.convert_to_tensor(frame)
input_tensor = input_tensor[tf.newaxis, ...]
# Detect objects
detections = detect_fn(input_tensor)
# Process detections
boxes = detections['detection_boxes'][0].numpy()
classes = detections['detection_classes'][0].numpy().astype(np.int32)
scores = detections['detection_scores'][0].numpy()
# Count vehicles
for i in range(len(scores)):
if scores[i] > 0.5 and classes[i] in vehicle_classes:
box = tuple(boxes[i].tolist())
# Check if vehicle crossed a counting line
y_pos = (box[0] + box[2]) / 2
if 0.45 < y_pos < 0.55:
box_id = f"{int(box[1]*1000)}-{int(box[3]*1000)}"
if box_id not in counted_vehicles:
counted_vehicles.add(box_id)
vehicle_count += 1
# Draw counting line
h, w = frame.shape[0], frame.shape[1]
cv2.line(frame, (0, int(h*0.5)), (w, int(h*0.5)), (0, 255, 0), 2)
# Add count text
cv2.putText(frame, f"Vehicle Count: {vehicle_count}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
# Visualize detections
vis_util.visualize_boxes_and_labels_on_image_array(
frame,
boxes,
classes,
scores,
category_index,
use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=0.5,
agnostic_mode=False)
return frame
# Process video
video_path = 'traffic_video.mp4' # Replace with your video
cap = cv2.VideoCapture(video_path)
# Optional: Set up video writer to save output
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('output.mp4', fourcc, 30.0,
(int(cap.get(3)), int(cap.get(4))))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
processed_frame = process_frame(frame)
# Write to output video
out.write(processed_frame)
# Display the frame
cv2.imshow('Traffic Monitoring', processed_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
out.release()
cv2.destroyAllWindows()
This traffic monitoring system:
- Detects vehicles in video frames
- Counts vehicles as they cross a virtual line
- Visualizes detections and keeps a running count
Optimizing Performance
Object detection can be computationally intensive. Here are some tips to improve performance:
1. Choose the Right Model
TensorFlow provides models with different speed/accuracy trade-offs:
- Faster models: SSD MobileNet, EfficientDet-D0
- More accurate models: Faster R-CNN, EfficientDet-D7
2. Resize Input Images
Smaller images require less computation:
def preprocess_image(image, target_size=(300, 300)):
resized_image = cv2.resize(image, target_size)
return resized_image
3. Use TensorRT for Inference
TensorRT can significantly speed up inference:
# Convert to TensorRT (requires TensorRT installation)
from tensorflow.python.compiler.tensorrt import trt_convert as trt
converter = trt.TrtGraphConverterV2(
input_saved_model_dir=model_dir)
converter.convert()
converter.save('tensorrt_model')
# Load the optimized model
trt_model = tf.saved_model.load('tensorrt_model')
detect_fn = trt_model.signatures['serving_default']
4. Deploy on Edge Devices
For edge deployment, consider:
- TensorFlow Lite for mobile devices
- Edge TPU for Coral devices
- NVIDIA Jetson for GPU acceleration
Summary
In this guide, we've covered the fundamentals of object detection using TensorFlow's Object Detection API. You've learned:
- How to set up the TensorFlow Object Detection API
- Using pre-trained models for inference
- Fine-tuning models on custom datasets
- Creating a practical traffic monitoring application
- Optimizing performance for real-world deployment
Object detection has numerous applications across industries, including:
- Autonomous vehicles
- Surveillance systems
- Retail analytics
- Medical imaging
- Industrial quality control
By mastering these techniques, you can build powerful computer vision applications that can detect and locate objects in the real world.
Additional Resources
- TensorFlow Object Detection API GitHub
- TensorFlow Model Garden
- TensorFlow Object Detection API Tutorial
- Google's AI Blog on Object Detection
Exercises
- Basic: Use a pre-trained model to detect objects in 5 different images and visualize the results.
- Intermediate: Create a small custom dataset (10-20 images) with annotations and fine-tune a model to detect a specific object.
- Advanced: Build a real-time object detection system using your webcam that can detect and track objects, displaying their trajectories over time.
- Challenge: Implement a multi-camera object detection system that can track the same objects across different camera views.
Happy object detecting!
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)