Method
forward
(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False)
Source from the content-addressed store, hash-verified
| 184 | self.middle_block.apply(convert_module_to_f32) |
| 185 | |
| 186 | def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: |
| 187 | h = self.input_layer(x) |
| 188 | h = h.type(self.dtype) |
| 189 | |
| 190 | for block in self.blocks: |
| 191 | h = block(h) |
| 192 | h = self.middle_block(h) |
| 193 | |
| 194 | h = h.type(x.dtype) |
| 195 | h = self.out_layer(h) |
| 196 | |
| 197 | mean, logvar = h.chunk(2, dim=1) |
| 198 | |
| 199 | if sample_posterior: |
| 200 | std = torch.exp(0.5 * logvar) |
| 201 | z = mean + std * torch.randn_like(std) |
| 202 | else: |
| 203 | z = mean |
| 204 | |
| 205 | if return_raw: |
| 206 | return z, mean, logvar |
| 207 | return z |
| 208 | |
| 209 | |
| 210 | class SparseStructureDecoder(nn.Module): |
Callers
nothing calls this directly
Tested by
no test coverage detected