MCPcopy
hub / github.com/Yuanshi9815/OminiControl / main

Function main

omini/train_flux/train_spatial_alignment.py:170–207  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

168
169
170def main():
171 # Initialize
172 config = get_config()
173 training_config = config["train"]
174 torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
175
176 # Load dataset text-to-image-2M
177 dataset = load_dataset(
178 "webdataset",
179 data_files={"train": training_config["dataset"]["urls"]},
180 split="train",
181 cache_dir="cache/t2i2m",
182 num_proc=32,
183 )
184
185 # Initialize custom dataset
186 dataset = ImageConditionDataset(
187 dataset,
188 condition_size=training_config["dataset"]["condition_size"],
189 target_size=training_config["dataset"]["target_size"],
190 condition_type=training_config["condition_type"],
191 drop_text_prob=training_config["dataset"]["drop_text_prob"],
192 drop_image_prob=training_config["dataset"]["drop_image_prob"],
193 position_scale=training_config["dataset"].get("position_scale", 1.0),
194 )
195
196 # Initialize model
197 trainable_model = OminiModel(
198 flux_pipe_id=config["flux_path"],
199 lora_config=training_config["lora_config"],
200 device=f"cuda",
201 dtype=getattr(torch, config["dtype"]),
202 optimizer_config=training_config["optimizer"],
203 model_config=config.get("model", {}),
204 gradient_checkpointing=training_config.get("gradient_checkpointing", False),
205 )
206
207 train(dataset, trainable_model, config, test_function)
208
209
210if __name__ == "__main__":

Callers 1

Calls 4

get_configFunction · 0.85
OminiModelClass · 0.85
trainFunction · 0.85

Tested by

no test coverage detected