TensorFlow.js
Introduction
TensorFlow.js is an open-source library that brings machine learning capabilities directly to JavaScript, allowing developers to train and deploy ML models in the browser and Node.js environments. As a part of the broader TensorFlow ecosystem, TensorFlow.js enables you to run existing models or build and train new ones without requiring users to install additional software beyond their web browser.
In this tutorial, you'll learn:
- What TensorFlow.js is and why it's useful
- How to set up and use TensorFlow.js in both browser and Node.js environments
- How to load and run pre-trained models
- How to train models directly in JavaScript
- Real-world applications and use cases
Why Use TensorFlow.js?
TensorFlow.js offers several unique advantages:
- Client-side ML: Run machine learning directly in the browser without server roundtrips
- Privacy: Process sensitive data locally without sending it to servers
- Accessibility: Lower the barrier to ML by leveraging the ubiquity of web browsers
- Offline capability: Applications can run ML models without internet connectivity
- Reduced server costs: Offload computation to the client's device
Getting Started with TensorFlow.js
Setting Up TensorFlow.js
You can include TensorFlow.js in your project in several ways:
- Browser
- NPM
<!-- Direct script tag from CDN -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
# Install via npm for browser or Node.js projects
npm install @tensorflow/tfjs
# For Node.js with native acceleration
npm install @tensorflow/tfjs-node
# For GPU support in Node.js
npm install @tensorflow/tfjs-node-gpu
Basic Concepts
Before we dive into examples, let's understand some key components of TensorFlow.js:
- Tensors: Multi-dimensional arrays that hold your data
- Models: The neural network architecture that processes the data
- Layers: Building blocks that compose a model
- Operations: Functions that manipulate tensors
Working with Tensors
Tensors are the core data structure in TensorFlow.js. Here's how to create and manipulate them:
// Create a 2x3 tensor
const tensor1 = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
console.log('Tensor shape:', tensor1.shape);
console.log('Tensor data type:', tensor1.dtype);
// Print tensor values
tensor1.print();
// Perform operations
const tensor2 = tensor1.add(1);
tensor2.print();
// Clean up memory
tensor1.dispose();
tensor2.dispose();
// Alternatively, use tf.tidy for automatic memory management
tf.tidy(() => {
const x = tf.tensor2d([[1, 2], [3, 4]]);
const y = x.square();
y.print();
// No need to call dispose() within tidy
});
Output:
Tensor shape: [2,3]
Tensor data type: float32
Tensor
[[1, 2, 3],
[4, 5, 6]]
Tensor
[[2, 3, 4],
[5, 6, 7]]
Tensor
[[1, 4],
[9, 16]]
Loading Pre-trained Models
TensorFlow.js allows you to use pre-trained models from various sources, including:
- TensorFlow.js model format
- TensorFlow SavedModel format (converted)
- Keras models (converted)
- TensorFlow Hub models
Using a Pre-trained Model from TensorFlow.js Model Repository
Here's an example of using the MobileNet image classification model:
<html>
<head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@latest"></script>
</head>
<body>
<img id="img" src="https://i.imgur.com/JlUvsxa.jpg" width="227" height="227"/>
<div id="prediction-result"></div>
<script>
async function predict() {
const img = document.getElementById('img');
const resultDiv = document.getElementById('prediction-result');
// Load MobileNet model
const model = await mobilenet.load();
// Make prediction
const predictions = await model.classify(img);
// Display results
resultDiv.innerHTML = predictions.map(p =>
`${p.className}: ${(p.probability * 100).toFixed(2)}%`
).join('<br>');
}
// Run prediction when the page loads
window.addEventListener('load', predict);
</script>
</body>
</html>
Output (will vary based on the image):
Golden retriever: 87.25%
Labrador retriever: 5.18%
Tennis ball: 2.54%
Creating and Training a Model
Let's build a simple model to recognize handwritten digits using the MNIST dataset:
// Create a sequential model
const model = tf.sequential();
// Add layers
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 3,
filters: 16,
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
model.add(tf.layers.conv2d({kernelSize: 3, filters: 32, activation: 'relu'}));
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({units: 64, activation: 'relu'}));
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
// Compile the model
model.compile({
optimizer: 'adam',
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
// Load and preprocess the MNIST dataset
async function loadData() {
const mnist = await tf.data.mnist();
// Normalize and reshape the data
const trainData = mnist.train.map(item => {
const img = item.xs.reshape([28, 28, 1]).div(255);
const label = item.ys;
return {xs: img, ys: label};
}).batch(32);
const testData = mnist.test.map(item => {
const img = item.xs.reshape([28, 28, 1]).div(255);
const label = item.ys;
return {xs: img, ys: label};
}).batch(32);
return {trainData, testData};
}
// Train the model
async function train() {
const {trainData, testData} = await loadData();
await model.fitDataset(trainData, {
epochs: 5,
validationData: testData,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(`Epoch ${epoch+1}: loss = ${logs.loss.toFixed(4)}, accuracy = ${logs.acc.toFixed(4)}`);
}
}
});
console.log('Training complete!');
return model;
}
// Start training
train().then(model => {
// Save model if needed
model.save('localstorage://my-mnist-model');
});
Output:
Epoch 1: loss = 0.2134, accuracy = 0.9342
Epoch 2: loss = 0.0865, accuracy = 0.9725
Epoch 3: loss = 0.0623, accuracy = 0.9801
Epoch 4: loss = 0.0488, accuracy = 0.9847
Epoch 5: loss = 0.0398, accuracy = 0.9873
Training complete!
Saving and Loading Models
TensorFlow.js allows you to save and load models in several formats:
// Save model to the browser's local storage
await model.save('localstorage://my-model');
// Save model to IndexedDB
await model.save('indexeddb://my-model');
// Save model as downloadable files (returns URLs)
const saveResults = await model.save('downloads://my-model');
// Save to HTTP endpoint (requires a compatible server)
await model.save('http://my-server.com/model');
To load a previously saved model:
// Load model from local storage
const model = await tf.loadLayersModel('localstorage://my-model');
// Load model from IndexedDB
const model = await tf.loadLayersModel('indexeddb://my-model');
// Load model from a URL
const model = await tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');
Real-world Applications
1. Real-time Webcam Object Detection
Here's a simplified example of using the COCO-SSD model for real-time object detection:
<html>
<head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"></script>
</head>
<body>
<video id="webcam" autoplay muted width="640" height="480"></video>
<canvas id="output" width="640" height="480"></canvas>
<script>
let model;
async function setupWebcam() {
const video = document.getElementById('webcam');
const stream = await navigator.mediaDevices.getUserMedia({
video: { width: 640, height: 480 },
audio: false
});
video.srcObject = stream;
return new Promise(resolve => {
video.onloadedmetadata = () => {
resolve(video);
};
});
}
async function detectObjects() {
const video = await setupWebcam();
const canvas = document.getElementById('output');
const ctx = canvas.getContext('2d');
// Load the COCO-SSD model
model = await cocoSsd.load();
// Detection loop
const detectFrame = async () => {
// Detect objects
const predictions = await model.detect(video);
// Draw video frame
ctx.drawImage(video, 0, 0, canvas.width, canvas.height);
// Draw bounding boxes and labels
predictions.forEach(prediction => {
const [x, y, width, height] = prediction.bbox;
ctx.strokeStyle = '#00FF00';
ctx.lineWidth = 2;
ctx.strokeRect(x, y, width, height);
ctx.fillStyle = '#00FF00';
ctx.font = '18px Arial';
ctx.fillText(
`${prediction.class}: ${Math.round(prediction.score * 100)}%`,
x, y - 5
);
});
// Continue detection
requestAnimationFrame(detectFrame);
};
detectFrame();
}
// Start detection
detectObjects();
</script>
</body>
</html>
2. Sentiment Analysis for Customer Feedback
This example shows how to use TensorFlow.js to analyze text sentiment:
import * as tf from '@tensorflow/tfjs';
import * as use from '@tensorflow-models/universal-sentence-encoder';
async function analyzeSentiment() {
// Load the Universal Sentence Encoder model
const model = await use.load();
// Sample customer feedback
const feedback = [
"I absolutely love this product! It's amazing.",
"This service was terrible, very disappointed.",
"The quality is okay, but the price is too high.",
"Best purchase I've made all year!"
];
// Encode the text to get embeddings
const embeddings = await model.embed(feedback);
// Simple sentiment classifier (for demonstration)
// In a real application, you'd use a properly trained classifier
const sentimentWords = {
positive: ["love", "amazing", "excellent", "great", "best", "fantastic"],
negative: ["terrible", "bad", "disappointed", "poor", "worst", "awful"]
};
// Convert embeddings to array
const embeddingArray = await embeddings.array();
// Analyze sentiment for each feedback
for (let i = 0; i < feedback.length; i++) {
const text = feedback[i].toLowerCase();
let sentiment = "neutral";
const posWords = sentimentWords.positive.filter(word => text.includes(word));
const negWords = sentimentWords.negative.filter(word => text.includes(word));
if (posWords.length > negWords.length) sentiment = "positive";
else if (negWords.length > posWords.length) sentiment = "negative";
console.log(`Feedback: "${feedback[i]}"`);
console.log(`Sentiment: ${sentiment}`);
console.log('---');
}
// Cleanup
embeddings.dispose();
}
analyzeSentiment();
Output:
Feedback: "I absolutely love this product! It's amazing."
Sentiment: positive
---
Feedback: "This service was terrible, very disappointed."
Sentiment: negative
---
Feedback: "The quality is okay, but the price is too high."
Sentiment: neutral
---
Feedback: "Best purchase I've made all year!"
Sentiment: positive
---
Performance Optimization Tips
-
Use WebGL backend for GPU acceleration:
javascriptawait tf.setBackend('webgl');
-
Manage memory with
tf.tidy()
:javascriptconst result = tf.tidy(() => {
const x = tf.tensor([1, 2, 3]);
const y = tf.tensor([4, 5, 6]);
return x.add(y);
}); -
Batch processing for efficiency:
javascript// Process data in batches instead of one by one
const batchSize = 32;
for (let i = 0; i < data.length; i += batchSize) {
const batch = data.slice(i, i + batchSize);
const predictions = model.predict(tf.tensor(batch));
// Process predictions
} -
Use quantized models when possible:
javascript// Load a quantized model (smaller size, faster loading)
const model = await tf.loadLayersModel(
'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'
);
Summary
TensorFlow.js is a powerful library that brings machine learning capabilities directly to the browser and Node.js environments. In this tutorial, we've covered:
- Setting up TensorFlow.js in both browser and Node.js
- Working with tensors and basic operations
- Loading pre-trained models
- Creating and training custom models
- Saving and loading models
- Real-world applications like object detection and sentiment analysis
- Performance optimization tips
TensorFlow.js makes machine learning more accessible by allowing models to run directly in the browser, providing privacy benefits and reducing the need for server-side processing. It opens up possibilities for interactive ML applications, from image recognition to natural language processing, all running on the client side.
Additional Resources
- TensorFlow.js Official Website
- TensorFlow.js API Reference
- TensorFlow.js GitHub Repository
- TensorFlow.js Models and Examples
- TensorFlow.js Tutorials
Exercises
- Create a simple image classifier using a pre-trained MobileNet model that allows users to upload their own images.
- Build a handwritten digit recognition app using the MNIST dataset that works with mouse or touch input.
- Implement a text toxicity detector using the Toxicity model from TensorFlow.js.
- Create a body pose estimation application using the PoseNet model.
- Build a simple recommender system using TensorFlow.js that suggests products based on user preferences.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)