MCPcopy
hub / github.com/zai-org/CogVideo / _build_modules

Method _build_modules

sat/dit_video_concat.py:670–773  ·  view source on GitHub ↗
(self, module_configs)

Source from the content-addressed store, hash-verified

668 )
669
670 def _build_modules(self, module_configs):
671 model_channels = self.hidden_size
672 # time_embed_dim = model_channels * 4
673 time_embed_dim = self.time_embed_dim
674 self.time_embed = nn.Sequential(
675 linear(model_channels, time_embed_dim),
676 nn.SiLU(),
677 linear(time_embed_dim, time_embed_dim),
678 )
679
680 if self.num_classes is not None:
681 if isinstance(self.num_classes, int):
682 self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
683 elif self.num_classes == "continuous":
684 print("setting up linear c_adm embedding layer")
685 self.label_emb = nn.Linear(1, time_embed_dim)
686 elif self.num_classes == "timestep":
687 self.label_emb = nn.Sequential(
688 Timestep(model_channels),
689 nn.Sequential(
690 linear(model_channels, time_embed_dim),
691 nn.SiLU(),
692 linear(time_embed_dim, time_embed_dim),
693 ),
694 )
695 elif self.num_classes == "sequential":
696 assert self.adm_in_channels is not None
697 self.label_emb = nn.Sequential(
698 nn.Sequential(
699 linear(self.adm_in_channels, time_embed_dim),
700 nn.SiLU(),
701 linear(time_embed_dim, time_embed_dim),
702 )
703 )
704 if self.zero_init_y_embed:
705 nn.init.constant_(self.label_emb[0][2].weight, 0)
706 nn.init.constant_(self.label_emb[0][2].bias, 0)
707 else:
708 raise ValueError()
709
710 pos_embed_config = module_configs["pos_embed_config"]
711 self.add_mixin(
712 "pos_embed",
713 instantiate_from_config(
714 pos_embed_config,
715 height=self.latent_height // self.patch_size,
716 width=self.latent_width // self.patch_size,
717 compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
718 hidden_size=self.hidden_size,
719 ),
720 reinit=True,
721 )
722
723 patch_embed_config = module_configs["patch_embed_config"]
724 self.add_mixin(
725 "patch_embed",
726 instantiate_from_config(
727 patch_embed_config,

Callers 1

__init__Method · 0.95

Calls 3

linearFunction · 0.90
TimestepClass · 0.90
instantiate_from_configFunction · 0.90

Tested by

no test coverage detected