()
| 168 | |
| 169 | |
| 170 | def main(): |
| 171 | # Initialize |
| 172 | config = get_config() |
| 173 | training_config = config["train"] |
| 174 | torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) |
| 175 | |
| 176 | # Load dataset text-to-image-2M |
| 177 | dataset = load_dataset( |
| 178 | "webdataset", |
| 179 | data_files={"train": training_config["dataset"]["urls"]}, |
| 180 | split="train", |
| 181 | cache_dir="cache/t2i2m", |
| 182 | num_proc=32, |
| 183 | ) |
| 184 | |
| 185 | # Initialize custom dataset |
| 186 | dataset = ImageConditionDataset( |
| 187 | dataset, |
| 188 | condition_size=training_config["dataset"]["condition_size"], |
| 189 | target_size=training_config["dataset"]["target_size"], |
| 190 | condition_type=training_config["condition_type"], |
| 191 | drop_text_prob=training_config["dataset"]["drop_text_prob"], |
| 192 | drop_image_prob=training_config["dataset"]["drop_image_prob"], |
| 193 | position_scale=training_config["dataset"].get("position_scale", 1.0), |
| 194 | ) |
| 195 | |
| 196 | # Initialize model |
| 197 | trainable_model = OminiModel( |
| 198 | flux_pipe_id=config["flux_path"], |
| 199 | lora_config=training_config["lora_config"], |
| 200 | device=f"cuda", |
| 201 | dtype=getattr(torch, config["dtype"]), |
| 202 | optimizer_config=training_config["optimizer"], |
| 203 | model_config=config.get("model", {}), |
| 204 | gradient_checkpointing=training_config.get("gradient_checkpointing", False), |
| 205 | ) |
| 206 | |
| 207 | train(dataset, trainable_model, config, test_function) |
| 208 | |
| 209 | |
| 210 | if __name__ == "__main__": |
no test coverage detected