| 371 | self.apply(_basic_init) |
| 372 | |
| 373 | def forward(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False): |
| 374 | h = self.input_layer(x) |
| 375 | h = h.type(self.dtype) |
| 376 | for i, res in enumerate(self.blocks): |
| 377 | for j, block in enumerate(res): |
| 378 | h = block(h) |
| 379 | h = h.type(x.dtype) |
| 380 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) |
| 381 | h = self.to_latent(h) |
| 382 | |
| 383 | # Sample from the posterior distribution |
| 384 | mean, logvar = h.feats.chunk(2, dim=-1) |
| 385 | if sample_posterior: |
| 386 | std = torch.exp(0.5 * logvar) |
| 387 | z = mean + std * torch.randn_like(std) |
| 388 | else: |
| 389 | z = mean |
| 390 | z = h.replace(z) |
| 391 | |
| 392 | if return_raw: |
| 393 | return z, mean, logvar |
| 394 | else: |
| 395 | return z |
| 396 | |
| 397 | |
| 398 | class SparseUnetVaeDecoder(nn.Module): |