MCPcopy
hub / github.com/tensorflow/tfjs-examples / getNextBatchFunction

Method getNextBatchFunction

jena-weather/data.js:272–342  ·  view source on GitHub ↗

* Get a data iterator function. * * @param {boolean} shuffle Whether the data is to be shuffled. If `false`, * the examples generated by repeated calling of the returned iterator * function will scan through range specified by `minIndex` and `maxIndex` * (or the entire range of

(
      shuffle, lookBack, delay, batchSize, step, minIndex, maxIndex, normalize,
      includeDateTime)

Source from the content-addressed store, hash-verified

270 * `[batchSize, 1]`.
271 */
272 getNextBatchFunction(
273 shuffle, lookBack, delay, batchSize, step, minIndex, maxIndex, normalize,
274 includeDateTime) {
275 let startIndex = minIndex + lookBack;
276 const lookBackSlices = Math.floor(lookBack / step);
277
278 return {
279 next: () => {
280 const rowIndices = [];
281 let done = false; // Indicates whether the dataset has ended.
282 if (shuffle) {
283 // If `shuffle` is `true`, start from randomly chosen rows.
284 const range = maxIndex - (minIndex + lookBack);
285 for (let i = 0; i < batchSize; ++i) {
286 const row = minIndex + lookBack + Math.floor(Math.random() * range);
287 rowIndices.push(row);
288 }
289 } else {
290 // If `shuffle` is `false`, the starting row indices will be sequential.
291 let r = startIndex;
292 for (; r < startIndex + batchSize && r < maxIndex; ++r) {
293 rowIndices.push(r);
294 }
295 if (r >= maxIndex) {
296 done = true;
297 }
298 }
299
300 const numExamples = rowIndices.length;
301 startIndex += numExamples;
302
303 const featureLength =
304 includeDateTime ? this.numColumns + 2 : this.numColumns;
305 const samples = tf.buffer([numExamples, lookBackSlices, featureLength]);
306 const targets = tf.buffer([numExamples, 1]);
307 // Iterate over examples. Each example contains a number of rows.
308 for (let j = 0; j < numExamples; ++j) {
309 const rowIndex = rowIndices[j];
310 let exampleRow = 0;
311 // Iterate over rows in the example.
312 for (let r = rowIndex - lookBack; r < rowIndex; r += step) {
313 let exampleCol = 0;
314 // Iterate over features in the row.
315 for (let n = 0; n < featureLength; ++n) {
316 let value;
317 if (n < this.numColumns) {
318 value = normalize ? this.normalizedData[r][n] : this.data[r][n];
319 } else if (n === this.numColumns) {
320 // Normalized day-of-the-year feature.
321 value = this.normalizedDayOfYear[r];
322 } else {
323 // Normalized time-of-the-day feature.
324 value = this.normalizedTimeOfDay[r];
325 }
326 samples.set(value, j, exampleRow, exampleCol++);
327 }
328
329 const value = normalize ?

Callers 3

trainModelFunction · 0.80
data_test.jsFile · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected