MCPcopy
hub / github.com/tensorflow/tfjs / next

Method next

tfjs-layers/src/engine/dataset_fakes.ts:115–185  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

113 }
114
115 async next(): Promise<IteratorResult<FitDatasetElement>> {
116 const done = ++this.batchCount > this.numBatches;
117 if (done) {
118 return {done, value: null};
119 }
120 if (this.xTensorsFunc == null) {
121 // Generate data randomly.
122 return {
123 done,
124 value: done ? null : {
125 xs: generateRandomTensorContainer(this.xBatchShape),
126 ys: generateRandomTensorContainer(this.yBatchShape)
127 }
128 };
129 } else {
130 // Use preset tensors.
131 if ((this.batchCount - 1) % this.numBatches === 0) {
132 this.xTensorValues = this.xTensorsFunc();
133 this.yTensorValues = this.yTensorsFunc();
134 this.tensorIndex = 0;
135 }
136 const index = this.tensorIndex++;
137
138 let xs: tfc.Tensor|{[name: string]: tfc.Tensor};
139 if (Array.isArray(this.xTensorValues)) {
140 xs = this.xTensorValues[index];
141 tfc.util.assert(
142 tfc.util.arraysEqual(xs.shape, this.xBatchShape as Shape),
143 () => `Shape mismatch: expected: ${
144 JSON.stringify(this.xBatchShape)}; ` +
145 `actual: ${JSON.stringify((xs as tfc.Tensor).shape)}`);
146 } else {
147 xs = {};
148 for (const key in this.xTensorValues) {
149 xs[key] = this.xTensorValues[key][index];
150 tfc.util.assert(
151 tfc.util.arraysEqual(xs[key].shape, this.xBatchShape as Shape),
152 () => `Shape mismatch: expected: ${
153 JSON.stringify(this.xBatchShape)}; ` +
154 `actual: ${JSON.stringify((xs as tfc.Tensor).shape)}`);
155 }
156 }
157
158 let ys: tfc.Tensor|{[name: string]: tfc.Tensor};
159 if (Array.isArray(this.yTensorValues)) {
160 // Get preset ys tensors for single-output models.
161 ys = this.yTensorValues[index];
162 tfc.util.assert(
163 tfc.util.arraysEqual(ys.shape, this.yBatchShape as Shape),
164 () => `Shape mismatch: expected: ${
165 JSON.stringify(this.yBatchShape)}; ` +
166 `actual: ${JSON.stringify((ys as tfc.Tensor).shape)}`);
167 } else {
168 // Get preset ys tensors for multi-output models.
169 ys = {};
170 this.yBatchShape = this.yBatchShape as {[name: string]: Shape};
171 for (const key in this.yTensorValues) {
172 ys[key] = this.yTensorValues[key][index];

Callers 4

nextMethod · 0.45
fitDatasetFunction · 0.45
evaluateDatasetFunction · 0.45

Calls 1

Tested by

no test coverage detected