(self, args, **kwargs)
| 25 | |
| 26 | class SATVideoDiffusionEngine(nn.Module): |
| 27 | def __init__(self, args, **kwargs): |
| 28 | super().__init__() |
| 29 | |
| 30 | model_config = args.model_config |
| 31 | # model args preprocess |
| 32 | log_keys = model_config.get("log_keys", None) |
| 33 | input_key = model_config.get("input_key", "mp4") |
| 34 | network_config = model_config.get("network_config", None) |
| 35 | network_wrapper = model_config.get("network_wrapper", None) |
| 36 | denoiser_config = model_config.get("denoiser_config", None) |
| 37 | sampler_config = model_config.get("sampler_config", None) |
| 38 | conditioner_config = model_config.get("conditioner_config", None) |
| 39 | first_stage_config = model_config.get("first_stage_config", None) |
| 40 | loss_fn_config = model_config.get("loss_fn_config", None) |
| 41 | scale_factor = model_config.get("scale_factor", 1.0) |
| 42 | latent_input = model_config.get("latent_input", False) |
| 43 | disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False) |
| 44 | no_cond_log = model_config.get("disable_first_stage_autocast", False) |
| 45 | not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"]) |
| 46 | compile_model = model_config.get("compile_model", False) |
| 47 | en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None) |
| 48 | lr_scale = model_config.get("lr_scale", None) |
| 49 | lora_train = model_config.get("lora_train", False) |
| 50 | self.use_pd = model_config.get("use_pd", False) # progressive distillation |
| 51 | |
| 52 | self.log_keys = log_keys |
| 53 | self.input_key = input_key |
| 54 | self.not_trainable_prefixes = not_trainable_prefixes |
| 55 | self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time |
| 56 | self.lr_scale = lr_scale |
| 57 | self.lora_train = lora_train |
| 58 | self.noised_image_input = model_config.get("noised_image_input", False) |
| 59 | self.noised_image_all_concat = model_config.get("noised_image_all_concat", False) |
| 60 | self.noised_image_dropout = model_config.get("noised_image_dropout", 0.0) |
| 61 | if args.fp16: |
| 62 | dtype = torch.float16 |
| 63 | dtype_str = "fp16" |
| 64 | elif args.bf16: |
| 65 | dtype = torch.bfloat16 |
| 66 | dtype_str = "bf16" |
| 67 | else: |
| 68 | dtype = torch.float32 |
| 69 | dtype_str = "fp32" |
| 70 | self.dtype = dtype |
| 71 | self.dtype_str = dtype_str |
| 72 | |
| 73 | network_config["params"]["dtype"] = dtype_str |
| 74 | model = instantiate_from_config(network_config) |
| 75 | self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( |
| 76 | model, compile_model=compile_model, dtype=dtype |
| 77 | ) |
| 78 | |
| 79 | self.denoiser = instantiate_from_config(denoiser_config) |
| 80 | self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None |
| 81 | self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG)) |
| 82 | |
| 83 | self._init_first_stage(first_stage_config) |
| 84 |
nothing calls this directly
no test coverage detected