(self, module_configs)
| 668 | ) |
| 669 | |
| 670 | def _build_modules(self, module_configs): |
| 671 | model_channels = self.hidden_size |
| 672 | # time_embed_dim = model_channels * 4 |
| 673 | time_embed_dim = self.time_embed_dim |
| 674 | self.time_embed = nn.Sequential( |
| 675 | linear(model_channels, time_embed_dim), |
| 676 | nn.SiLU(), |
| 677 | linear(time_embed_dim, time_embed_dim), |
| 678 | ) |
| 679 | |
| 680 | if self.num_classes is not None: |
| 681 | if isinstance(self.num_classes, int): |
| 682 | self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) |
| 683 | elif self.num_classes == "continuous": |
| 684 | print("setting up linear c_adm embedding layer") |
| 685 | self.label_emb = nn.Linear(1, time_embed_dim) |
| 686 | elif self.num_classes == "timestep": |
| 687 | self.label_emb = nn.Sequential( |
| 688 | Timestep(model_channels), |
| 689 | nn.Sequential( |
| 690 | linear(model_channels, time_embed_dim), |
| 691 | nn.SiLU(), |
| 692 | linear(time_embed_dim, time_embed_dim), |
| 693 | ), |
| 694 | ) |
| 695 | elif self.num_classes == "sequential": |
| 696 | assert self.adm_in_channels is not None |
| 697 | self.label_emb = nn.Sequential( |
| 698 | nn.Sequential( |
| 699 | linear(self.adm_in_channels, time_embed_dim), |
| 700 | nn.SiLU(), |
| 701 | linear(time_embed_dim, time_embed_dim), |
| 702 | ) |
| 703 | ) |
| 704 | if self.zero_init_y_embed: |
| 705 | nn.init.constant_(self.label_emb[0][2].weight, 0) |
| 706 | nn.init.constant_(self.label_emb[0][2].bias, 0) |
| 707 | else: |
| 708 | raise ValueError() |
| 709 | |
| 710 | pos_embed_config = module_configs["pos_embed_config"] |
| 711 | self.add_mixin( |
| 712 | "pos_embed", |
| 713 | instantiate_from_config( |
| 714 | pos_embed_config, |
| 715 | height=self.latent_height // self.patch_size, |
| 716 | width=self.latent_width // self.patch_size, |
| 717 | compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, |
| 718 | hidden_size=self.hidden_size, |
| 719 | ), |
| 720 | reinit=True, |
| 721 | ) |
| 722 | |
| 723 | patch_embed_config = module_configs["patch_embed_config"] |
| 724 | self.add_mixin( |
| 725 | "patch_embed", |
| 726 | instantiate_from_config( |
| 727 | patch_embed_config, |
no test coverage detected