PyTorch and WebNN Train and save the neural network using PyTorch for this example a simple XOR neural net Save the dictionary and a json version of the network The json version of the network make a flexible future proof version of the data that we can easily pass across to the WebNN JavaScript Python Code import torch import torch nn as nn import torch optim as optim import json Define the neural network class XORNet nn Module def __init__ self super XORNet self __init__ self fc1 nn Linear 2 2 self fc2 nn Linear 2 1 def forward self x x torch sigmoid self fc1 x x torch sigmoid self fc2 x return x Training data inputs torch tensor 0 0 0 1 1 0 1 1 dtype torch float32 targets torch tensor 0 1 1 0 dtype torch float32 Initialize the network criterion and optimizer model XORNet criterion nn MSELoss optimizer optim SGD model parameters lr 0 1 Training loop with progress updates num_epochs 5000 for epoch in range num_epochs optimizer zero_grad output model inputs loss criterion output targets loss backward optimizer step Print progress every 100 epochs if epoch 1 500 0 or epoch num_epochs 1 print f Epoch epoch 1 num_epochs Loss loss item 4f Progress epoch 1 num_epochs 500 2f Save the trained model to a file torch save model state_dict xor_net pth Convert the state dictionary to JSON format state_dict model state_dict state_dict_json k v tolist for k v in state_dict items with open xor_net json w as f json dump state_dict_json f Print the weights and biases for name param in model named_parameters if param requires_grad print f name param data numpy prints fc1 weight 0 3043185 0 6067796 1 8907964 1 8609867 fc1 bias 0 4322342 0 1948249 fc2 weight 0 01738905 1 469708 fc2 bias 1 16437 If you need to run a server as a localhost to access the files i e json file you can use the following script The cors feature has been enabled so you can access the file using other domains not just localhost Run server directly without py file using this line python m http server Or file version local py which you run with python local py import http server import socketserver Configuration PORT 8000 Port number to serve on e g http localhost 8000 class CORSRequestHandler http server SimpleHTTPRequestHandler def end_headers self self send_header Access Control Allow Origin self send_header Access Control Allow Methods GET POST OPTIONS self send_header Access Control Allow Headers x api key Content Type http server SimpleHTTPRequestHandler end_headers self Create and start the server with socketserver TCPServer PORT CORSRequestHandler as httpd print f Serving at http localhost PORT try httpd serve_forever except KeyboardInterrupt print nShutting down server httpd server_close The code below loads and uses the trained network using WebNN Load the json data in which was trained on PyTorch to be used by WebNN Fetch the model file const response await fetch http localhost 8000 xor_net json const json await response json console log json json let weights1 json fc1 weight flat 2 let bias1 json fc1 bias flat 2 let weights2 json fc2 weight flat 2 let bias2 json fc2 bias flat 2 console log weights1 weights1 console log bias1 bias1 console log weights2 weights2 console log bias2 bias2 json fc1 weight 0 6804810762405396 0 2847576439380646 0 2714926302433014 0 18927520513534546 fc1 bias 0 24895617365837097 0 48802873492240906 fc2 weight 0 3371845781803131 0 3348560333251953 fc2 bias 0 2576610743999481 weights1 0 6804810762405396 0 2847576439380646 0 2714926302433014 0 18927520513534546 bias1 0 24895617365837097 0 48802873492240906 weights2 0 3371845781803131 0 3348560333251953 bias2 0 2576610743999481 const weights1 0 6804810762405396 0 2847576439380646 0 2714926302433014 0 18927520513534546 const bias1 0 24895617365837097 0 48802873492240906 const weights2 0 3371845781803131 0 3348560333251953 const bias2 0 2576610743999481 const context await navigator ml createContext const builder new MLGraphBuilder context Define WebNN network Define input placeholders const inputShape 1 2 Input shape should be 1 2 for 2 input neurons as it s a 1d array const inputType float32 const outputShape 1 1 Output shape 1 for single output const input builder input input dataType inputType shape inputShape Create weights and biases as constants const W1 builder constant dataType inputType shape 2 2 new Float32Array weights1 const b1 builder constant dataType inputType shape 1 2 new Float32Array bias1 const W2 builder constant dataType inputType shape 2 1 new Float32Array weights2 const b2 builder constant dataType inputType shape 1 1 new Float32Array bias2 Create the network const hiddenLayer builder sigmoid builder add builder matmul input W1 b1 const outputLayer builder sigmoid builder add builder matmul hiddenLayer W2 b2 Build the graph const graph await builder build output outputLayer Create reusable tensors for inputs and output const inputTensor await context createTensor dataType inputType shape inputShape writable true const outputTensor await context createTensor dataType inputType shape outputShape readable true Define different input values for testing const inputValuesList new Float32Array 0 0 0 0 new Float32Array 0 0 1 0 new Float32Array 1 0 0 0 new Float32Array 1 0 1 0 Execute the graph with different input values for const inputValues of inputValuesList Write the input values to the tensor await context writeTensor inputTensor inputValues Execute the graph const inputs input inputTensor const outputs output outputTensor await context dispatch graph inputs outputs Read and print the result const result await context readTensor outputTensor const out new Float32Array result console log Input Object values inputValues Output Object values out
odel file const response await fetch http localhost 8000 xor_net json const json await response json console log json json let weights1 json fc1 weight flat 2 let bias1 json fc1 bias flat 2 let weights2 json fc2 weight flat 2 let bias2 json fc2 bias flat 2 console log weights1 weights1 console log bias1 bias1 console log weights2 weights2 console log bias2 bias2 json fc1 weight 0 6804810762405396 0 2847576439380646 0 2714926302433014 0 18927520513534546 fc1 bias 0 24895617365837097 0 48802873492240906 fc2 weight 0 3371845781803131 0 3348560333251953 fc2 bias 0 2576610743999481 weights1 0 6804810762405396 0 2847576439380646 0 2714926302433014 0 18927520513534546 bias1 0 24895617365837097 0 48802873492240906 weights2 0 3371845781803131 0 3348560333251953 bias2 0 2576610743999481 const weights1 0 6804810762405396 0 2847576439380646 0 2714926302433014 0 18927520513534546 const bias1 0 24895617365837097 0 48802873492240906 const weights2 0 3371845781803131 0 3348560333251953 const bias2 0 2576610743999481 const context await navigator ml createContext const builder new MLGraphBuilder context Define WebNN network Define input placeholders const inputShape 1 2 Input shape should be 1 2 for 2 input neurons as it s a 1d array const inputType float32 const outputShape 1 1 Output shape 1 for single output const input builder input input dataType inputType shape inputShape Create weights and biases as constants const W1 builder constant dataType inputType shape 2 2 new Float32Array weights1 const b1 builder constant dataType inputType shape 1 2 new Float32Array bias1 const W2 builder constant dataType inputType shape 2 1 new Float32Array weights2 const b2 builder constant dataType inputType shape 1 1 new Float32Array bias2 Create the network const hiddenLayer builder sigmoid builder add builder matmul input W1 b1 const outputLayer builder sigmoid builder add builder matmul hiddenLayer W2 b2 Build the graph const graph await builder build output outputLayer Create reusable tensors for inputs and output const inputTensor await context createTensor dataType inputType shape inputShape writable true const outputTensor await context createTensor dataType inputType shape outputShape readable true Define different input values for testing const inputValuesList new Float32Array 0 0 0 0 new Float32Array 0 0 1 0 new Float32Array 1 0 0 0 new Float32Array 1 0 1 0 Execute the graph with different input values for const inputValues of inputValuesList Write the input values to the tensor await context writeTensor inputTensor inputValues Execute the graph const inputs input inputTensor const outputs output outputTensor await context dispatch graph inputs outputs Read and print the result const result await context readTensor outputTensor const out new Float32Array result console log Input Object values inputValues Output Object values out