* Train the discriminator for one step. * * In this step, only the weights of the discriminator are updated. The * generator is not involved. * * The following steps are involved: * * - Slice the training features and to get batch of real data. * - Generate a random latent-space vector a
(
xTrain, yTrain, batchStart, batchSize, latentSize, generator,
discriminator)
| 296 | * @returns {number[]} The loss values from the one-step training as numbers. |
| 297 | */ |
| 298 | async function trainDiscriminatorOneStep( |
| 299 | xTrain, yTrain, batchStart, batchSize, latentSize, generator, |
| 300 | discriminator) { |
| 301 | // TODO(cais): Remove tidy() once the current memory leak issue in tfjs-node |
| 302 | // and tfjs-node-gpu is fixed. |
| 303 | const [x, y, auxY] = tf.tidy(() => { |
| 304 | const imageBatch = xTrain.slice(batchStart, batchSize); |
| 305 | const labelBatch = yTrain.slice(batchStart, batchSize).asType('float32'); |
| 306 | |
| 307 | // Latent vectors. |
| 308 | let zVectors = tf.randomUniform([batchSize, latentSize], -1, 1); |
| 309 | let sampledLabels = |
| 310 | tf.randomUniform([batchSize, 1], 0, NUM_CLASSES, 'int32') |
| 311 | .asType('float32'); |
| 312 | |
| 313 | const generatedImages = |
| 314 | generator.predict([zVectors, sampledLabels], {batchSize: batchSize}); |
| 315 | |
| 316 | const x = tf.concat([imageBatch, generatedImages], 0); |
| 317 | |
| 318 | const y = tf.tidy( |
| 319 | () => tf.concat( |
| 320 | [tf.ones([batchSize, 1]).mul(SOFT_ONE), tf.zeros([batchSize, 1])])); |
| 321 | |
| 322 | const auxY = tf.concat([labelBatch, sampledLabels], 0); |
| 323 | return [x, y, auxY]; |
| 324 | }); |
| 325 | |
| 326 | const losses = await discriminator.trainOnBatch(x, [y, auxY]); |
| 327 | tf.dispose([x, y, auxY]); |
| 328 | return losses; |
| 329 | } |
| 330 | |
| 331 | /** |
| 332 | * Train the combined ACGAN for one step. |