Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs.
(self, x, timesteps, y=None)
| 632 | self.output_blocks.apply(convert_module_to_f32) |
| 633 | |
| 634 | def forward(self, x, timesteps, y=None): |
| 635 | """ |
| 636 | Apply the model to an input batch. |
| 637 | |
| 638 | :param x: an [N x C x ...] Tensor of inputs. |
| 639 | :param timesteps: a 1-D batch of timesteps. |
| 640 | :param y: an [N] Tensor of labels, if class-conditional. |
| 641 | :return: an [N x C x ...] Tensor of outputs. |
| 642 | """ |
| 643 | assert (y is not None) == ( |
| 644 | self.num_classes is not None |
| 645 | ), "must specify y if and only if the model is class-conditional" |
| 646 | |
| 647 | hs = [] |
| 648 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) |
| 649 | |
| 650 | if self.num_classes is not None: |
| 651 | assert y.shape == (x.shape[0],) |
| 652 | emb = emb + self.label_emb(y) |
| 653 | |
| 654 | h = x.type(self.dtype) |
| 655 | for module in self.input_blocks: |
| 656 | h = module(h, emb) |
| 657 | hs.append(h) |
| 658 | h = self.middle_block(h, emb) |
| 659 | for module in self.output_blocks: |
| 660 | h = th.cat([h, hs.pop()], dim=1) |
| 661 | h = module(h, emb) |
| 662 | h = h.type(x.dtype) |
| 663 | return self.out(h) |
| 664 | |
| 665 | |
| 666 | class SuperResModel(UNetModel): |
nothing calls this directly
no test coverage detected