Simple MNIST GAN using TensorflowJS script src https cdnjs cloudflare com ajax libs tensorflow 1 2 2 tf js script script image utils js class ImageUtil Flattens a RGBA channel data into grey scale float array static flatten data options const w options width 0 const h options height 0 const flat for let i 0 i w h i const j i 4 const newVal data j 0 data j 1 data j 2 data j 3 4 0 flat push newVal 255 0 return flat Unflatten single channel to RGBA static unflatten data options const w options width 0 const h options height 0 const unflat for let i 0 i w h i const val data i unflat push data i 255 unflat push data i 255 unflat push data i 255 unflat push 255 return unflat static async loadImage url options const img new Image const canvas document createElement canvas const ctx canvas getContext 2d window ctx ctx const imgRequest new Promise resolve reject img crossOrigin img onload img width options width img naturalWidth img height options height img naturalHeight canvas width img width canvas height img height ctx drawImage img 0 0 img width img height const imageData ctx getImageData 0 0 canvas width canvas height ctx drawImage img 0 i chunkSize img width chunkSize 0 0 img width chunkSize resolve imageData img src url return imgRequest script script data import as tf from tensorflow tfjs const IMAGE_SIZE 784 const NUM_CLASSES 10 const NUM_DATASET_ELEMENTS 65000 const NUM_TRAIN_ELEMENTS 55000 const NUM_TEST_ELEMENTS NUM_DATASET_ELEMENTS NUM_TRAIN_ELEMENTS const MNIST_IMAGES_SPRITE_PATH https storage googleapis com learnjs data model builder mnist_images png const MNIST_LABELS_PATH https storage googleapis com learnjs data model builder mnist_labels_uint8 const MNIST_IMAGES_SPRITE_PATH https notebook xbdev net var resources mnist_images png const MNIST_LABELS_PATH https notebook xbdev net var resources mnist_labels_uint8 A class that fetches the sprited MNIST dataset and returns shuffled batches NOTE This will get much easier For now we do data fetching and manipulation manually class MnistData constructor this shuffledTrainIndex 0 this shuffledTestIndex 0 async load Make a request for the MNIST sprited image const img new Image const canvas document createElement canvas const ctx canvas getContext 2d const imgRequest new Promise resolve reject img crossOrigin img onload img width img naturalWidth img height img naturalHeight const datasetBytesBuffer new ArrayBuffer NUM_DATASET_ELEMENTS IMAGE_SIZE 4 const chunkSize 5000 canvas width img width canvas height chunkSize for let i 0 i NUM_DATASET_ELEMENTS chunkSize i const datasetBytesView new Float32Array datasetBytesBuffer i IMAGE_SIZE chunkSize 4 IMAGE_SIZE chunkSize ctx drawImage img 0 i chunkSize img width chunkSize 0 0 img width chunkSize const imageData ctx getImageData 0 0 canvas width canvas height for let j 0 j imageData data length 4 j All channels hold an equal value since the image is grayscale so just read the red channel datasetBytesView j imageData data j 4 255 console log Processed chunk i this datasetImages new Float32Array datasetBytesBuffer resolve img src MNIST_IMAGES_SPRITE_PATH const labelsRequest fetch MNIST_LABELS_PATH const imgResponse labelsResponse await Promise all imgRequest labelsRequest this datasetLabels new Uint8Array await labelsResponse arrayBuffer Create shuffled indices into the train test set for when we select a random dataset element for training validation this trainIndices tf util createShuffledIndices NUM_TRAIN_ELEMENTS this testIndices tf util createShuffledIndices NUM_TEST_ELEMENTS Slice the the images and labels into train and test sets this trainImages this datasetImages slice 0 IMAGE_SIZE NUM_TRAIN_ELEMENTS this testImages this datasetImages slice IMAGE_SIZE NUM_TRAIN_ELEMENTS this trainLabels this datasetLabels slice 0 NUM_CLASSES NUM_TRAIN_ELEMENTS this testLabels this datasetLabels slice NUM_CLASSES NUM_TRAIN_ELEMENTS nextTrainBatch batchSize return this nextBatch batchSize this trainImages this trainLabels this shuffledTrainIndex this shuffledTrainIndex 1 this trainIndices length return this trainIndices this shuffledTrainIndex nextTestBatch batchSize return this nextBatch batchSize this testImages this testLabels this shuffledTestIndex this shuffledTestIndex 1 this testIndices length return this testIndices this shuffledTestIndex nextBatch batchSize data index const batchImagesArray new Float32Array batchSize IMAGE_SIZE const batchLabelsArray new Uint8Array batchSize NUM_CLASSES for let i 0 i batchSize i const idx index const image data 0 slice idx IMAGE_SIZE idx IMAGE_SIZE IMAGE_SIZE batchImagesArray set image i IMAGE_SIZE const label data 1 slice idx NUM_CLASSES idx NUM_CLASSES NUM_CLASSES batchLabelsArray set label i NUM_CLASSES const xs tf tensor2d batchImagesArray batchSize IMAGE_SIZE const labels tf tensor2d batchLabelsArray batchSize NUM_CLASSES return xs labels script script gan Input params const BATCH 200 const SIZE 28 const INPUT_SIZE SIZE SIZE const SEED_SIZE 40 const SEED_STD 3 5 const ONES tf ones BATCH 1 const ONES_PRIME tf ones BATCH 1 mul tf scalar 0 98 const ZEROS tf zeros BATCH 1 Generator and discrimantor params const DISCRIMINATOR_LEARNING_RATE 0 025 const GENERATOR_LEARNING_RATE 0 025 const dOptimizer tf train sgd DISCRIMINATOR_LEARNING_RATE const gOptimizer tf train sgd GENERATOR_LEARNING_RATE Helper functions const varInitNormal shape mean 0 std 0 1 tf variable tf randomNormal shape mean std const varLoad shape data tf variable tf tensor shape data const seed s BATCH tf randomNormal s SEED_SIZE 0 SEED_STD Network arch for generator let G1w varInitNormal SEED_SIZE 140 let G1b varInitNormal 140 let G2w varInitNormal 140 80 let G2b varInitNormal 80 let G3w varInitNormal 80 INPUT_SIZE let G3b varInitNormal INPUT_SIZE Network arch for discriminator let D1w varInitNormal INPUT_SIZE 200 let D1b varInitNormal 200 let D2w varInitNormal 200 90 let D2b varInitNormal 90 let D3w varInitNormal 90 1 let D3b varInitNormal 1 GAN functions function gen xs const l1 tf leakyRelu xs matMul G1w add G1b const l2 tf leakyRelu l1 matMul G2w add G2b const l3 tf tanh l2 matMul G3w add G3b return l3 function disReal xs const l1 tf leakyRelu xs matMul D1w add D1b const l2 tf leakyRelu l1 matMul D2w add D2b const logits l2 matMul D3w add D3b const output tf sigmoid logits return logits output function disFake xs return disReal gen xs Copied from tensorflow core function sigmoidCrossEntropyWithLogits target output return tf tidy function let maxOutput tf maximum output tf zerosLike output let outputXTarget tf mul output target let sigmoidOutput tf log tf add tf scalar 1 0 tf exp tf neg tf abs output let result tf add tf sub maxOutput outputXTarget sigmoidOutput return result Single batch training async function trainBatch realBatch fakeBatch const dcost dOptimizer minimize const logitsReal outputReal disReal realBatch const logitsFake outputFake disFake fakeBatch const lossReal sigmoidCrossEntropyWithLogits ONES_PRIME logitsReal const lossFake sigmoidCrossEntropyWithLogits ZEROS logitsFake return lossReal add lossFake mean true D1w D1b D2w D2b D3w D3b await tf nextFrame const gcost gOptimizer minimize const logitsFake outputFake disFake fakeBatch const lossFake sigmoidCrossEntropyWithLogits ONES logitsFake return lossFake mean true G1w G1b G2w G2b G3w G3b await tf nextFrame return dcost gcost script style body height 100 padding 0 margin 0 font family Tahoma Verdana font size 14px display flex justify content center background DDD section background FDFDFD margin 0 4rem padding 0 2rem flex grow 0 button font size 100 margin 5px style section h4 Simple MNIST GAN using TensorflowJS h4 p Hand written digit generation using Generative Adversarial Network GAN TensorflowJS implementation and vanilla Javascript all here p table style margin left 20px tr td Early stages td td img src https notebook xbdev net var images sample early png height 30px img td tr tr td Getting better td td img src https notebook xbdev net var images sample mid png height 30px img td tr tr td Later still td td img src https notebook xbdev net var images sample late png height 30px img td tr table p Click strong Train strong to train for an additional 5 epochs Click strong Sample image strong to generate a sample output using the current weights The network should start to converge after 15 20 epochs p button id train onclick train 1500 Train1500 button button onclick sampleImage Sample image button br p id load status br Loading resources this may take a few seconds br p br br div id samples container div br section script const mnistData new MnistData async function loadMnist console log Start loading document querySelectorAll button forEach d d disabled true await mnistData load console log Done loading document querySelectorAll button forEach d d disabled false document querySelector load status style display none async function train num 1000 console log starting document querySelector train disabled true for let i 0 i num i document querySelector train innerHTML i num const real mnistData nextTrainBatch BATCH const fake seed const dcost gcost await trainBatch real xs fake if i 50 0 i num 1 console log i i console log discriminator cost dcost dataSync console log generator cost gcost dataSync document querySelector train innerHTML Train document querySelector train disabled false console log done async function sampleImage await tf nextFrame const options width SIZE height SIZE const canvas document createElement canvas canvas width options width canvas height options height const ctx canvas getContext 2d const imageData new ImageData options width options height const data gen seed 1 dataSync Undo tanh for let i 0 i data length i data i 0 5 data i 1 0 const unflat ImageUtil unflatten data options for let i 0 i unflat length i imageData data i unflat i ctx putImageData imageData 0 0 document body querySelector samples container appendChild canvas async function start await loadMnist start console log ready script
SIZE SIZE const SEED_SIZE 40 const SEED_STD 3 5 const ONES tf ones BATCH 1 const ONES_PRIME tf ones BATCH 1 mul tf scalar 0 98 const ZEROS tf zeros BATCH 1 Generator and discrimantor params const DISCRIMINATOR_LEARNING_RATE 0 025 const GENERATOR_LEARNING_RATE 0 025 const dOptimizer tf train sgd DISCRIMINATOR_LEARNING_RATE const gOptimizer tf train sgd GENERATOR_LEARNING_RATE Helper functions const varInitNormal shape mean 0 std 0 1 tf variable tf randomNormal shape mean std const varLoad shape data tf variable tf tensor shape data const seed s BATCH tf randomNormal s SEED_SIZE 0 SEED_STD Network arch for generator let G1w varInitNormal SEED_SIZE 140 let G1b varInitNormal 140 let G2w varInitNormal 140 80 let G2b varInitNormal 80 let G3w varInitNormal 80 INPUT_SIZE let G3b varInitNormal INPUT_SIZE Network arch for discriminator let D1w varInitNormal INPUT_SIZE 200 let D1b varInitNormal 200 let D2w varInitNormal 200 90 let D2b varInitNormal 90 let D3w varInitNormal 90 1 let D3b varInitNormal 1 GAN functions function gen xs const l1 tf leakyRelu xs matMul G1w add G1b const l2 tf leakyRelu l1 matMul G2w add G2b const l3 tf tanh l2 matMul G3w add G3b return l3 function disReal xs const l1 tf leakyRelu xs matMul D1w add D1b const l2 tf leakyRelu l1 matMul D2w add D2b const logits l2 matMul D3w add D3b const output tf sigmoid logits return logits output function disFake xs return disReal gen xs Copied from tensorflow core function sigmoidCrossEntropyWithLogits target output return tf tidy function let maxOutput tf maximum output tf zerosLike output let outputXTarget tf mul output target let sigmoidOutput tf log tf add tf scalar 1 0 tf exp tf neg tf abs output let result tf add tf sub maxOutput outputXTarget sigmoidOutput return result Single batch training async function trainBatch realBatch fakeBatch const dcost dOptimizer minimize const logitsReal outputReal disReal realBatch const logitsFake outputFake disFake fakeBatch const lossReal sigmoidCrossEntropyWithLogits ONES_PRIME logitsReal const lossFake sigmoidCrossEntropyWithLogits ZEROS logitsFake return lossReal add lossFake mean true D1w D1b D2w D2b D3w D3b await tf nextFrame const gcost gOptimizer minimize const logitsFake outputFake disFake fakeBatch const lossFake sigmoidCrossEntropyWithLogits ONES logitsFake return lossFake mean true G1w G1b G2w G2b G3w G3b await tf nextFrame return dcost gcost script style body height 100 padding 0 margin 0 font family Tahoma Verdana font size 14px display flex justify content center background DDD section background FDFDFD margin 0 4rem padding 0 2rem flex grow 0 button font size 100 margin 5px style section h4 Simple MNIST GAN using TensorflowJS h4 p Hand written digit generation using Generative Adversarial Network GAN TensorflowJS implementation and vanilla Javascript all here p table style margin left 20px tr td Early stages td td img src https notebook xbdev net var images sample early png height 30px img td tr tr td Getting better td td img src https notebook xbdev net var images sample mid png height 30px img td tr tr td Later still td td img src https notebook xbdev net var images sample late png height 30px img td tr table p Click strong Train strong to train for an additional 5 epochs Click strong Sample image strong to generate a sample output using the current weights The network should start to converge after 15 20 epochs p button id train onclick train 1500 Train1500 button button onclick sampleImage Sample image button br p id load status br Loading resources this may take a few seconds br p br br div id samples container div br section script const mnistData new MnistData async function loadMnist console log Start loading document querySelectorAll button forEach d d disabled true await mnistData load console log Done loading document querySelectorAll button forEach d d disabled false document querySelector load status style display none async function train num 1000 console log starting document querySelector train disabled true for let i 0 i num i document querySelector train innerHTML i num const real mnistData nextTrainBatch BATCH const fake seed const dcost gcost await trainBatch real xs fake if i 50 0 i num 1 console log i i console log discriminator cost dcost dataSync console log generator cost gcost dataSync document querySelector train innerHTML Train document querySelector train disabled false console log done async function sampleImage await tf nextFrame const options width SIZE height SIZE const canvas document createElement canvas canvas width options width canvas height options height const ctx canvas getContext 2d const imageData new ImageData options width options height const data gen seed 1 dataSync Undo tanh for let i 0 i data length i data i 0 5 data i 1 0 const unflat ImageUtil unflatten data options for let i 0 i unflat length i imageData data i unflat i ctx putImageData imageData 0 0 document body querySelector samples container appendChild canvas async function start await loadMnist start console log ready script