MCPcopy
hub / github.com/tensorflow/tfjs-examples / trainDiscriminatorOneStep

Function trainDiscriminatorOneStep

mnist-acgan/gan.js:298–329  ·  view source on GitHub ↗

* Train the discriminator for one step. * * In this step, only the weights of the discriminator are updated. The * generator is not involved. * * The following steps are involved: * * - Slice the training features and to get batch of real data. * - Generate a random latent-space vector a

(
    xTrain, yTrain, batchStart, batchSize, latentSize, generator,
    discriminator)

Source from the content-addressed store, hash-verified

296 * @returns {number[]} The loss values from the one-step training as numbers.
297 */
298async function trainDiscriminatorOneStep(
299 xTrain, yTrain, batchStart, batchSize, latentSize, generator,
300 discriminator) {
301 // TODO(cais): Remove tidy() once the current memory leak issue in tfjs-node
302 // and tfjs-node-gpu is fixed.
303 const [x, y, auxY] = tf.tidy(() => {
304 const imageBatch = xTrain.slice(batchStart, batchSize);
305 const labelBatch = yTrain.slice(batchStart, batchSize).asType('float32');
306
307 // Latent vectors.
308 let zVectors = tf.randomUniform([batchSize, latentSize], -1, 1);
309 let sampledLabels =
310 tf.randomUniform([batchSize, 1], 0, NUM_CLASSES, 'int32')
311 .asType('float32');
312
313 const generatedImages =
314 generator.predict([zVectors, sampledLabels], {batchSize: batchSize});
315
316 const x = tf.concat([imageBatch, generatedImages], 0);
317
318 const y = tf.tidy(
319 () => tf.concat(
320 [tf.ones([batchSize, 1]).mul(SOFT_ONE), tf.zeros([batchSize, 1])]));
321
322 const auxY = tf.concat([labelBatch, sampledLabels], 0);
323 return [x, y, auxY];
324 });
325
326 const losses = await discriminator.trainOnBatch(x, [y, auxY]);
327 tf.dispose([x, y, auxY]);
328 return losses;
329}
330
331/**
332 * Train the combined ACGAN for one step.

Callers 1

runFunction · 0.85

Calls 1

predictMethod · 0.45

Tested by

no test coverage detected