(self, x: torch.Tensor)
| 91 | assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" |
| 92 | |
| 93 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 94 | if hasattr(self, "conv"): |
| 95 | x = self.conv(x) |
| 96 | return pixel_shuffle_3d(x, 2) |
| 97 | else: |
| 98 | return F.interpolate(x, scale_factor=2, mode="nearest") |
| 99 | |
| 100 | |
| 101 | class SparseStructureEncoder(nn.Module): |
nothing calls this directly
no test coverage detected