MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / viz_epoch

Function viz_epoch

mGPT/render/visualize.py:352–616  ·  view source on GitHub ↗

Generate & viz samples

(model,
              dataset,
              epoch,
              params,
              folder,
              module=None,
              writer=None,
              exps='')

Source from the content-addressed store, hash-verified

350
351
352def viz_epoch(model,
353 dataset,
354 epoch,
355 params,
356 folder,
357 module=None,
358 writer=None,
359 exps=''):
360 """ Generate & viz samples """
361 module = model if module is None else module
362
363 # visualize with joints3D
364 model.outputxyz = True
365
366 print(f"Visualization of the epoch {epoch}")
367
368 noise_same_action = params["noise_same_action"]
369 noise_diff_action = params["noise_diff_action"]
370 duration_mode = params["duration_mode"]
371 reconstruction_mode = params["reconstruction_mode"]
372 decoder_test = params["decoder_test"]
373
374 fact = params["fact_latent"]
375 figname = params["figname"].format(epoch)
376
377 nspa = params["num_samples_per_action"]
378 nats = params["num_actions_to_sample"]
379
380 num_classes = params["num_classes"]
381 # nats = min(num_classes, nats)
382
383 # define some classes
384 classes = torch.randperm(num_classes)[:nats]
385 # duplicate same classes when sampling too much
386 if nats > num_classes:
387 classes = classes.expand(nats)
388
389 meandurations = torch.from_numpy(
390 np.array([
391 round(dataset.get_mean_length_label(cl.item())) for cl in classes
392 ]))
393
394 if duration_mode == "interpolate" or decoder_test == "diffduration":
395 points, step = np.linspace(-nspa, nspa, nspa, retstep=True)
396 # points = np.round(10*points/step).astype(int)
397 points = np.array([5, 10, 16, 30, 60, 80]).astype(int)
398 # gendurations = meandurations.repeat((nspa, 1)) + points[:, None]
399 gendurations = torch.from_numpy(points[:, None]).expand(
400 (nspa, 1)).repeat((1, nats))
401 else:
402 gendurations = meandurations.repeat((nspa, 1))
403 print("Duration time: ")
404 print(gendurations[:, 0])
405
406 # extract the real samples
407 # real_samples, real_theta, mask_real, real_lengths, imgs, paths
408 batch = dataset.get_label_sample_batch(classes.numpy())
409

Callers

nothing calls this directly

Calls 5

generate_by_videoFunction · 0.85
keysMethod · 0.80
itemsMethod · 0.80
saveMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected