Simple MNIST GAN using TensorflowJS script src https cdnjs cloudflare com ajax libs tensorflow 1 2 2 tf js script script image utils js class ImageUtil Flattens a RGBA channel data into grey scale float array static flatten data options const w options width 0 const h options height 0 const flat for let i 0 i w h i const j i 4 const newVal data j 0 data j 1 data j 2 data j 3 4 0 flat push newVal 255 0 return flat Unflatten single channel to RGBA static unflatten data options const w options width 0 const h options height 0 const unflat for let i 0 i w h i const val data i unflat push data i 255 unflat push data i 255 unflat push data i 255 unflat push 255 return unflat static async loadImage url options const img new Image const canvas document createElement canvas const ctx canvas getContext 2d window ctx ctx const imgRequest new Promise resolve reject img crossOrigin img onload img width options width img naturalWidth img height options height img naturalHeight canvas width img width canvas height img height ctx drawImage img 0 0 img width img height const imageData ctx getImageData 0 0 canvas width canvas height ctx drawImage img 0 i chunkSize img width chunkSize 0 0 img width chunkSize resolve imageData img src url return imgRequest script script data import as tf from tensorflow tfjs const IMAGE SIZE 784 const NUM CLASSES 10 const NUM DATASET ELEMENTS 65000 const NUM TRAIN ELEMENTS 55000 const NUM TEST ELEMENTS NUM DATASET ELEMENTS NUM TRAIN ELEMENTS const MNIST IMAGES SPRITE PATH https storage googleapis com learnjs data model builder mnist images png const MNIST LABELS PATH https storage googleapis com learnjs data model builder mnist labels uint8 const MNIST IMAGES SPRITE PATH https notebook xbdev net var resources mnist images png const MNIST LABELS PATH https notebook xbdev net var resources mnist labels uint8 A class that fetches the sprited MNIST dataset and returns shuffled batches NOTE This will get much easier For now we do data fetching and manipulation manually class MnistData constructor this shuffledTrainIndex 0 this shuffledTestIndex 0 async load Make a request for the MNIST sprited image const img new Image const canvas document createElement canvas const ctx canvas getContext 2d const imgRequest new Promise resolve reject img crossOrigin img onload img width img naturalWidth img height img naturalHeight const datasetBytesBuffer new ArrayBuffer NUM DATASET ELEMENTS IMAGE SIZE 4 const chunkSize 5000 canvas width img width canvas height chunkSize for let i 0 i NUM DATASET ELEMENTS chunkSize i const datasetBytesView new Float32Array datasetBytesBuffer i IMAGE SIZE chunkSize 4 IMAGE SIZE chunkSize ctx drawImage img 0 i chunkSize img width chunkSize 0 0 img width chunkSize const imageData ctx getImageData 0 0 canvas width canvas height for let j 0 j imageData data length 4 j All channels hold an equal value since the image is grayscale so just read the red channel datasetBytesView j imageData data j 4 255 console log Processed chunk i this datasetImages new Float32Array datasetBytesBuffer resolve img src MNIST IMAGES SPRITE PATH const labelsRequest fetch MNIST LABELS PATH const imgResponse labelsResponse await Promise all imgRequest labelsRequest this datasetLabels new Uint8Array await labelsResponse arrayBuffer Create shuffled indices into the train test set for when we select a random dataset element for training validation this trainIndices tf util createShuffledIndices NUM TRAIN ELEMENTS this testIndices tf util createShuffledIndices NUM TEST ELEMENTS Slice the the images and labels into train and test sets this trainImages this datasetImages slice 0 IMAGE SIZE NUM TRAIN ELEMENTS this testImages this datasetImages slice IMAGE SIZE NUM TRAIN ELEMENTS this trainLabels this datasetLabels slice 0 NUM CLASSES NUM TRAIN ELEMENTS this testLabels this datasetLabels slice NUM CLASSES NUM TRAIN ELEMENTS nextTrainBatch batchSize return this nextBatch batchSize this trainImages this trainLabels this shuffledTrainIndex this shuffledTrainIndex 1 this trainIndices length return this trainIndices this shuffledTrainIndex nextTestBatch batchSize return this nextBatch batchSize this testImages this testLabels this shuffledTestIndex this shuffledTestIndex 1 this testIndices length return this testIndices this shuffledTestIndex nextBatch batchSize data index const batchImagesArray new Float32Array batchSize IMAGE SIZE const batchLabelsArray new Uint8Array batchSize NUM CLASSES for let i 0 i batchSize i const idx index const image data 0 slice idx IMAGE SIZE idx IMAGE SIZE IMAGE SIZE batchImagesArray set image i IMAGE SIZE const label data 1 slice idx NUM CLASSES idx NUM CLASSES NUM CLASSES batchLabelsArray set label i NUM CLASSES const xs tf tensor2d batchImagesArray batchSize IMAGE SIZE const labels tf tensor2d batchLabelsArray batchSize NUM CLASSES return xs labels script script gan Input params const BATCH 200 const SIZE 28 const INPUT SIZE SIZE SIZE const SEED SIZE 40 const SEED STD 3 5 const ONES tf ones BATCH 1 const ONES PRIME tf ones BATCH 1 mul tf scalar 0 98 const ZEROS tf zeros BATCH 1 Generator and discrimantor params const DISCRIMINATOR LEARNING RATE 0 025 const GENERATOR LEARNING RATE 0 025 const dOptimizer tf train sgd DISCRIMINATOR LEARNING RATE const gOptimizer tf train sgd GENERATOR LEARNING RATE Helper functions const varInitNormal shape mean 0 std 0 1 tf variable tf randomNormal shape mean std const varLoad shape data tf variable tf tensor shape data const seed s BATCH tf randomNormal s SEED SIZE 0 SEED STD Network arch for generator let G1w varInitNormal SEED SIZE 140 let G1b varInitNormal 140 let G2w varInitNormal 140 80 let G2b varInitNormal 80 let G3w varInitNormal 80 INPUT SIZE let G3b varInitNormal INPUT SIZE Network arch for discriminator let D1w varInitNormal INPUT SIZE 200 let D1b varInitNormal 200 let D2w varInitNormal 200 90 let D2b varInitNormal 90 let D3w varInitNormal 90 1 let D3b varInitNormal 1 GAN functions function gen xs const l1 tf leakyRelu xs matMul G1w add G1b const l2 tf leakyRelu l1 matMul G2w add G2b const l3 tf tanh l2 matMul G3w add G3b return l3 function disReal xs const l1 tf leakyRelu xs matMul D1w add D1b const l2 tf leakyRelu l1 matMul D2w add D2b const logits l2 matMul D3w add D3b const output tf sigmoid logits return logits output function disFake xs return disReal gen xs Copied from tensorflow core function sigmoidCrossEntropyWithLogits target output return tf tidy function let maxOutput tf maximum output tf zerosLike output let outputXTarget tf mul output target let sigmoidOutput tf log tf add tf scalar 1 0 tf exp tf neg tf abs output let result tf add tf sub maxOutput outputXTarget sigmoidOutput return result Single batch training async function trainBatch realBatch fakeBatch const dcost dOptimizer minimize const logitsReal outputReal disReal realBatch const logitsFake outputFake disFake fakeBatch const lossReal sigmoidCrossEntropyWithLogits ONES PRIME logitsReal const lossFake sigmoidCrossEntropyWithLogits ZEROS logitsFake return lossReal add lossFake mean true D1w D1b D2w D2b D3w D3b await tf nextFrame const gcost gOptimizer minimize const logitsFake outputFake disFake fakeBatch const lossFake sigmoidCrossEntropyWithLogits ONES logitsFake return lossFake mean true G1w G1b G2w G2b G3w G3b await tf nextFrame return dcost gcost script style body height 100 padding 0 margin 0 font family Tahoma Verdana font size 14px display flex justify content center background DDD section background FDFDFD margin 0 4rem padding 0 2rem flex grow 0 button font size 100 margin 5px style section h4 Simple MNIST GAN using TensorflowJS h4 p Hand written digit generation using Generative Adversarial Network GAN TensorflowJS implementation and vanilla Javascript all here p table style margin left 20px tr td Early stages td td img src https notebook xbdev net var images sample early png height 30px img td tr tr td Getting better td td img src https notebook xbdev net var images sample mid png height 30px img td tr tr td Later still td td img src https notebook xbdev net var images sample late png height 30px img td tr table p Click strong Train strong to train for an additional 5 epochs Click strong Sample image strong to generate a sample output using the current weights The network should start to converge after 15 20 epochs p button id train onclick train 1500 Train1500 button button onclick sampleImage Sample image button br p id load status br Loading resources this may take a few seconds br p br br div id samples container div br section script const mnistData new MnistData async function loadMnist console log Start loading document querySelectorAll button forEach d d disabled true await mnistData load console log Done loading document querySelectorAll button forEach d d disabled false document querySelector load status style display none async function train num 1000 console log starting document querySelector train disabled true for let i 0 i num i document querySelector train innerHTML i num const real mnistData nextTrainBatch BATCH const fake seed const dcost gcost await trainBatch real xs fake if i 50 0 i num 1 console log i i console log discriminator cost dcost dataSync console log generator cost gcost dataSync document querySelector train innerHTML Train document querySelector train disabled false console log done async function sampleImage await tf nextFrame const options width SIZE height SIZE const canvas document createElement canvas canvas width options width canvas height options height const ctx canvas getContext 2d const imageData new ImageData options width options height const data gen seed 1 dataSync Undo tanh for let i 0 i data length i data i 0 5 data i 1 0 const unflat ImageUtil unflatten data options for let i 0 i unflat length i imageData data i unflat i ctx putImageData imageData 0 0 document body querySelector samples container appendChild canvas async function start await loadMnist start console log ready script
SIZE SIZE const SEED SIZE 40 const SEED STD 3 5 const ONES tf ones BATCH 1 const ONES PRIME tf ones BATCH 1 mul tf scalar 0 98 const ZEROS tf zeros BATCH 1 Generator and discrimantor params const DISCRIMINATOR LEARNING RATE 0 025 const GENERATOR LEARNING RATE 0 025 const dOptimizer tf train sgd DISCRIMINATOR LEARNING RATE const gOptimizer tf train sgd GENERATOR LEARNING RATE Helper functions const varInitNormal shape mean 0 std 0 1 tf variable tf randomNormal shape mean std const varLoad shape data tf variable tf tensor shape data const seed s BATCH tf randomNormal s SEED SIZE 0 SEED STD Network arch for generator let G1w varInitNormal SEED SIZE 140 let G1b varInitNormal 140 let G2w varInitNormal 140 80 let G2b varInitNormal 80 let G3w varInitNormal 80 INPUT SIZE let G3b varInitNormal INPUT SIZE Network arch for discriminator let D1w varInitNormal INPUT SIZE 200 let D1b varInitNormal 200 let D2w varInitNormal 200 90 let D2b varInitNormal 90 let D3w varInitNormal 90 1 let D3b varInitNormal 1 GAN functions function gen xs const l1 tf leakyRelu xs matMul G1w add G1b const l2 tf leakyRelu l1 matMul G2w add G2b const l3 tf tanh l2 matMul G3w add G3b return l3 function disReal xs const l1 tf leakyRelu xs matMul D1w add D1b const l2 tf leakyRelu l1 matMul D2w add D2b const logits l2 matMul D3w add D3b const output tf sigmoid logits return logits output function disFake xs return disReal gen xs Copied from tensorflow core function sigmoidCrossEntropyWithLogits target output return tf tidy function let maxOutput tf maximum output tf zerosLike output let outputXTarget tf mul output target let sigmoidOutput tf log tf add tf scalar 1 0 tf exp tf neg tf abs output let result tf add tf sub maxOutput outputXTarget sigmoidOutput return result Single batch training async function trainBatch realBatch fakeBatch const dcost dOptimizer minimize const logitsReal outputReal disReal realBatch const logitsFake outputFake disFake fakeBatch const lossReal sigmoidCrossEntropyWithLogits ONES PRIME logitsReal const lossFake sigmoidCrossEntropyWithLogits ZEROS logitsFake return lossReal add lossFake mean true D1w D1b D2w D2b D3w D3b await tf nextFrame const gcost gOptimizer minimize const logitsFake outputFake disFake fakeBatch const lossFake sigmoidCrossEntropyWithLogits ONES logitsFake return lossFake mean true G1w G1b G2w G2b G3w G3b await tf nextFrame return dcost gcost script style body height 100 padding 0 margin 0 font family Tahoma Verdana font size 14px display flex justify content center background DDD section background FDFDFD margin 0 4rem padding 0 2rem flex grow 0 button font size 100 margin 5px style section h4 Simple MNIST GAN using TensorflowJS h4 p Hand written digit generation using Generative Adversarial Network GAN TensorflowJS implementation and vanilla Javascript all here p table style margin left 20px tr td Early stages td td img src https notebook xbdev net var images sample early png height 30px img td tr tr td Getting better td td img src https notebook xbdev net var images sample mid png height 30px img td tr tr td Later still td td img src https notebook xbdev net var images sample late png height 30px img td tr table p Click strong Train strong to train for an additional 5 epochs Click strong Sample image strong to generate a sample output using the current weights The network should start to converge after 15 20 epochs p button id train onclick train 1500 Train1500 button button onclick sampleImage Sample image button br p id load status br Loading resources this may take a few seconds br p br br div id samples container div br section script const mnistData new MnistData async function loadMnist console log Start loading document querySelectorAll button forEach d d disabled true await mnistData load console log Done loading document querySelectorAll button forEach d d disabled false document querySelector load status style display none async function train num 1000 console log starting document querySelector train disabled true for let i 0 i num i document querySelector train innerHTML i num const real mnistData nextTrainBatch BATCH const fake seed const dcost gcost await trainBatch real xs fake if i 50 0 i num 1 console log i i console log discriminator cost dcost dataSync console log generator cost gcost dataSync document querySelector train innerHTML Train document querySelector train disabled false console log done async function sampleImage await tf nextFrame const options width SIZE height SIZE const canvas document createElement canvas canvas width options width canvas height options height const ctx canvas getContext 2d const imageData new ImageData options width options height const data gen seed 1 dataSync Undo tanh for let i 0 i data length i data i 0 5 data i 1 0 const unflat ImageUtil unflatten data options for let i 0 i unflat length i imageData data i unflat i ctx putImageData imageData 0 0 document body querySelector samples container appendChild canvas async function start await loadMnist start console log ready script