Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs.
(self, x, timesteps, y=None)
| 460 | return next(self.input_blocks.parameters()).dtype |
| 461 | |
| 462 | def forward(self, x, timesteps, y=None): |
| 463 | """ |
| 464 | Apply the model to an input batch. |
| 465 | |
| 466 | :param x: an [N x C x ...] Tensor of inputs. |
| 467 | :param timesteps: a 1-D batch of timesteps. |
| 468 | :param y: an [N] Tensor of labels, if class-conditional. |
| 469 | :return: an [N x C x ...] Tensor of outputs. |
| 470 | """ |
| 471 | assert (y is not None) == ( |
| 472 | self.num_classes is not None |
| 473 | ), "must specify y if and only if the model is class-conditional" |
| 474 | |
| 475 | hs = [] |
| 476 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) |
| 477 | |
| 478 | if self.num_classes is not None: |
| 479 | assert y.shape == (x.shape[0],) |
| 480 | emb = emb + self.label_emb(y) |
| 481 | |
| 482 | h = x.type(self.inner_dtype) |
| 483 | for module in self.input_blocks: |
| 484 | h = module(h, emb) |
| 485 | hs.append(h) |
| 486 | h = self.middle_block(h, emb) |
| 487 | for module in self.output_blocks: |
| 488 | cat_in = th.cat([h, hs.pop()], dim=1) |
| 489 | h = module(cat_in, emb) |
| 490 | h = h.type(x.dtype) |
| 491 | return self.out(h) |
| 492 | |
| 493 | def get_feature_vectors(self, x, timesteps, y=None): |
| 494 | """ |
nothing calls this directly
no test coverage detected