()
| 149 | |
| 150 | |
| 151 | def main(): |
| 152 | # Initialize |
| 153 | config = get_config() |
| 154 | training_config = config["train"] |
| 155 | torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) |
| 156 | |
| 157 | # Initialize raw dataset |
| 158 | raw_dataset = load_dataset("Yuanshi/Subjects200K") |
| 159 | |
| 160 | # Define filter function to filter out low-quality images from Subjects200K |
| 161 | def filter_func(item): |
| 162 | if not item.get("quality_assessment"): |
| 163 | return False |
| 164 | return all( |
| 165 | item["quality_assessment"].get(key, 0) >= 5 |
| 166 | for key in ["compositeStructure", "objectConsistency", "imageQuality"] |
| 167 | ) |
| 168 | |
| 169 | # Filter dataset |
| 170 | if not os.path.exists("./cache/dataset"): |
| 171 | os.makedirs("./cache/dataset") |
| 172 | data_valid = raw_dataset["train"].filter( |
| 173 | filter_func, |
| 174 | num_proc=16, |
| 175 | cache_file_name="./cache/dataset/data_valid.arrow", |
| 176 | ) |
| 177 | |
| 178 | # Initialize the dataset |
| 179 | dataset = Subject200KDataset( |
| 180 | data_valid, |
| 181 | condition_size=training_config["dataset"]["condition_size"], |
| 182 | target_size=training_config["dataset"]["target_size"], |
| 183 | image_size=training_config["dataset"]["image_size"], |
| 184 | padding=training_config["dataset"]["padding"], |
| 185 | condition_type=training_config["condition_type"], |
| 186 | drop_text_prob=training_config["dataset"]["drop_text_prob"], |
| 187 | drop_image_prob=training_config["dataset"]["drop_image_prob"], |
| 188 | ) |
| 189 | |
| 190 | # Initialize model |
| 191 | trainable_model = OminiModel( |
| 192 | flux_pipe_id=config["flux_path"], |
| 193 | lora_config=training_config["lora_config"], |
| 194 | device=f"cuda", |
| 195 | dtype=getattr(torch, config["dtype"]), |
| 196 | optimizer_config=training_config["optimizer"], |
| 197 | model_config=config.get("model", {}), |
| 198 | gradient_checkpointing=training_config.get("gradient_checkpointing", False), |
| 199 | ) |
| 200 | |
| 201 | train(dataset, trainable_model, config, test_function) |
| 202 | |
| 203 | |
| 204 | if __name__ == "__main__": |
no test coverage detected