(
self,
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
state_dict_converter=None, quantize = None
)
| 7 | |
| 8 | class LightningModel(LightningModelForT2ILoRA): |
| 9 | def __init__( |
| 10 | self, |
| 11 | torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None, |
| 12 | learning_rate=1e-4, use_gradient_checkpointing=True, |
| 13 | lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None, |
| 14 | state_dict_converter=None, quantize = None |
| 15 | ): |
| 16 | super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter) |
| 17 | # Load models |
| 18 | model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device) |
| 19 | if quantize is None: |
| 20 | model_manager.load_models(pretrained_weights) |
| 21 | else: |
| 22 | model_manager.load_models(pretrained_weights[1:]) |
| 23 | model_manager.load_model(pretrained_weights[0], torch_dtype=quantize) |
| 24 | if preset_lora_path is not None: |
| 25 | preset_lora_path = preset_lora_path.split(",") |
| 26 | for path in preset_lora_path: |
| 27 | model_manager.load_lora(path) |
| 28 | |
| 29 | self.pipe = FluxImagePipeline.from_model_manager(model_manager) |
| 30 | |
| 31 | if quantize is not None: |
| 32 | self.pipe.dit.quantize() |
| 33 | |
| 34 | self.pipe.scheduler.set_timesteps(1000, training=True) |
| 35 | |
| 36 | self.freeze_parameters() |
| 37 | self.add_lora_to_model( |
| 38 | self.pipe.denoising_model(), |
| 39 | lora_rank=lora_rank, |
| 40 | lora_alpha=lora_alpha, |
| 41 | lora_target_modules=lora_target_modules, |
| 42 | init_lora_weights=init_lora_weights, |
| 43 | pretrained_lora_path=pretrained_lora_path, |
| 44 | state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format |
| 45 | ) |
| 46 | |
| 47 | |
| 48 | def parse_args(): |
nothing calls this directly
no test coverage detected