MCPcopy
hub / github.com/Yuanshi9815/OminiControl / main

Function main

omini/train_flux/train_subject.py:151–201  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

149
150
151def main():
152 # Initialize
153 config = get_config()
154 training_config = config["train"]
155 torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
156
157 # Initialize raw dataset
158 raw_dataset = load_dataset("Yuanshi/Subjects200K")
159
160 # Define filter function to filter out low-quality images from Subjects200K
161 def filter_func(item):
162 if not item.get("quality_assessment"):
163 return False
164 return all(
165 item["quality_assessment"].get(key, 0) >= 5
166 for key in ["compositeStructure", "objectConsistency", "imageQuality"]
167 )
168
169 # Filter dataset
170 if not os.path.exists("./cache/dataset"):
171 os.makedirs("./cache/dataset")
172 data_valid = raw_dataset["train"].filter(
173 filter_func,
174 num_proc=16,
175 cache_file_name="./cache/dataset/data_valid.arrow",
176 )
177
178 # Initialize the dataset
179 dataset = Subject200KDataset(
180 data_valid,
181 condition_size=training_config["dataset"]["condition_size"],
182 target_size=training_config["dataset"]["target_size"],
183 image_size=training_config["dataset"]["image_size"],
184 padding=training_config["dataset"]["padding"],
185 condition_type=training_config["condition_type"],
186 drop_text_prob=training_config["dataset"]["drop_text_prob"],
187 drop_image_prob=training_config["dataset"]["drop_image_prob"],
188 )
189
190 # Initialize model
191 trainable_model = OminiModel(
192 flux_pipe_id=config["flux_path"],
193 lora_config=training_config["lora_config"],
194 device=f"cuda",
195 dtype=getattr(torch, config["dtype"]),
196 optimizer_config=training_config["optimizer"],
197 model_config=config.get("model", {}),
198 gradient_checkpointing=training_config.get("gradient_checkpointing", False),
199 )
200
201 train(dataset, trainable_model, config, test_function)
202
203
204if __name__ == "__main__":

Callers 1

train_subject.pyFile · 0.70

Calls 4

get_configFunction · 0.85
Subject200KDatasetClass · 0.85
OminiModelClass · 0.85
trainFunction · 0.85

Tested by

no test coverage detected