Generate & viz samples
(dataset, params, folder)
| 617 | |
| 618 | |
| 619 | def viz_dataset(dataset, params, folder): |
| 620 | """ Generate & viz samples """ |
| 621 | print("Visualization of the dataset") |
| 622 | |
| 623 | nspa = params["num_samples_per_action"] |
| 624 | nats = params["num_actions_to_sample"] |
| 625 | |
| 626 | num_classes = params["num_classes"] |
| 627 | |
| 628 | figname = "{}_{}_numframes_{}_sampling_{}_step_{}".format( |
| 629 | params["dataset"], params["pose_rep"], params["num_frames"], |
| 630 | params["sampling"], params["sampling_step"]) |
| 631 | |
| 632 | # define some classes |
| 633 | classes = torch.randperm(num_classes)[:nats] |
| 634 | |
| 635 | allclasses = classes.repeat(nspa, 1).reshape(nspa * nats) |
| 636 | # extract the real samples |
| 637 | real_samples, mask_real, real_lengths = dataset.get_label_sample_batch( |
| 638 | allclasses.numpy()) |
| 639 | # to visualize directly |
| 640 | |
| 641 | # Visualizaion of real samples |
| 642 | visualization = { |
| 643 | "x": real_samples, |
| 644 | "y": allclasses, |
| 645 | "mask": mask_real, |
| 646 | 'lengths': real_lengths, |
| 647 | "output": real_samples |
| 648 | } |
| 649 | |
| 650 | from mGPT.models.rotation2xyz import Rotation2xyz |
| 651 | |
| 652 | device = params["device"] |
| 653 | rot2xyz = Rotation2xyz(device=device) |
| 654 | |
| 655 | rot2xyz_params = { |
| 656 | "pose_rep": params["pose_rep"], |
| 657 | "glob_rot": params["glob_rot"], |
| 658 | "glob": params["glob"], |
| 659 | "jointstype": params["jointstype"], |
| 660 | "translation": params["translation"] |
| 661 | } |
| 662 | |
| 663 | output = visualization["output"] |
| 664 | visualization["output_xyz"] = rot2xyz(output.to(device), |
| 665 | visualization["mask"].to(device), |
| 666 | **rot2xyz_params) |
| 667 | |
| 668 | for key, val in visualization.items(): |
| 669 | if len(visualization[key].shape) == 1: |
| 670 | visualization[key] = val.reshape(nspa, nats) |
| 671 | else: |
| 672 | visualization[key] = val.reshape(nspa, nats, *val.shape[1:]) |
| 673 | |
| 674 | finalpath = os.path.join(folder, figname + ".gif") |
| 675 | tmp_path = os.path.join(folder, f"subfigures_{figname}") |
| 676 | os.makedirs(tmp_path, exist_ok=True) |
nothing calls this directly
no test coverage detected