(self, decay=0.95, S=128)
| 200 | |
| 201 | @torch.no_grad() |
| 202 | def update_extra_state(self, decay=0.95, S=128): |
| 203 | # call before each epoch to update extra states. |
| 204 | if not self.cuda_ray: |
| 205 | return |
| 206 | # use random cond (different expressions should have similar density grid...) |
| 207 | rand_idx = random.randint(0, self.conds.shape[0] - 1) |
| 208 | cond = get_audio_features(self.conds, 2, rand_idx).to(self.density_bitfield.device) |
| 209 | if hparams.get("to_heatmap", False): |
| 210 | from tasks.radnerfs.dataset_utils import transform_normed_lm_to_pixel_lm, convert_to_tensor, get_ldm_heatmap |
| 211 | inds = torch.arange(512*512) |
| 212 | lm478 = cond.reshape([hparams['smo_win_size'], 478, 3]) |
| 213 | lm478_px = convert_to_tensor(transform_normed_lm_to_pixel_lm(lm478)).reshape([hparams['smo_win_size'], 1, 478, 3]) |
| 214 | heatmaps = [] |
| 215 | for i in range(hparams['smo_win_size']): |
| 216 | lm = lm478_px[i:i+1] |
| 217 | # heatmap = get_ldm_heatmap([512,512], lm.cuda()).cpu() # [1, 1, W, H, C=478] |
| 218 | heatmap = get_ldm_heatmap([512,512], lm.cuda()) # [1, 1, W, H, C=478] |
| 219 | heatmaps.append(heatmap.reshape([1, 512, 512, 478])) |
| 220 | heatmaps = torch.stack(heatmaps) # [5,1,W,H,C] |
| 221 | heatmaps = heatmaps.reshape([hparams['smo_win_size'], 1, 512*512, 478]) |
| 222 | cond = heatmaps # [:, :, :, :] |
| 223 | enc_a = self.cal_cond_feat(cond.to(self.density_bitfield.device)) |
| 224 | del cond |
| 225 | torch.cuda.empty_cache() |
| 226 | # encode audio |
| 227 | else: |
| 228 | enc_a = self.cal_cond_feat(cond) |
| 229 | |
| 230 | ### update density grid |
| 231 | tmp_grid = torch.zeros_like(self.density_grid) |
| 232 | |
| 233 | # full update |
| 234 | X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) |
| 235 | Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) |
| 236 | Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) |
| 237 | |
| 238 | for xs in X: |
| 239 | for ys in Y: |
| 240 | for zs in Z: |
| 241 | |
| 242 | # construct points |
| 243 | xx, yy, zz = custom_meshgrid(xs, ys, zs) |
| 244 | coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) |
| 245 | indices = raymarching.morton3D(coords).long() # [N] |
| 246 | xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] |
| 247 | |
| 248 | # cascading |
| 249 | if hparams.get("to_heatmap", False) is True: |
| 250 | enc_a_ = enc_a[torch.randint(low=0, high=len(enc_a), size=(len(xyzs),))] |
| 251 | |
| 252 | for cas in range(self.cascade): |
| 253 | bound = min(2 ** cas, self.bound) |
| 254 | half_grid_size = bound / self.grid_size |
| 255 | # scale to current cascade's resolution |
| 256 | cas_xyzs = xyzs * (bound - half_grid_size) |
| 257 | # add noise in [-hgs, hgs] |
| 258 | cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size |
| 259 | # query density |
no test coverage detected