MCPcopy
hub / github.com/yerfor/GeneFacePlusPlus / update_extra_state

Method update_extra_state

modules/radnerfs/renderer.py:202–284  ·  view source on GitHub ↗
(self, decay=0.95, S=128)

Source from the content-addressed store, hash-verified

200
201 @torch.no_grad()
202 def update_extra_state(self, decay=0.95, S=128):
203 # call before each epoch to update extra states.
204 if not self.cuda_ray:
205 return
206 # use random cond (different expressions should have similar density grid...)
207 rand_idx = random.randint(0, self.conds.shape[0] - 1)
208 cond = get_audio_features(self.conds, 2, rand_idx).to(self.density_bitfield.device)
209 if hparams.get("to_heatmap", False):
210 from tasks.radnerfs.dataset_utils import transform_normed_lm_to_pixel_lm, convert_to_tensor, get_ldm_heatmap
211 inds = torch.arange(512*512)
212 lm478 = cond.reshape([hparams['smo_win_size'], 478, 3])
213 lm478_px = convert_to_tensor(transform_normed_lm_to_pixel_lm(lm478)).reshape([hparams['smo_win_size'], 1, 478, 3])
214 heatmaps = []
215 for i in range(hparams['smo_win_size']):
216 lm = lm478_px[i:i+1]
217 # heatmap = get_ldm_heatmap([512,512], lm.cuda()).cpu() # [1, 1, W, H, C=478]
218 heatmap = get_ldm_heatmap([512,512], lm.cuda()) # [1, 1, W, H, C=478]
219 heatmaps.append(heatmap.reshape([1, 512, 512, 478]))
220 heatmaps = torch.stack(heatmaps) # [5,1,W,H,C]
221 heatmaps = heatmaps.reshape([hparams['smo_win_size'], 1, 512*512, 478])
222 cond = heatmaps # [:, :, :, :]
223 enc_a = self.cal_cond_feat(cond.to(self.density_bitfield.device))
224 del cond
225 torch.cuda.empty_cache()
226 # encode audio
227 else:
228 enc_a = self.cal_cond_feat(cond)
229
230 ### update density grid
231 tmp_grid = torch.zeros_like(self.density_grid)
232
233 # full update
234 X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
235 Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
236 Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
237
238 for xs in X:
239 for ys in Y:
240 for zs in Z:
241
242 # construct points
243 xx, yy, zz = custom_meshgrid(xs, ys, zs)
244 coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
245 indices = raymarching.morton3D(coords).long() # [N]
246 xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
247
248 # cascading
249 if hparams.get("to_heatmap", False) is True:
250 enc_a_ = enc_a[torch.randint(low=0, high=len(enc_a), size=(len(xyzs),))]
251
252 for cas in range(self.cascade):
253 bound = min(2 ** cas, self.bound)
254 half_grid_size = bound / self.grid_size
255 # scale to current cascade's resolution
256 cas_xyzs = xyzs * (bound - half_grid_size)
257 # add noise in [-hgs, hgs]
258 cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
259 # query density

Callers 4

_training_stepMethod · 0.45
_training_stepMethod · 0.45
_training_stepMethod · 0.45
_training_stepMethod · 0.45

Calls 10

cal_cond_featMethod · 0.95
densityMethod · 0.95
get_audio_featuresFunction · 0.90
get_ldm_heatmapFunction · 0.90
custom_meshgridFunction · 0.90
convert_to_tensorFunction · 0.85
appendMethod · 0.80
meanMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected