()
| 39 | constructor() {} |
| 40 | |
| 41 | async load() { |
| 42 | // Make a request for the MNIST sprited image. |
| 43 | const img = new Image(); |
| 44 | const canvas = document.createElement('canvas'); |
| 45 | const ctx = canvas.getContext('2d'); |
| 46 | const imgRequest = new Promise((resolve, reject) => { |
| 47 | img.crossOrigin = ''; |
| 48 | img.onload = () => { |
| 49 | img.width = img.naturalWidth; |
| 50 | img.height = img.naturalHeight; |
| 51 | |
| 52 | const datasetBytesBuffer = |
| 53 | new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); |
| 54 | |
| 55 | const chunkSize = 5000; |
| 56 | canvas.width = img.width; |
| 57 | canvas.height = chunkSize; |
| 58 | |
| 59 | for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { |
| 60 | const datasetBytesView = new Float32Array( |
| 61 | datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, |
| 62 | IMAGE_SIZE * chunkSize); |
| 63 | ctx.drawImage( |
| 64 | img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, |
| 65 | chunkSize); |
| 66 | |
| 67 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); |
| 68 | |
| 69 | for (let j = 0; j < imageData.data.length / 4; j++) { |
| 70 | // All channels hold an equal value since the image is grayscale, so |
| 71 | // just read the red channel. |
| 72 | datasetBytesView[j] = imageData.data[j * 4] / 255; |
| 73 | } |
| 74 | } |
| 75 | this.datasetImages = new Float32Array(datasetBytesBuffer); |
| 76 | |
| 77 | resolve(); |
| 78 | }; |
| 79 | img.src = MNIST_IMAGES_SPRITE_PATH; |
| 80 | }); |
| 81 | |
| 82 | const labelsRequest = fetch(MNIST_LABELS_PATH); |
| 83 | const [imgResponse, labelsResponse] = |
| 84 | await Promise.all([imgRequest, labelsRequest]); |
| 85 | |
| 86 | this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); |
| 87 | |
| 88 | // Slice the the images and labels into train and test sets. |
| 89 | this.trainImages = |
| 90 | this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); |
| 91 | this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); |
| 92 | this.trainLabels = |
| 93 | this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); |
| 94 | this.testLabels = |
| 95 | this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); |
| 96 | } |
| 97 | |
| 98 | /** |
no outgoing calls