Generate & viz samples
(model,
dataset,
epoch,
params,
folder,
module=None,
writer=None,
exps='')
| 350 | |
| 351 | |
| 352 | def viz_epoch(model, |
| 353 | dataset, |
| 354 | epoch, |
| 355 | params, |
| 356 | folder, |
| 357 | module=None, |
| 358 | writer=None, |
| 359 | exps=''): |
| 360 | """ Generate & viz samples """ |
| 361 | module = model if module is None else module |
| 362 | |
| 363 | # visualize with joints3D |
| 364 | model.outputxyz = True |
| 365 | |
| 366 | print(f"Visualization of the epoch {epoch}") |
| 367 | |
| 368 | noise_same_action = params["noise_same_action"] |
| 369 | noise_diff_action = params["noise_diff_action"] |
| 370 | duration_mode = params["duration_mode"] |
| 371 | reconstruction_mode = params["reconstruction_mode"] |
| 372 | decoder_test = params["decoder_test"] |
| 373 | |
| 374 | fact = params["fact_latent"] |
| 375 | figname = params["figname"].format(epoch) |
| 376 | |
| 377 | nspa = params["num_samples_per_action"] |
| 378 | nats = params["num_actions_to_sample"] |
| 379 | |
| 380 | num_classes = params["num_classes"] |
| 381 | # nats = min(num_classes, nats) |
| 382 | |
| 383 | # define some classes |
| 384 | classes = torch.randperm(num_classes)[:nats] |
| 385 | # duplicate same classes when sampling too much |
| 386 | if nats > num_classes: |
| 387 | classes = classes.expand(nats) |
| 388 | |
| 389 | meandurations = torch.from_numpy( |
| 390 | np.array([ |
| 391 | round(dataset.get_mean_length_label(cl.item())) for cl in classes |
| 392 | ])) |
| 393 | |
| 394 | if duration_mode == "interpolate" or decoder_test == "diffduration": |
| 395 | points, step = np.linspace(-nspa, nspa, nspa, retstep=True) |
| 396 | # points = np.round(10*points/step).astype(int) |
| 397 | points = np.array([5, 10, 16, 30, 60, 80]).astype(int) |
| 398 | # gendurations = meandurations.repeat((nspa, 1)) + points[:, None] |
| 399 | gendurations = torch.from_numpy(points[:, None]).expand( |
| 400 | (nspa, 1)).repeat((1, nats)) |
| 401 | else: |
| 402 | gendurations = meandurations.repeat((nspa, 1)) |
| 403 | print("Duration time: ") |
| 404 | print(gendurations[:, 0]) |
| 405 | |
| 406 | # extract the real samples |
| 407 | # real_samples, real_theta, mask_real, real_lengths, imgs, paths |
| 408 | batch = dataset.get_label_sample_batch(classes.numpy()) |
| 409 |
nothing calls this directly
no test coverage detected