ML in the Browser

with Tensorflow.js




Ed Atrero

Weedmaps Tech Meetup

Machine Learning

Computer Vision

Image Classification

Demo

Machine Learning

Computer Vision

Image Classification

Convolutional Neural Network (CNN)

Convolutional Neural Networks

cnn

What is a Convolution?

cnn

weights

cnn

Machine Learning Steps:

1. Model

2. Train

3. Predict

Transfer Learning

cnn

Tensorflow

Library released by Google Brain team in 2015.

Processing multidimensional arrays (tensors)

CPUs, GPUs, TPUs

Tensorflow.js

The JS version of tensorflow

Uses GPU via WebGL shaders

ML portion of the demo in less than 100 lines

Model

1. const mobilenet = await tf.loadModel(
2. "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json",
3. );
4.
5. // Return a model that outputs an internal activation.
6. const layer = mobilenet.getLayer("conv_pw_13_relu");
7. this.mobilenet = tf.model({
8. inputs: mobilenet.inputs,
9. outputs: layer.output,
10. });
11.
12. // Hyperparams
13. const batchSizeRatio = 0.4;
14. const epochs = 20;
15. const hiddenUnits = 100;
16. const learningRate = 0.0001;
17.
18. export function model(numClasses) {
19. // Creates a 2-layer fully connected model. By creating a separate model,
20. // rather than adding layers to the mobilenet model, we "freeze" the weights
21. // of the mobilenet model, and only train weights from the new model.
22. let model = tf.sequential({
23. layers: [
24. // Flattens the input to a vector so we can use it in a dense layer. While
25. // technically a layer, this only performs a reshape (and has no training
26. // parameters).
27. tf.layers.flatten({ inputShape: [7, 7, 256] }),
28. // Layer 1
29. tf.layers.dense({
30. units: hiddenUnits,
31. activation: "relu",
32. kernelInitializer: "varianceScaling",
33. useBias: true,
34. }),
35. // Layer 2. The number of units of the last layer should correspond
36. // to the number of classes we want to predict.
37. tf.layers.dense({
38. units: numClasses,
39. kernelInitializer: "varianceScaling",
40. useBias: false,
41. activation: "softmax",
42. }),
43. ],
44. });
45.
46. // Creates the optimizers which drives training of the model.
47. const optimizer = tf.train.adam(learningRate);
48. // We use categoricalCrossentropy which is the loss function we use for
49. // categorical classification which measures the error between our predicted
50. // probability distribution over classes (probability that an input is of each
51. // class), versus the label (100% probability in the true class)>
52. model.compile({ optimizer: optimizer, loss: "categoricalCrossentropy" });
53.
54. return model;
55. }
56.

Training

1. /**
2. * Sets up and trains the classifier.
3. */
4. export async function train(model, controllerDataset, trainCallback) {
5. if (controllerDataset.xs == null) {
6. throw new Error("Add some examples before training!");
7. }
8.
9. const batchSize = Math.floor(controllerDataset.xs.shape[0] * batchSizeRatio);
10.
11. // Train the model! Model.fit() will shuffle xs & ys so we don't have to.
12. model.fit(controllerDataset.xs, controllerDataset.ys, {
13. batchSize,
14. epochs,
15. callbacks: {
16. onBatchEnd: async (batch, logs) => {
17. trainCallback(logs);
18. await tf.nextFrame();
19. },
20. },
21. });
22. }
23.

Stochastic Gradient Descent

gradient descent

gradient descent

Predict

1. export async function predict(webcam, mobilenet, model, callback) {
2. isPredicting = true;
3. while (isPredicting) {
4. const predictedClass = tf.tidy(() => {
5. // Capture the frame from the webcam.
6. const img = webcam.capture();
7.
8. // Make a prediction through mobilenet, getting the internal activation of
9. // the mobilenet model.
10. const activation = mobilenet.predict(img);
11.
12. // Make a prediction through our newly-trained model using the activation
13. // from mobilenet as input.
14. const predictions = model.predict(activation);
15.
16. // Returns the index with the maximum probability. This number corresponds
17. // to the class the model thinks is the most probable given the input.
18. return predictions.as1D().argMax();
19. });
20.
21. const classId = (await predictedClass.data())[0];
22. predictedClass.dispose();
23.
24. callback(classId);
25. await tf.nextFrame();
26. }
27. }
28.
29. export function stopPredicting() {
30. isPredicting = false;
31. }
32.

Thanks!

References

https://playground.tensorflow.org/

slides => slides

https://github.com/eatrero/transfer-learning