MCPcopy Index your code
hub / github.com/modelscope/DiffSynth-Studio / __init__

Method __init__

examples/train/flux/train_flux_lora.py:9–45  ·  view source on GitHub ↗
(
        self,
        torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
        learning_rate=1e-4, use_gradient_checkpointing=True,
        lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
        state_dict_converter=None, quantize = None
    )

Source from the content-addressed store, hash-verified

7
8class LightningModel(LightningModelForT2ILoRA):
9 def __init__(
10 self,
11 torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
12 learning_rate=1e-4, use_gradient_checkpointing=True,
13 lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
14 state_dict_converter=None, quantize = None
15 ):
16 super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
17 # Load models
18 model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
19 if quantize is None:
20 model_manager.load_models(pretrained_weights)
21 else:
22 model_manager.load_models(pretrained_weights[1:])
23 model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
24 if preset_lora_path is not None:
25 preset_lora_path = preset_lora_path.split(",")
26 for path in preset_lora_path:
27 model_manager.load_lora(path)
28
29 self.pipe = FluxImagePipeline.from_model_manager(model_manager)
30
31 if quantize is not None:
32 self.pipe.dit.quantize()
33
34 self.pipe.scheduler.set_timesteps(1000, training=True)
35
36 self.freeze_parameters()
37 self.add_lora_to_model(
38 self.pipe.denoising_model(),
39 lora_rank=lora_rank,
40 lora_alpha=lora_alpha,
41 lora_target_modules=lora_target_modules,
42 init_lora_weights=init_lora_weights,
43 pretrained_lora_path=pretrained_lora_path,
44 state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
45 )
46
47
48def parse_args():

Callers

nothing calls this directly

Calls 10

load_modelsMethod · 0.95
load_modelMethod · 0.95
load_loraMethod · 0.95
ModelManagerClass · 0.90
freeze_parametersMethod · 0.80
from_model_managerMethod · 0.45
quantizeMethod · 0.45
set_timestepsMethod · 0.45
add_lora_to_modelMethod · 0.45
denoising_modelMethod · 0.45

Tested by

no test coverage detected