MCPcopy
hub / github.com/zai-org/CogVideo / __init__

Method __init__

sat/diffusion_video.py:27–91  ·  view source on GitHub ↗
(self, args, **kwargs)

Source from the content-addressed store, hash-verified

25
26class SATVideoDiffusionEngine(nn.Module):
27 def __init__(self, args, **kwargs):
28 super().__init__()
29
30 model_config = args.model_config
31 # model args preprocess
32 log_keys = model_config.get("log_keys", None)
33 input_key = model_config.get("input_key", "mp4")
34 network_config = model_config.get("network_config", None)
35 network_wrapper = model_config.get("network_wrapper", None)
36 denoiser_config = model_config.get("denoiser_config", None)
37 sampler_config = model_config.get("sampler_config", None)
38 conditioner_config = model_config.get("conditioner_config", None)
39 first_stage_config = model_config.get("first_stage_config", None)
40 loss_fn_config = model_config.get("loss_fn_config", None)
41 scale_factor = model_config.get("scale_factor", 1.0)
42 latent_input = model_config.get("latent_input", False)
43 disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
44 no_cond_log = model_config.get("disable_first_stage_autocast", False)
45 not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
46 compile_model = model_config.get("compile_model", False)
47 en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
48 lr_scale = model_config.get("lr_scale", None)
49 lora_train = model_config.get("lora_train", False)
50 self.use_pd = model_config.get("use_pd", False) # progressive distillation
51
52 self.log_keys = log_keys
53 self.input_key = input_key
54 self.not_trainable_prefixes = not_trainable_prefixes
55 self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
56 self.lr_scale = lr_scale
57 self.lora_train = lora_train
58 self.noised_image_input = model_config.get("noised_image_input", False)
59 self.noised_image_all_concat = model_config.get("noised_image_all_concat", False)
60 self.noised_image_dropout = model_config.get("noised_image_dropout", 0.0)
61 if args.fp16:
62 dtype = torch.float16
63 dtype_str = "fp16"
64 elif args.bf16:
65 dtype = torch.bfloat16
66 dtype_str = "bf16"
67 else:
68 dtype = torch.float32
69 dtype_str = "fp32"
70 self.dtype = dtype
71 self.dtype_str = dtype_str
72
73 network_config["params"]["dtype"] = dtype_str
74 model = instantiate_from_config(network_config)
75 self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
76 model, compile_model=compile_model, dtype=dtype
77 )
78
79 self.denoiser = instantiate_from_config(denoiser_config)
80 self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
81 self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
82
83 self._init_first_stage(first_stage_config)
84

Callers

nothing calls this directly

Calls 4

_init_first_stageMethod · 0.95
instantiate_from_configFunction · 0.90
get_obj_from_strFunction · 0.90
defaultFunction · 0.90

Tested by

no test coverage detected