Tensorflow JS CNN Vanilla Javascript MNIST digits style body min height 600px height 100 padding 0 margin 0 font family Tahoma Verdana font size 14px display flex justify content center background DDD section background FDFDFD margin 0 4rem padding 0 2rem flex grow 0 colum display flex fiex direction column row display flex fiex direction row table border collapse collapse max width 800px td border 1px solid CCC padding 1px 5px style script src https cdnjs cloudflare com ajax libs tensorflow 0 10 3 tf min js script script src https d3js org d3 v5 min js script script data const IMAGE_SIZE 784 const NUM_CLASSES 10 const NUM_DATASET_ELEMENTS 65000 const NUM_TRAIN_ELEMENTS 55000 const NUM_TEST_ELEMENTS NUM_DATASET_ELEMENTS NUM_TRAIN_ELEMENTS const MNIST_IMAGES_SPRITE_PATH https storage googleapis com learnjs data model builder mnist_images png const MNIST_LABELS_PATH https storage googleapis com learnjs data model builder mnist_labels_uint8 const MNIST_IMAGES_SPRITE_PATH https notebook xbdev net var resources mnist_images png const MNIST_LABELS_PATH https notebook xbdev net var resources mnist_labels_uint8 A class that fetches the sprited MNIST dataset and returns shuffled batches NOTE This will get much easier For now we do data fetching and manipulation manually class MnistData constructor this shuffledTrainIndex 0 this shuffledTestIndex 0 async load Make a request for the MNIST sprited image const img new Image const canvas document createElement canvas const ctx canvas getContext 2d const imgRequest new Promise resolve reject img crossOrigin img onload img width img naturalWidth img height img naturalHeight const datasetBytesBuffer new ArrayBuffer NUM_DATASET_ELEMENTS IMAGE_SIZE 4 const chunkSize 5000 canvas width img width canvas height chunkSize for let i 0 i NUM_DATASET_ELEMENTS chunkSize i const datasetBytesView new Float32Array datasetBytesBuffer i IMAGE_SIZE chunkSize 4 IMAGE_SIZE chunkSize ctx drawImage img 0 i chunkSize img width chunkSize 0 0 img width chunkSize const imageData ctx getImageData 0 0 canvas width canvas height for let j 0 j imageData data length 4 j All channels hold an equal value since the image is grayscale so just read the red channel datasetBytesView j imageData data j 4 255 console log Processed chunk i this datasetImages new Float32Array datasetBytesBuffer resolve img src MNIST_IMAGES_SPRITE_PATH const labelsRequest fetch MNIST_LABELS_PATH const imgResponse labelsResponse await Promise all imgRequest labelsRequest this datasetLabels new Uint8Array await labelsResponse arrayBuffer Create shuffled indices into the train test set for when we select a random dataset element for training validation this trainIndices tf util createShuffledIndices NUM_TRAIN_ELEMENTS this testIndices tf util createShuffledIndices NUM_TEST_ELEMENTS Slice the the images and labels into train and test sets this trainImages this datasetImages slice 0 IMAGE_SIZE NUM_TRAIN_ELEMENTS this testImages this datasetImages slice IMAGE_SIZE NUM_TRAIN_ELEMENTS this trainLabels this datasetLabels slice 0 NUM_CLASSES NUM_TRAIN_ELEMENTS this testLabels this datasetLabels slice NUM_CLASSES NUM_TRAIN_ELEMENTS nextTrainBatch batchSize return this nextBatch batchSize this trainImages this trainLabels this shuffledTrainIndex this shuffledTrainIndex 1 this trainIndices length return this trainIndices this shuffledTrainIndex nextTestBatch batchSize return this nextBatch batchSize this testImages this testLabels this shuffledTestIndex this shuffledTestIndex 1 this testIndices length return this testIndices this shuffledTestIndex nextBatch batchSize data index const batchImagesArray new Float32Array batchSize IMAGE_SIZE const batchLabelsArray new Uint8Array batchSize NUM_CLASSES for let i 0 i batchSize i const idx index const image data 0 slice idx IMAGE_SIZE idx IMAGE_SIZE IMAGE_SIZE batchImagesArray set image i IMAGE_SIZE const label data 1 slice idx NUM_CLASSES idx NUM_CLASSES NUM_CLASSES batchLabelsArray set label i NUM_CLASSES const xs tf tensor2d batchImagesArray batchSize IMAGE_SIZE const labels tf tensor2d batchLabelsArray batchSize NUM_CLASSES return xs labels script script model class Model constructor learningRate this learningRate learningRate this model tf sequential this model add tf layers conv2d inputShape 28 28 1 kernelSize 5 filters 8 strides 1 activation relu kernelInitializer varianceScaling this model add tf layers maxPooling2d poolSize 2 2 strides 2 2 this model add tf layers conv2d kernelSize 5 filters 16 strides 1 activation relu kernelInitializer varianceScaling this model add tf layers maxPooling2d poolSize 2 2 strides 2 2 this model add tf layers flatten this model add tf layers dense units 10 kernelInitializer varianceScaling activation softmax const optimizer tf train sgd this learningRate this model compile optimizer optimizer loss categoricalCrossentropy metrics accuracy build layer1options layer2options this model tf sequential First layer this model add tf layers conv2d inputShape 28 28 1 kernelSize layer1options kernelSize 5 filters layer1options filters 8 strides 1 activation relu kernelInitializer varianceScaling this model add tf layers maxPooling2d poolSize 2 2 strides 2 2 Second layer this model add tf layers conv2d kernelSize layer2options kernelSize 5 filters layer2options filters 16 strides 1 activation relu kernelInitializer varianceScaling this model add tf layers maxPooling2d poolSize 2 2 strides 2 2 Output layer this model add tf layers flatten this model add tf layers dense units 10 kernelInitializer varianceScaling activation softmax const optimizer tf train sgd this learningRate this model compile optimizer optimizer loss categoricalCrossentropy metrics accuracy predict data return this model predict data async train data numTrainBatches callback null const BATCH_SIZE 64 const TRAIN_BATCHES 150 const TRAIN_BATCHES 40 const TRAIN_BATCHES numTrainBatches const TEST_BATCH_SIZE 1000 const TEST_ITERATION_FREQUENCY 5 for let i 0 i TRAIN_BATCHES i const batch data nextTrainBatch BATCH_SIZE Every few batches test the accuracy of the mode let testBatch let validationData if i TEST_ITERATION_FREQUENCY 0 i TRAIN_BATCHES 1 testBatch data nextTestBatch TEST_BATCH_SIZE validationData testBatch xs reshape TEST_BATCH_SIZE 28 28 1 testBatch labels The entire dataset doesn t fit into memory so we call fit repeatedly with batches const history await this model fit batch xs reshape BATCH_SIZE 28 28 1 batch labels batchSize BATCH_SIZE validationData epochs 1 const loss history history loss 0 const accuracy history history acc 0 if testBatch null console log Batch i accuracy accuracy Invoke callback if callback callback history Clean up batch xs dispose batch labels dispose if testBatch null testBatch xs dispose testBatch labels dispose await tf nextFrame End train script section h3 CNN Tensorflow JS h3 p A 2 layer CNN with Tensorflow JS to recognize hand drawn digits MNIST Use the sliders to adjust the layer parameter below and then click Train to train the neural network Once trained draw a digit on the canvas to run live prediction p br div class row Left Panel div Layer 1 div class row span strong Layer 1 strong Kernel input id row one kernel type range min 2 max 8 value 5 step 1 oninput updateSliders span id row one kernel value span span nbsp nbsp nbsp span Filters input id row one filters type range min 4 max 24 value 8 step 1 oninput updateSliders span id row one filters value span span div br Layer 2 div class row span strong Layer 2 strong Kernel input id row two kernel type range min 2 max 8 value 5 step 1 oninput updateSliders span id row two kernel value span span nbsp nbsp nbsp span Filters input id row two filters type range min 4 max 24 value 16 step 1 oninput updateSliders span id row two filters value span span div br Batch div class row span Training batches input style width 15rem id num batch type range min 10 max 100 value 25 step 5 size 30 oninput updateSliders span id num batch value span span div br Train div button style font size 120 onclick run Train button nbsp nbsp span id training not trained span div Prediction div class row div style padding left 30px id predictions one div div style padding left 30px id predictions two div div div Right Panel div style padding left 30px div class row div canvas id input canvas br button style font size 120 onclick clearInputCanvas Clear canvas button div div style padding left 10px font size 150 div Prediction Is div div style font size 300 id user prediction div div div div div br section script let mnistData new MnistData let cnnModel new Model 0 15 cnnModel build Just init a default model so stuff doesn t error out function updateTraining history d3 select training text Accuracy history history acc 0 toFixed 2 function updateSliders let r1kernel d3 select row one kernel node value let r1filters d3 select row one filters node value let r2kernel d3 select row two kernel node value let r2filters d3 select row two filters node value let numBatch d3 select num batch node value console log r1kernel r1filters r2kernel r2filters d3 select row one kernel value text r1kernel d3 select row one filters value text r1filters d3 select row two kernel value text r2kernel d3 select row two filters value text r2filters d3 select num batch value text numBatch Copied from https github com tensorflow tfjs examples blob master mnist ui js function draw image canvas const width height 28 28 canvas width width canvas height height const ctx canvas getContext 2d const imageData new ImageData width height const data image dataSync for let i 0 i height width i const j i 4 imageData data j 0 data i 255 imageData data j 1 data i 255 imageData data j 2 data i 255 imageData data j 3 255 ctx putImageData imageData 0 0 function showPredictions batch predictions labels const len batch xs shape 0 const c1 d3 select predictions one const c2 d3 select predictions two for let i 0 i len i const image batch xs slice i 0 1 batch xs shape 1 const canvas document createElement canvas canvas className prediction canvas draw image flatten canvas const row i 2 0 c1 append div c2 append div row style height 30px row append span text Prediction predictions i Actual labels i row node appendChild canvas let inputCanvas null let inputCtx null let isMouseDown false let prev null let curr null function clearInputCanvas d3 select user prediction text if inputCtx null inputCtx fillStyle 000000 inputCtx fillRect 0 0 140 140 function setInputCanvas const line ctx x1 y1 x2 y2 ctx beginPath ctx moveTo x1 y1 ctx lineTo x2 y2 ctx closePath ctx stroke const getPosition canvas clientX clientY var rect inputCanvas getBoundingClientRect return x clientX rect left y clientY rect top const extract Pull const imgData inputCtx getImageData 0 0 140 140 Grab single channel const singleChannel imgData data filter d i i 4 0 Down sample const sample for let i 0 i 28 i for let j 0 j 28 j let c 0 for let x 0 x 5 x for let y 0 y 5 y c singleChannel 140 i 5 x j 5 y c 25 sample push c console log sample sample const tensor tf tensor sample const userInput tensor reshape 1 28 28 1 const output cnnModel predict userInput const axis 1 const predictions Array from output argMax axis dataSync d3 select user prediction text predictions console log prediction predictions const onMouseDown e isMouseDown true prev getPosition inputCanvas e clientX e clientY const onMouseUp e isMouseDown false extract const onMouseMove e if isMouseDown false return curr getPosition inputCanvas e clientX e clientY inputCtx lineWidth 6 inputCtx lineCap round inputCtx fillStyle FFFFFF inputCtx strokeStyle FFFFFF line inputCtx prev x prev y curr x curr y prev curr inputCanvas document getElementById input inputCanvas width 140 inputCanvas height 140 inputCtx inputCanvas getContext 2d inputCtx fillStyle 000000 inputCtx fillRect 0 0 140 140 inputCanvas addEventListener mousedown onMouseDown inputCanvas addEventListener mouseup onMouseUp inputCanvas addEventListener mousemove onMouseMove window inputCtx inputCtx window inputCanvas inputCanvas async function load console log Start loading await mnistData load console log Done loading async function train console log Start training let numBatch d3 select num batch node value let r1kernel d3 select row one kernel node value let r1filters d3 select row one filters node value let r2kernel d3 select row two kernel node value let r2filters d3 select row two filters node value cnnModel build kernelSize r1kernel filters r1filters kernelSize r2kernel filters r2filters await cnnModel train mnistData numBatch updateTraining console log Done training async function run d3 select predictions one selectAll remove d3 select predictions two selectAll remove d3 select training text loading await load await train const batch mnistData nextTestBatch 10 const output cnnModel predict batch xs reshape 1 28 28 1 const axis 1 const labels Array from batch labels argMax axis dataSync const predictions Array from output argMax axis dataSync showPredictions batch predictions labels run updateSliders setInputCanvas script
xs dispose batch labels dispose if testBatch null testBatch xs dispose testBatch labels dispose await tf nextFrame End train script section h3 CNN Tensorflow JS h3 p A 2 layer CNN with Tensorflow JS to recognize hand drawn digits MNIST Use the sliders to adjust the layer parameter below and then click Train to train the neural network Once trained draw a digit on the canvas to run live prediction p br div class row Left Panel div Layer 1 div class row span strong Layer 1 strong Kernel input id row one kernel type range min 2 max 8 value 5 step 1 oninput updateSliders span id row one kernel value span span nbsp nbsp nbsp span Filters input id row one filters type range min 4 max 24 value 8 step 1 oninput updateSliders span id row one filters value span span div br Layer 2 div class row span strong Layer 2 strong Kernel input id row two kernel type range min 2 max 8 value 5 step 1 oninput updateSliders span id row two kernel value span span nbsp nbsp nbsp span Filters input id row two filters type range min 4 max 24 value 16 step 1 oninput updateSliders span id row two filters value span span div br Batch div class row span Training batches input style width 15rem id num batch type range min 10 max 100 value 25 step 5 size 30 oninput updateSliders span id num batch value span span div br Train div button style font size 120 onclick run Train button nbsp nbsp span id training not trained span div Prediction div class row div style padding left 30px id predictions one div div style padding left 30px id predictions two div div div Right Panel div style padding left 30px div class row div canvas id input canvas br button style font size 120 onclick clearInputCanvas Clear canvas button div div style padding left 10px font size 150 div Prediction Is div div style font size 300 id user prediction div div div div div br section script let mnistData new MnistData let cnnModel new Model 0 15 cnnModel build Just init a default model so stuff doesn t error out function updateTraining history d3 select training text Accuracy history history acc 0 toFixed 2 function updateSliders let r1kernel d3 select row one kernel node value let r1filters d3 select row one filters node value let r2kernel d3 select row two kernel node value let r2filters d3 select row two filters node value let numBatch d3 select num batch node value console log r1kernel r1filters r2kernel r2filters d3 select row one kernel value text r1kernel d3 select row one filters value text r1filters d3 select row two kernel value text r2kernel d3 select row two filters value text r2filters d3 select num batch value text numBatch Copied from https github com tensorflow tfjs examples blob master mnist ui js function draw image canvas const width height 28 28 canvas width width canvas height height const ctx canvas getContext 2d const imageData new ImageData width height const data image dataSync for let i 0 i height width i const j i 4 imageData data j 0 data i 255 imageData data j 1 data i 255 imageData data j 2 data i 255 imageData data j 3 255 ctx putImageData imageData 0 0 function showPredictions batch predictions labels const len batch xs shape 0 const c1 d3 select predictions one const c2 d3 select predictions two for let i 0 i len i const image batch xs slice i 0 1 batch xs shape 1 const canvas document createElement canvas canvas className prediction canvas draw image flatten canvas const row i 2 0 c1 append div c2 append div row style height 30px row append span text Prediction predictions i Actual labels i row node appendChild canvas let inputCanvas null let inputCtx null let isMouseDown false let prev null let curr null function clearInputCanvas d3 select user prediction text if inputCtx null inputCtx fillStyle 000000 inputCtx fillRect 0 0 140 140 function setInputCanvas const line ctx x1 y1 x2 y2 ctx beginPath ctx moveTo x1 y1 ctx lineTo x2 y2 ctx closePath ctx stroke const getPosition canvas clientX clientY var rect inputCanvas getBoundingClientRect return x clientX rect left y clientY rect top const extract Pull const imgData inputCtx getImageData 0 0 140 140 Grab single channel const singleChannel imgData data filter d i i 4 0 Down sample const sample for let i 0 i 28 i for let j 0 j 28 j let c 0 for let x 0 x 5 x for let y 0 y 5 y c singleChannel 140 i 5 x j 5 y c 25 sample push c console log sample sample const tensor tf tensor sample const userInput tensor reshape 1 28 28 1 const output cnnModel predict userInput const axis 1 const predictions Array from output argMax axis dataSync d3 select user prediction text predictions console log prediction predictions const onMouseDown e isMouseDown true prev getPosition inputCanvas e clientX e clientY const onMouseUp e isMouseDown false extract const onMouseMove e if isMouseDown false return curr getPosition inputCanvas e clientX e clientY inputCtx lineWidth 6 inputCtx lineCap round inputCtx fillStyle FFFFFF inputCtx strokeStyle FFFFFF line inputCtx prev x prev y curr x curr y prev curr inputCanvas document getElementById input inputCanvas width 140 inputCanvas height 140 inputCtx inputCanvas getContext 2d inputCtx fillStyle 000000 inputCtx fillRect 0 0 140 140 inputCanvas addEventListener mousedown onMouseDown inputCanvas addEventListener mouseup onMouseUp inputCanvas addEventListener mousemove onMouseMove window inputCtx inputCtx window inputCanvas inputCanvas async function load console log Start loading await mnistData load console log Done loading async function train console log Start training let numBatch d3 select num batch node value let r1kernel d3 select row one kernel node value let r1filters d3 select row one filters node value let r2kernel d3 select row two kernel node value let r2filters d3 select row two filters node value cnnModel build kernelSize r1kernel filters r1filters kernelSize r2kernel filters r2filters await cnnModel train mnistData numBatch updateTraining console log Done training async function run d3 select predictions one selectAll remove d3 select predictions two selectAll remove d3 select training text loading await load await train const batch mnistData nextTestBatch 10 const output cnnModel predict batch xs reshape 1 28 28 1 const axis 1 const labels Array from batch labels argMax axis dataSync const predictions Array from output argMax axis dataSync showPredictions batch predictions labels run updateSliders setInputCanvas script