MCPcopy
hub / github.com/tensorflow/tfjs / load

Method load

tfjs-vis/demos/mnist/data.js:44–104  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

42 }
43
44 async load() {
45 // Make a request for the MNIST sprited image.
46 const img = new Image();
47 const canvas = document.createElement('canvas');
48 const ctx = canvas.getContext('2d');
49 const imgRequest = new Promise((resolve, reject) => {
50 img.crossOrigin = '';
51 img.onload = () => {
52 img.width = img.naturalWidth;
53 img.height = img.naturalHeight;
54
55 const datasetBytesBuffer =
56 new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
57
58 const chunkSize = 5000;
59 canvas.width = img.width;
60 canvas.height = chunkSize;
61
62 for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
63 const datasetBytesView = new Float32Array(
64 datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
65 IMAGE_SIZE * chunkSize);
66 ctx.drawImage(
67 img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
68 chunkSize);
69
70 const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
71
72 for (let j = 0; j < imageData.data.length / 4; j++) {
73 // All channels hold an equal value since the image is grayscale, so
74 // just read the red channel.
75 datasetBytesView[j] = imageData.data[j * 4] / 255;
76 }
77 }
78 this.datasetImages = new Float32Array(datasetBytesBuffer);
79
80 resolve();
81 };
82 img.src = MNIST_IMAGES_SPRITE_PATH;
83 });
84
85 const labelsRequest = fetch(MNIST_LABELS_PATH);
86 const [imgResponse, labelsResponse] =
87 await Promise.all([imgRequest, labelsRequest]);
88
89 this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
90
91 // Create shuffled indices into the train/test set for when we select a
92 // random dataset element for training / validation.
93 this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
94 this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);
95
96 // Slice the the images and labels into train and test sets.
97 this.trainImages =
98 this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
99 this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
100 this.trainLabels =
101 this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);

Callers 1

loadDataFunction · 0.95

Calls 5

allMethod · 0.80
sliceMethod · 0.65
fetchFunction · 0.50
getContextMethod · 0.45
getImageDataMethod · 0.45

Tested by

no test coverage detected