()
| 46 | |
| 47 | @torch.no_grad() |
| 48 | def main(): |
| 49 | cprint('output: ', color='red', attrs=['bold'], end='') |
| 50 | log(cfg.octree_path, 'yellow') |
| 51 | |
| 52 | log('making network') |
| 53 | network: Network = make_network(cfg) |
| 54 | network = network.cuda() |
| 55 | network = network.eval() |
| 56 | epoch = load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch) # load network into `network` |
| 57 | |
| 58 | log('getting dataset sample') |
| 59 | dataset: Dataset = make_dataset(cfg, is_train=False) |
| 60 | batch = to_tensor(dataset[cfg.octree.latent_index * dataset.num_cams]) # get a sample dataset, no batch dimension |
| 61 | tbounds = batch.tbounds.cuda() # 2, 3 |
| 62 | tverts = batch.tverts.cuda() # 6890, 3 |
| 63 | center = (tbounds[1] + tbounds[0]) / 2 # 3, |
| 64 | radius = (tbounds[1] - tbounds[0]) / 2 |
| 65 | # if cfg.octree.preserve_aspect_ratio: |
| 66 | # radius[:] = radius.max() # remove fancy scaling and leave samples where it's important |
| 67 | |
| 68 | log('creating octree structure') |
| 69 | sh_dim = (cfg.octree.sh_deg + 1) ** 2 |
| 70 | data_dim = cfg.octree.channel * sh_dim + 1 |
| 71 | data_format = f'SH{sh_dim}' |
| 72 | init_reserve = (2 ** cfg.octree.grid_depth) ** 2 # initially only reser on slab |
| 73 | octree = N3Tree(N=cfg.octree.branch_factor, |
| 74 | data_dim=data_dim, |
| 75 | init_reserve=init_reserve, |
| 76 | depth_limit=cfg.octree.grid_depth, |
| 77 | radius=radius, |
| 78 | center=center, |
| 79 | data_format=data_format, |
| 80 | device='cuda') |
| 81 | |
| 82 | log("performing grid evaluation on sigma and dists") |
| 83 | # create grid on cpu (too large to fit in VRAM) |
| 84 | reso: torch.Tensor = 2 ** (cfg.octree.grid_depth + 1) |
| 85 | offset = octree.offset.cpu() |
| 86 | scale = octree.invradius.cpu() |
| 87 | arr = (torch.arange(0, reso, dtype=torch.float32, device='cpu') + 0.5) / reso # reso grid |
| 88 | xx = (arr - offset[0]) / scale[0] |
| 89 | yy = (arr - offset[1]) / scale[1] |
| 90 | zz = (arr - offset[2]) / scale[2] |
| 91 | # the grid might be too large to fit in the graphical memory |
| 92 | grid = torch.stack(torch.meshgrid(xx, yy, zz)).reshape(3, -1).T # 3, reso, reso, reso -> reso**3, 3 |
| 93 | log(f"initial grid shape: {grid.shape}") |
| 94 | |
| 95 | chunk_size = cfg.octree.chunk_size |
| 96 | # filter by distance to smpl |
| 97 | dists = [] |
| 98 | for i in tqdm(range(0, grid.shape[0], chunk_size)): |
| 99 | grid_chunk = grid[i:i+chunk_size].cuda() |
| 100 | dists_chunk = sample_K_closest_points(grid_chunk[None], tverts[None], K=1)[0][0, :, 0] # add and remove fake batch dimension |
| 101 | dists.append(dists_chunk) # this value will be later discarded |
| 102 | dists = torch.cat(dists, dim=0) # reso ** 3, |
| 103 | mask = dists < cfg.dist_th |
| 104 | grid = grid[mask] # first filter by smpl dists, might still be large |
| 105 |
no test coverage detected