Tensorflow JS CNN Vanilla Javascript MNIST digits style body min height 600px height 100 padding 0 margin 0 font family Tahoma Verdana font size 14px display flex justify content center background DDD section background FDFDFD margin 0 4rem padding 0 2rem flex grow 0 colum display flex fiex direction column row display flex fiex direction row table border collapse collapse max width 800px td border 1px solid CCC padding 1px 5px style script src https cdnjs cloudflare com ajax libs tensorflow 0 10 3 tf min js script script src https d3js org d3 v5 min js script script data const IMAGE SIZE 784 const NUM CLASSES 10 const NUM DATASET ELEMENTS 65000 const NUM TRAIN ELEMENTS 55000 const NUM TEST ELEMENTS NUM DATASET ELEMENTS NUM TRAIN ELEMENTS const MNIST IMAGES SPRITE PATH https storage googleapis com learnjs data model builder mnist images png const MNIST LABELS PATH https storage googleapis com learnjs data model builder mnist labels uint8 const MNIST IMAGES SPRITE PATH https notebook xbdev net var resources mnist images png const MNIST LABELS PATH https notebook xbdev net var resources mnist labels uint8 A class that fetches the sprited MNIST dataset and returns shuffled batches NOTE This will get much easier For now we do data fetching and manipulation manually class MnistData constructor this shuffledTrainIndex 0 this shuffledTestIndex 0 async load Make a request for the MNIST sprited image const img new Image const canvas document createElement canvas const ctx canvas getContext 2d const imgRequest new Promise resolve reject img crossOrigin img onload img width img naturalWidth img height img naturalHeight const datasetBytesBuffer new ArrayBuffer NUM DATASET ELEMENTS IMAGE SIZE 4 const chunkSize 5000 canvas width img width canvas height chunkSize for let i 0 i NUM DATASET ELEMENTS chunkSize i const datasetBytesView new Float32Array datasetBytesBuffer i IMAGE SIZE chunkSize 4 IMAGE SIZE chunkSize ctx drawImage img 0 i chunkSize img width chunkSize 0 0 img width chunkSize const imageData ctx getImageData 0 0 canvas width canvas height for let j 0 j imageData data length 4 j All channels hold an equal value since the image is grayscale so just read the red channel datasetBytesView j imageData data j 4 255 console log Processed chunk i this datasetImages new Float32Array datasetBytesBuffer resolve img src MNIST IMAGES SPRITE PATH const labelsRequest fetch MNIST LABELS PATH const imgResponse labelsResponse await Promise all imgRequest labelsRequest this datasetLabels new Uint8Array await labelsResponse arrayBuffer Create shuffled indices into the train test set for when we select a random dataset element for training validation this trainIndices tf util createShuffledIndices NUM TRAIN ELEMENTS this testIndices tf util createShuffledIndices NUM TEST ELEMENTS Slice the the images and labels into train and test sets this trainImages this datasetImages slice 0 IMAGE SIZE NUM TRAIN ELEMENTS this testImages this datasetImages slice IMAGE SIZE NUM TRAIN ELEMENTS this trainLabels this datasetLabels slice 0 NUM CLASSES NUM TRAIN ELEMENTS this testLabels this datasetLabels slice NUM CLASSES NUM TRAIN ELEMENTS nextTrainBatch batchSize return this nextBatch batchSize this trainImages this trainLabels this shuffledTrainIndex this shuffledTrainIndex 1 this trainIndices length return this trainIndices this shuffledTrainIndex nextTestBatch batchSize return this nextBatch batchSize this testImages this testLabels this shuffledTestIndex this shuffledTestIndex 1 this testIndices length return this testIndices this shuffledTestIndex nextBatch batchSize data index const batchImagesArray new Float32Array batchSize IMAGE SIZE const batchLabelsArray new Uint8Array batchSize NUM CLASSES for let i 0 i batchSize i const idx index const image data 0 slice idx IMAGE SIZE idx IMAGE SIZE IMAGE SIZE batchImagesArray set image i IMAGE SIZE const label data 1 slice idx NUM CLASSES idx NUM CLASSES NUM CLASSES batchLabelsArray set label i NUM CLASSES const xs tf tensor2d batchImagesArray batchSize IMAGE SIZE const labels tf tensor2d batchLabelsArray batchSize NUM CLASSES return xs labels script script model class Model constructor learningRate this learningRate learningRate this model tf sequential this model add tf layers conv2d inputShape 28 28 1 kernelSize 5 filters 8 strides 1 activation relu kernelInitializer varianceScaling this model add tf layers maxPooling2d poolSize 2 2 strides 2 2 this model add tf layers conv2d kernelSize 5 filters 16 strides 1 activation relu kernelInitializer varianceScaling this model add tf layers maxPooling2d poolSize 2 2 strides 2 2 this model add tf layers flatten this model add tf layers dense units 10 kernelInitializer varianceScaling activation softmax const optimizer tf train sgd this learningRate this model compile optimizer optimizer loss categoricalCrossentropy metrics accuracy build layer1options layer2options this model tf sequential First layer this model add tf layers conv2d inputShape 28 28 1 kernelSize layer1options kernelSize 5 filters layer1options filters 8 strides 1 activation relu kernelInitializer varianceScaling this model add tf layers maxPooling2d poolSize 2 2 strides 2 2 Second layer this model add tf layers conv2d kernelSize layer2options kernelSize 5 filters layer2options filters 16 strides 1 activation relu kernelInitializer varianceScaling this model add tf layers maxPooling2d poolSize 2 2 strides 2 2 Output layer this model add tf layers flatten this model add tf layers dense units 10 kernelInitializer varianceScaling activation softmax const optimizer tf train sgd this learningRate this model compile optimizer optimizer loss categoricalCrossentropy metrics accuracy predict data return this model predict data async train data numTrainBatches callback null const BATCH SIZE 64 const TRAIN BATCHES 150 const TRAIN BATCHES 40 const TRAIN BATCHES numTrainBatches const TEST BATCH SIZE 1000 const TEST ITERATION FREQUENCY 5 for let i 0 i TRAIN BATCHES i const batch data nextTrainBatch BATCH SIZE Every few batches test the accuracy of the mode let testBatch let validationData if i TEST ITERATION FREQUENCY 0 i TRAIN BATCHES 1 testBatch data nextTestBatch TEST BATCH SIZE validationData testBatch xs reshape TEST BATCH SIZE 28 28 1 testBatch labels The entire dataset doesn t fit into memory so we call fit repeatedly with batches const history await this model fit batch xs reshape BATCH SIZE 28 28 1 batch labels batchSize BATCH SIZE validationData epochs 1 const loss history history loss 0 const accuracy history history acc 0 if testBatch null console log Batch i accuracy accuracy Invoke callback if callback callback history Clean up batch xs dispose batch labels dispose if testBatch null testBatch xs dispose testBatch labels dispose await tf nextFrame End train script section h3 CNN Tensorflow JS h3 p A 2 layer CNN with Tensorflow JS to recognize hand drawn digits MNIST Use the sliders to adjust the layer parameter below and then click Train to train the neural network Once trained draw a digit on the canvas to run live prediction p br div class row Left Panel div Layer 1 div class row span strong Layer 1 strong Kernel input id row one kernel type range min 2 max 8 value 5 step 1 oninput updateSliders span id row one kernel value span span nbsp nbsp nbsp span Filters input id row one filters type range min 4 max 24 value 8 step 1 oninput updateSliders span id row one filters value span span div br Layer 2 div class row span strong Layer 2 strong Kernel input id row two kernel type range min 2 max 8 value 5 step 1 oninput updateSliders span id row two kernel value span span nbsp nbsp nbsp span Filters input id row two filters type range min 4 max 24 value 16 step 1 oninput updateSliders span id row two filters value span span div br Batch div class row span Training batches input style width 15rem id num batch type range min 10 max 100 value 25 step 5 size 30 oninput updateSliders span id num batch value span span div br Train div button style font size 120 onclick run Train button nbsp nbsp span id training not trained span div Prediction div class row div style padding left 30px id predictions one div div style padding left 30px id predictions two div div div Right Panel div style padding left 30px div class row div canvas id input canvas br button style font size 120 onclick clearInputCanvas Clear canvas button div div style padding left 10px font size 150 div Prediction Is div div style font size 300 id user prediction div div div div div br section script let mnistData new MnistData let cnnModel new Model 0 15 cnnModel build Just init a default model so stuff doesn t error out function updateTraining history d3 select training text Accuracy history history acc 0 toFixed 2 function updateSliders let r1kernel d3 select row one kernel node value let r1filters d3 select row one filters node value let r2kernel d3 select row two kernel node value let r2filters d3 select row two filters node value let numBatch d3 select num batch node value console log r1kernel r1filters r2kernel r2filters d3 select row one kernel value text r1kernel d3 select row one filters value text r1filters d3 select row two kernel value text r2kernel d3 select row two filters value text r2filters d3 select num batch value text numBatch Copied from https github com tensorflow tfjs examples blob master mnist ui js function draw image canvas const width height 28 28 canvas width width canvas height height const ctx canvas getContext 2d const imageData new ImageData width height const data image dataSync for let i 0 i height width i const j i 4 imageData data j 0 data i 255 imageData data j 1 data i 255 imageData data j 2 data i 255 imageData data j 3 255 ctx putImageData imageData 0 0 function showPredictions batch predictions labels const len batch xs shape 0 const c1 d3 select predictions one const c2 d3 select predictions two for let i 0 i len i const image batch xs slice i 0 1 batch xs shape 1 const canvas document createElement canvas canvas className prediction canvas draw image flatten canvas const row i 2 0 c1 append div c2 append div row style height 30px row append span text Prediction predictions i Actual labels i row node appendChild canvas let inputCanvas null let inputCtx null let isMouseDown false let prev null let curr null function clearInputCanvas d3 select user prediction text if inputCtx null inputCtx fillStyle 000000 inputCtx fillRect 0 0 140 140 function setInputCanvas const line ctx x1 y1 x2 y2 ctx beginPath ctx moveTo x1 y1 ctx lineTo x2 y2 ctx closePath ctx stroke const getPosition canvas clientX clientY var rect inputCanvas getBoundingClientRect return x clientX rect left y clientY rect top const extract Pull const imgData inputCtx getImageData 0 0 140 140 Grab single channel const singleChannel imgData data filter d i i 4 0 Down sample const sample for let i 0 i 28 i for let j 0 j 28 j let c 0 for let x 0 x 5 x for let y 0 y 5 y c singleChannel 140 i 5 x j 5 y c 25 sample push c console log sample sample const tensor tf tensor sample const userInput tensor reshape 1 28 28 1 const output cnnModel predict userInput const axis 1 const predictions Array from output argMax axis dataSync d3 select user prediction text predictions console log prediction predictions const onMouseDown e isMouseDown true prev getPosition inputCanvas e clientX e clientY const onMouseUp e isMouseDown false extract const onMouseMove e if isMouseDown false return curr getPosition inputCanvas e clientX e clientY inputCtx lineWidth 6 inputCtx lineCap round inputCtx fillStyle FFFFFF inputCtx strokeStyle FFFFFF line inputCtx prev x prev y curr x curr y prev curr inputCanvas document getElementById input inputCanvas width 140 inputCanvas height 140 inputCtx inputCanvas getContext 2d inputCtx fillStyle 000000 inputCtx fillRect 0 0 140 140 inputCanvas addEventListener mousedown onMouseDown inputCanvas addEventListener mouseup onMouseUp inputCanvas addEventListener mousemove onMouseMove window inputCtx inputCtx window inputCanvas inputCanvas async function load console log Start loading await mnistData load console log Done loading async function train console log Start training let numBatch d3 select num batch node value let r1kernel d3 select row one kernel node value let r1filters d3 select row one filters node value let r2kernel d3 select row two kernel node value let r2filters d3 select row two filters node value cnnModel build kernelSize r1kernel filters r1filters kernelSize r2kernel filters r2filters await cnnModel train mnistData numBatch updateTraining console log Done training async function run d3 select predictions one selectAll remove d3 select predictions two selectAll remove d3 select training text loading await load await train const batch mnistData nextTestBatch 10 const output cnnModel predict batch xs reshape 1 28 28 1 const axis 1 const labels Array from batch labels argMax axis dataSync const predictions Array from output argMax axis dataSync showPredictions batch predictions labels run updateSliders setInputCanvas script
xs dispose batch labels dispose if testBatch null testBatch xs dispose testBatch labels dispose await tf nextFrame End train script section h3 CNN Tensorflow JS h3 p A 2 layer CNN with Tensorflow JS to recognize hand drawn digits MNIST Use the sliders to adjust the layer parameter below and then click Train to train the neural network Once trained draw a digit on the canvas to run live prediction p br div class row Left Panel div Layer 1 div class row span strong Layer 1 strong Kernel input id row one kernel type range min 2 max 8 value 5 step 1 oninput updateSliders span id row one kernel value span span nbsp nbsp nbsp span Filters input id row one filters type range min 4 max 24 value 8 step 1 oninput updateSliders span id row one filters value span span div br Layer 2 div class row span strong Layer 2 strong Kernel input id row two kernel type range min 2 max 8 value 5 step 1 oninput updateSliders span id row two kernel value span span nbsp nbsp nbsp span Filters input id row two filters type range min 4 max 24 value 16 step 1 oninput updateSliders span id row two filters value span span div br Batch div class row span Training batches input style width 15rem id num batch type range min 10 max 100 value 25 step 5 size 30 oninput updateSliders span id num batch value span span div br Train div button style font size 120 onclick run Train button nbsp nbsp span id training not trained span div Prediction div class row div style padding left 30px id predictions one div div style padding left 30px id predictions two div div div Right Panel div style padding left 30px div class row div canvas id input canvas br button style font size 120 onclick clearInputCanvas Clear canvas button div div style padding left 10px font size 150 div Prediction Is div div style font size 300 id user prediction div div div div div br section script let mnistData new MnistData let cnnModel new Model 0 15 cnnModel build Just init a default model so stuff doesn t error out function updateTraining history d3 select training text Accuracy history history acc 0 toFixed 2 function updateSliders let r1kernel d3 select row one kernel node value let r1filters d3 select row one filters node value let r2kernel d3 select row two kernel node value let r2filters d3 select row two filters node value let numBatch d3 select num batch node value console log r1kernel r1filters r2kernel r2filters d3 select row one kernel value text r1kernel d3 select row one filters value text r1filters d3 select row two kernel value text r2kernel d3 select row two filters value text r2filters d3 select num batch value text numBatch Copied from https github com tensorflow tfjs examples blob master mnist ui js function draw image canvas const width height 28 28 canvas width width canvas height height const ctx canvas getContext 2d const imageData new ImageData width height const data image dataSync for let i 0 i height width i const j i 4 imageData data j 0 data i 255 imageData data j 1 data i 255 imageData data j 2 data i 255 imageData data j 3 255 ctx putImageData imageData 0 0 function showPredictions batch predictions labels const len batch xs shape 0 const c1 d3 select predictions one const c2 d3 select predictions two for let i 0 i len i const image batch xs slice i 0 1 batch xs shape 1 const canvas document createElement canvas canvas className prediction canvas draw image flatten canvas const row i 2 0 c1 append div c2 append div row style height 30px row append span text Prediction predictions i Actual labels i row node appendChild canvas let inputCanvas null let inputCtx null let isMouseDown false let prev null let curr null function clearInputCanvas d3 select user prediction text if inputCtx null inputCtx fillStyle 000000 inputCtx fillRect 0 0 140 140 function setInputCanvas const line ctx x1 y1 x2 y2 ctx beginPath ctx moveTo x1 y1 ctx lineTo x2 y2 ctx closePath ctx stroke const getPosition canvas clientX clientY var rect inputCanvas getBoundingClientRect return x clientX rect left y clientY rect top const extract Pull const imgData inputCtx getImageData 0 0 140 140 Grab single channel const singleChannel imgData data filter d i i 4 0 Down sample const sample for let i 0 i 28 i for let j 0 j 28 j let c 0 for let x 0 x 5 x for let y 0 y 5 y c singleChannel 140 i 5 x j 5 y c 25 sample push c console log sample sample const tensor tf tensor sample const userInput tensor reshape 1 28 28 1 const output cnnModel predict userInput const axis 1 const predictions Array from output argMax axis dataSync d3 select user prediction text predictions console log prediction predictions const onMouseDown e isMouseDown true prev getPosition inputCanvas e clientX e clientY const onMouseUp e isMouseDown false extract const onMouseMove e if isMouseDown false return curr getPosition inputCanvas e clientX e clientY inputCtx lineWidth 6 inputCtx lineCap round inputCtx fillStyle FFFFFF inputCtx strokeStyle FFFFFF line inputCtx prev x prev y curr x curr y prev curr inputCanvas document getElementById input inputCanvas width 140 inputCanvas height 140 inputCtx inputCanvas getContext 2d inputCtx fillStyle 000000 inputCtx fillRect 0 0 140 140 inputCanvas addEventListener mousedown onMouseDown inputCanvas addEventListener mouseup onMouseUp inputCanvas addEventListener mousemove onMouseMove window inputCtx inputCtx window inputCanvas inputCanvas async function load console log Start loading await mnistData load console log Done loading async function train console log Start training let numBatch d3 select num batch node value let r1kernel d3 select row one kernel node value let r1filters d3 select row one filters node value let r2kernel d3 select row two kernel node value let r2filters d3 select row two filters node value cnnModel build kernelSize r1kernel filters r1filters kernelSize r2kernel filters r2filters await cnnModel train mnistData numBatch updateTraining console log Done training async function run d3 select predictions one selectAll remove d3 select predictions two selectAll remove d3 select training text loading await load await train const batch mnistData nextTestBatch 10 const output cnnModel predict batch xs reshape 1 28 28 1 const axis 1 const labels Array from batch labels argMax axis dataSync const predictions Array from output argMax axis dataSync showPredictions batch predictions labels run updateSliders setInputCanvas script