(self, x: torch.Tensor)
| 293 | self.middle_block.apply(convert_module_to_f32) |
| 294 | |
| 295 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 296 | h = self.input_layer(x) |
| 297 | |
| 298 | h = h.type(self.dtype) |
| 299 | |
| 300 | h = self.middle_block(h) |
| 301 | for block in self.blocks: |
| 302 | h = block(h) |
| 303 | |
| 304 | h = h.type(x.dtype) |
| 305 | h = self.out_layer(h) |
| 306 | return h |