()
| 42 | } |
| 43 | |
| 44 | async load() { |
| 45 | // Make a request for the MNIST sprited image. |
| 46 | const img = new Image(); |
| 47 | const canvas = document.createElement('canvas'); |
| 48 | const ctx = canvas.getContext('2d'); |
| 49 | const imgRequest = new Promise((resolve, reject) => { |
| 50 | img.crossOrigin = ''; |
| 51 | img.onload = () => { |
| 52 | img.width = img.naturalWidth; |
| 53 | img.height = img.naturalHeight; |
| 54 | |
| 55 | const datasetBytesBuffer = |
| 56 | new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); |
| 57 | |
| 58 | const chunkSize = 5000; |
| 59 | canvas.width = img.width; |
| 60 | canvas.height = chunkSize; |
| 61 | |
| 62 | for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { |
| 63 | const datasetBytesView = new Float32Array( |
| 64 | datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, |
| 65 | IMAGE_SIZE * chunkSize); |
| 66 | ctx.drawImage( |
| 67 | img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, |
| 68 | chunkSize); |
| 69 | |
| 70 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); |
| 71 | |
| 72 | for (let j = 0; j < imageData.data.length / 4; j++) { |
| 73 | // All channels hold an equal value since the image is grayscale, so |
| 74 | // just read the red channel. |
| 75 | datasetBytesView[j] = imageData.data[j * 4] / 255; |
| 76 | } |
| 77 | } |
| 78 | this.datasetImages = new Float32Array(datasetBytesBuffer); |
| 79 | |
| 80 | resolve(); |
| 81 | }; |
| 82 | img.src = MNIST_IMAGES_SPRITE_PATH; |
| 83 | }); |
| 84 | |
| 85 | const labelsRequest = fetch(MNIST_LABELS_PATH); |
| 86 | const [imgResponse, labelsResponse] = |
| 87 | await Promise.all([imgRequest, labelsRequest]); |
| 88 | |
| 89 | this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); |
| 90 | |
| 91 | // Create shuffled indices into the train/test set for when we select a |
| 92 | // random dataset element for training / validation. |
| 93 | this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS); |
| 94 | this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS); |
| 95 | |
| 96 | // Slice the the images and labels into train and test sets. |
| 97 | this.trainImages = |
| 98 | this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); |
| 99 | this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); |
| 100 | this.trainLabels = |
| 101 | this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); |
no test coverage detected