Skip to main content

TensorFlow.js

Introduction

TensorFlow.js is an open-source library that brings machine learning capabilities directly to JavaScript, allowing developers to train and deploy ML models in the browser and Node.js environments. As a part of the broader TensorFlow ecosystem, TensorFlow.js enables you to run existing models or build and train new ones without requiring users to install additional software beyond their web browser.

In this tutorial, you'll learn:

  • What TensorFlow.js is and why it's useful
  • How to set up and use TensorFlow.js in both browser and Node.js environments
  • How to load and run pre-trained models
  • How to train models directly in JavaScript
  • Real-world applications and use cases

Why Use TensorFlow.js?

TensorFlow.js offers several unique advantages:

  1. Client-side ML: Run machine learning directly in the browser without server roundtrips
  2. Privacy: Process sensitive data locally without sending it to servers
  3. Accessibility: Lower the barrier to ML by leveraging the ubiquity of web browsers
  4. Offline capability: Applications can run ML models without internet connectivity
  5. Reduced server costs: Offload computation to the client's device

Getting Started with TensorFlow.js

Setting Up TensorFlow.js

You can include TensorFlow.js in your project in several ways:

html
<!-- Direct script tag from CDN -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>

Basic Concepts

Before we dive into examples, let's understand some key components of TensorFlow.js:

  1. Tensors: Multi-dimensional arrays that hold your data
  2. Models: The neural network architecture that processes the data
  3. Layers: Building blocks that compose a model
  4. Operations: Functions that manipulate tensors

Working with Tensors

Tensors are the core data structure in TensorFlow.js. Here's how to create and manipulate them:

javascript
// Create a 2x3 tensor
const tensor1 = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
console.log('Tensor shape:', tensor1.shape);
console.log('Tensor data type:', tensor1.dtype);

// Print tensor values
tensor1.print();

// Perform operations
const tensor2 = tensor1.add(1);
tensor2.print();

// Clean up memory
tensor1.dispose();
tensor2.dispose();

// Alternatively, use tf.tidy for automatic memory management
tf.tidy(() => {
const x = tf.tensor2d([[1, 2], [3, 4]]);
const y = x.square();
y.print();
// No need to call dispose() within tidy
});

Output:

Tensor shape: [2,3]
Tensor data type: float32
Tensor
[[1, 2, 3],
[4, 5, 6]]
Tensor
[[2, 3, 4],
[5, 6, 7]]
Tensor
[[1, 4],
[9, 16]]

Loading Pre-trained Models

TensorFlow.js allows you to use pre-trained models from various sources, including:

  1. TensorFlow.js model format
  2. TensorFlow SavedModel format (converted)
  3. Keras models (converted)
  4. TensorFlow Hub models

Using a Pre-trained Model from TensorFlow.js Model Repository

Here's an example of using the MobileNet image classification model:

html
<html>
<head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@latest"></script>
</head>
<body>
<img id="img" src="https://i.imgur.com/JlUvsxa.jpg" width="227" height="227"/>
<div id="prediction-result"></div>

<script>
async function predict() {
const img = document.getElementById('img');
const resultDiv = document.getElementById('prediction-result');

// Load MobileNet model
const model = await mobilenet.load();

// Make prediction
const predictions = await model.classify(img);

// Display results
resultDiv.innerHTML = predictions.map(p =>
`${p.className}: ${(p.probability * 100).toFixed(2)}%`
).join('<br>');
}

// Run prediction when the page loads
window.addEventListener('load', predict);
</script>
</body>
</html>

Output (will vary based on the image):

Golden retriever: 87.25%
Labrador retriever: 5.18%
Tennis ball: 2.54%

Creating and Training a Model

Let's build a simple model to recognize handwritten digits using the MNIST dataset:

javascript
// Create a sequential model
const model = tf.sequential();

