Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs.
(self, x, timesteps)
| 869 | self.middle_block.apply(convert_module_to_f32) |
| 870 | |
| 871 | def forward(self, x, timesteps): |
| 872 | """ |
| 873 | Apply the model to an input batch. |
| 874 | |
| 875 | :param x: an [N x C x ...] Tensor of inputs. |
| 876 | :param timesteps: a 1-D batch of timesteps. |
| 877 | :return: an [N x K] Tensor of outputs. |
| 878 | """ |
| 879 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) |
| 880 | |
| 881 | results = [] |
| 882 | h = x.type(self.dtype) |
| 883 | for module in self.input_blocks: |
| 884 | h = module(h, emb) |
| 885 | if self.pool.startswith("spatial"): |
| 886 | results.append(h.type(x.dtype).mean(dim=(2, 3))) |
| 887 | h = self.middle_block(h, emb) |
| 888 | if self.pool.startswith("spatial"): |
| 889 | results.append(h.type(x.dtype).mean(dim=(2, 3))) |
| 890 | h = th.cat(results, axis=-1) |
| 891 | return self.out(h) |
| 892 | else: |
| 893 | h = h.type(x.dtype) |
| 894 | return self.out(h) |
nothing calls this directly
no test coverage detected