MCPcopy
hub / github.com/TencentARC/Pixal3D / forward

Method forward

pixal3d/models/sc_vaes/sparse_unet_vae.py:373–395  ·  view source on GitHub ↗
(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False)

Source from the content-addressed store, hash-verified

371 self.apply(_basic_init)
372
373 def forward(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False):
374 h = self.input_layer(x)
375 h = h.type(self.dtype)
376 for i, res in enumerate(self.blocks):
377 for j, block in enumerate(res):
378 h = block(h)
379 h = h.type(x.dtype)
380 h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
381 h = self.to_latent(h)
382
383 # Sample from the posterior distribution
384 mean, logvar = h.feats.chunk(2, dim=-1)
385 if sample_posterior:
386 std = torch.exp(0.5 * logvar)
387 z = mean + std * torch.randn_like(std)
388 else:
389 z = mean
390 z = h.replace(z)
391
392 if return_raw:
393 return z, mean, logvar
394 else:
395 return z
396
397
398class SparseUnetVaeDecoder(nn.Module):

Callers

nothing calls this directly

Calls 2

typeMethod · 0.45
replaceMethod · 0.45

Tested by

no test coverage detected