// Add layers
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 3,
filters: 16,
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
model.add(tf.layers.conv2d({kernelSize: 3, filters: 32, activation: 'relu'}));
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({units: 64, activation: 'relu'}));
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));

// Compile the model
model.compile({
optimizer: 'adam',
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});

// Load and preprocess the MNIST dataset
async function loadData() {
const mnist = await tf.data.mnist();

// Normalize and reshape the data
const trainData = mnist.train.map(item => {
const img = item.xs.reshape([28, 28, 1]).div(255);
const label = item.ys;
return {xs: img, ys: label};
}).batch(32);

const testData = mnist.test.map(item => {
const img = item.xs.reshape([28, 28, 1]).div(255);
const label = item.ys;
return {xs: img, ys: label};
}).batch(32);

return {trainData, testData};
}

// Train the model
async function train() {
const {trainData, testData} = await loadData();

await model.fitDataset(trainData, {
epochs: 5,
validationData: testData,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(`Epoch ${epoch+1}: loss = ${logs.loss.toFixed(4)}, accuracy = ${logs.acc.toFixed(4)}`);
}
}
});

console.log('Training complete!');
return model;
}

// Start training
train().then(model => {
// Save model if needed
model.save('localstorage://my-mnist-model');
});

Output:

Epoch 1: loss = 0.2134, accuracy = 0.9342
Epoch 2: loss = 0.0865, accuracy = 0.9725
Epoch 3: loss = 0.0623, accuracy = 0.9801
Epoch 4: loss = 0.0488, accuracy = 0.9847
Epoch 5: loss = 0.0398, accuracy = 0.9873
Training complete!

Saving and Loading Models

TensorFlow.js allows you to save and load models in several formats:

javascript
// Save model to the browser's local storage
await model.save('localstorage://my-model');

// Save model to IndexedDB
await model.save('indexeddb://my-model');

// Save model as downloadable files (returns URLs)
const saveResults = await model.save('downloads://my-model');

// Save to HTTP endpoint (requires a compatible server)
await model.save('http://my-server.com/model');

To load a previously saved model:

javascript
// Load model from local storage
const model = await tf.loadLayersModel('localstorage://my-model');

// Load model from IndexedDB
const model = await tf.loadLayersModel('indexeddb://my-model');

// Load model from a URL
const model = await tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');

Real-world Applications

1. Real-time Webcam Object Detection

Here's a simplified example of using the COCO-SSD model for real-time object detection:

html
<html>
<head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"></script>
</head>
<body>
<video id="webcam" autoplay muted width="640" height="480"></video>
<canvas id="output" width="640" height="480"></canvas>

<script>
let model;

async function setupWebcam() {
const video = document.getElementById('webcam');
const stream = await navigator.mediaDevices.getUserMedia({
video: { width: 640, height: 480 },
audio: false
});
video.srcObject = stream;

return new Promise(resolve => {
video.onloadedmetadata = () => {
resolve(video);
};
});
}

async function detectObjects() {
const video = await setupWebcam();
const canvas = document.getElementById('output');
const ctx = canvas.getContext('2d');

// Load the COCO-SSD model
model = await cocoSsd.load();

// Detection loop
const detectFrame = async () => {
// Detect objects
const predictions = await model.detect(video);

// Draw video frame
ctx.drawImage(video, 0, 0, canvas.width, canvas.height);

// Draw bounding boxes and labels
predictions.forEach(prediction => {
const [x, y, width, height] = prediction.bbox;

ctx.strokeStyle = '#00FF00';
ctx.lineWidth = 2;
ctx.strokeRect(x, y, width, height);

ctx.fillStyle = '#00FF00';
ctx.font = '18px Arial';
ctx.fillText(
`${prediction.class}: ${Math.round(prediction.score * 100)}%`,
x, y - 5
);
});

// Continue detection
requestAnimationFrame(detectFrame);
};

detectFrame();
}

// Start detection
detectObjects();
</script>
</body>
</html>

2. Sentiment Analysis for Customer Feedback

This example shows how to use TensorFlow.js to analyze text sentiment:

javascript
import * as tf from '@tensorflow/tfjs';
import * as use from '@tensorflow-models/universal-sentence-encoder';

async function analyzeSentiment() {
// Load the Universal Sentence Encoder model
const model = await use.load();

// Sample customer feedback
const feedback = [
"I absolutely love this product! It's amazing.",
"This service was terrible, very disappointed.",
"The quality is okay, but the price is too high.",
"Best purchase I've made all year!"
];

// Encode the text to get embeddings
const embeddings = await model.embed(feedback);

// Simple sentiment classifier (for demonstration)
// In a real application, you'd use a properly trained classifier
const sentimentWords = {
positive: ["love", "amazing", "excellent", "great", "best", "fantastic"],
negative: ["terrible", "bad", "disappointed", "poor", "worst", "awful"]
};

// Convert embeddings to array
const embeddingArray = await embeddings.array();

// Analyze sentiment for each feedback
for (let i = 0; i < feedback.length; i++) {
const text = feedback[i].toLowerCase();
let sentiment = "neutral";

const posWords = sentimentWords.positive.filter(word => text.includes(word));
const negWords = sentimentWords.negative.filter(word => text.includes(word));

if (posWords.length > negWords.length) sentiment = "positive";
else if (negWords.length > posWords.length) sentiment = "negative";

console.log(`Feedback: "${feedback[i]}"`);
console.log(`Sentiment: ${sentiment}`);
console.log('---');
}

// Cleanup
embeddings.dispose();
}

analyzeSentiment();

Output:

Feedback: "I absolutely love this product! It's amazing."
Sentiment: positive
---
Feedback: "This service was terrible, very disappointed."
Sentiment: negative
---
Feedback: "The quality is okay, but the price is too high."
Sentiment: neutral
---
Feedback: "Best purchase I've made all year!"
Sentiment: positive
---

Performance Optimization Tips

  1. Use WebGL backend for GPU acceleration:

    javascript
    await tf.setBackend('webgl');
  2. Manage memory with tf.tidy():

    javascript
    const result = tf.tidy(() => {
    const x = tf.tensor([1, 2, 3]);
    const y = tf.tensor([4, 5, 6]);
    return x.add(y);
    });
  3. Batch processing for efficiency:

    javascript
    // Process data in batches instead of one by one
    const batchSize = 32;
    for (let i = 0; i < data.length; i += batchSize) {
    const batch = data.slice(i, i + batchSize);
    const predictions = model.predict(tf.tensor(batch));
    // Process predictions
    }
  4. Use quantized models when possible:

    javascript
    // Load a quantized model (smaller size, faster loading)
    const model = await tf.loadLayersModel(
    'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'
    );

Summary

TensorFlow.js is a powerful library that brings machine learning capabilities directly to the browser and Node.js environments. In this tutorial, we've covered:

  • Setting up TensorFlow.js in both browser and Node.js
  • Working with tensors and basic operations
  • Loading pre-trained models
  • Creating and training custom models
  • Saving and loading models
  • Real-world applications like object detection and sentiment analysis
  • Performance optimization tips

TensorFlow.js makes machine learning more accessible by allowing models to run directly in the browser, providing privacy benefits and reducing the need for server-side processing. It opens up possibilities for interactive ML applications, from image recognition to natural language processing, all running on the client side.

Additional Resources

Exercises

  1. Create a simple image classifier using a pre-trained MobileNet model that allows users to upload their own images.
  2. Build a handwritten digit recognition app using the MNIST dataset that works with mouse or touch input.
  3. Implement a text toxicity detector using the Toxicity model from TensorFlow.js.
  4. Create a body pose estimation application using the PoseNet model.
  5. Build a simple recommender system using TensorFlow.js that suggests products based on user preferences.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)