MCPcopy
hub / github.com/zju3dv/4K4D / main

Function main

scripts/plenoctree/convert_nerf_to_plenoctree.py:48–180  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

46
47@torch.no_grad()
48def main():
49 cprint('output: ', color='red', attrs=['bold'], end='')
50 log(cfg.octree_path, 'yellow')
51
52 log('making network')
53 network: Network = make_network(cfg)
54 network = network.cuda()
55 network = network.eval()
56 epoch = load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch) # load network into `network`
57
58 log('getting dataset sample')
59 dataset: Dataset = make_dataset(cfg, is_train=False)
60 batch = to_tensor(dataset[cfg.octree.latent_index * dataset.num_cams]) # get a sample dataset, no batch dimension
61 tbounds = batch.tbounds.cuda() # 2, 3
62 tverts = batch.tverts.cuda() # 6890, 3
63 center = (tbounds[1] + tbounds[0]) / 2 # 3,
64 radius = (tbounds[1] - tbounds[0]) / 2
65 # if cfg.octree.preserve_aspect_ratio:
66 # radius[:] = radius.max() # remove fancy scaling and leave samples where it's important
67
68 log('creating octree structure')
69 sh_dim = (cfg.octree.sh_deg + 1) ** 2
70 data_dim = cfg.octree.channel * sh_dim + 1
71 data_format = f'SH{sh_dim}'
72 init_reserve = (2 ** cfg.octree.grid_depth) ** 2 # initially only reser on slab
73 octree = N3Tree(N=cfg.octree.branch_factor,
74 data_dim=data_dim,
75 init_reserve=init_reserve,
76 depth_limit=cfg.octree.grid_depth,
77 radius=radius,
78 center=center,
79 data_format=data_format,
80 device='cuda')
81
82 log("performing grid evaluation on sigma and dists")
83 # create grid on cpu (too large to fit in VRAM)
84 reso: torch.Tensor = 2 ** (cfg.octree.grid_depth + 1)
85 offset = octree.offset.cpu()
86 scale = octree.invradius.cpu()
87 arr = (torch.arange(0, reso, dtype=torch.float32, device='cpu') + 0.5) / reso # reso grid
88 xx = (arr - offset[0]) / scale[0]
89 yy = (arr - offset[1]) / scale[1]
90 zz = (arr - offset[2]) / scale[2]
91 # the grid might be too large to fit in the graphical memory
92 grid = torch.stack(torch.meshgrid(xx, yy, zz)).reshape(3, -1).T # 3, reso, reso, reso -> reso**3, 3
93 log(f"initial grid shape: {grid.shape}")
94
95 chunk_size = cfg.octree.chunk_size
96 # filter by distance to smpl
97 dists = []
98 for i in tqdm(range(0, grid.shape[0], chunk_size)):
99 grid_chunk = grid[i:i+chunk_size].cuda()
100 dists_chunk = sample_K_closest_points(grid_chunk[None], tverts[None], K=1)[0][0, :, 0] # add and remove fake batch dimension
101 dists.append(dists_chunk) # this value will be later discarded
102 dists = torch.cat(dists, dim=0) # reso ** 3,
103 mask = dists < cfg.dist_th
104 grid = grid[mask] # first filter by smpl dists, might still be large
105

Callers 1

Calls 12

logFunction · 0.90
load_networkFunction · 0.90
to_tensorFunction · 0.90
sample_K_closest_pointsFunction · 0.90
add_batchFunction · 0.90
to_cudaFunction · 0.90
project_nerf_to_shFunction · 0.85
cudaMethod · 0.80
cpuMethod · 0.80
saveMethod · 0.80
occMethod · 0.45
sampleMethod · 0.45

Tested by

no test coverage detected