MCPcopy Index your code
hub / github.com/modelscope/DiffSynth-Studio / launch_training_task

Function launch_training_task

diffsynth/trainers/text_to_image.py:267–318  ·  view source on GitHub ↗
(model, args)

Source from the content-addressed store, hash-verified

265
266
267def launch_training_task(model, args):
268 # dataset and data loader
269 dataset = TextImageDataset(
270 args.dataset_path,
271 steps_per_epoch=args.steps_per_epoch * args.batch_size,
272 height=args.height,
273 width=args.width,
274 center_crop=args.center_crop,
275 random_flip=args.random_flip
276 )
277 train_loader = torch.utils.data.DataLoader(
278 dataset,
279 shuffle=True,
280 batch_size=args.batch_size,
281 num_workers=args.dataloader_num_workers
282 )
283 # train
284 if args.use_swanlab:
285 from swanlab.integration.pytorch_lightning import SwanLabLogger
286 swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
287 swanlab_config.update(vars(args))
288 swanlab_logger = SwanLabLogger(
289 project="diffsynth_studio",
290 name="diffsynth_studio",
291 config=swanlab_config,
292 mode=args.swanlab_mode,
293 logdir=os.path.join(args.output_path, "swanlog"),
294 )
295 logger = [swanlab_logger]
296 else:
297 logger = None
298 trainer = pl.Trainer(
299 max_epochs=args.max_epochs,
300 accelerator="gpu",
301 devices="auto",
302 precision=args.precision,
303 strategy=args.training_strategy,
304 default_root_dir=args.output_path,
305 accumulate_grad_batches=args.accumulate_grad_batches,
306 callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
307 logger=logger,
308 )
309 trainer.fit(model=model, train_dataloaders=train_loader)
310
311 # Upload models
312 if args.modelscope_model_id is not None and args.modelscope_access_token is not None:
313 print(f"Uploading models to modelscope. model_id: {args.modelscope_model_id} local_path: {trainer.log_dir}")
314 with open(os.path.join(trainer.log_dir, "configuration.json"), "w", encoding="utf-8") as f:
315 f.write('{"framework":"Pytorch","task":"text-to-image-synthesis"}\n')
316 api = HubApi()
317 api.login(args.modelscope_access_token)
318 api.push_model(model_id=args.modelscope_model_id, model_dir=trainer.log_dir)

Callers 6

train_sd3_lora.pyFile · 0.90
train_sdxl_lora.pyFile · 0.90
train_flux_lora.pyFile · 0.90
train_sd_lora.pyFile · 0.90

Calls 2

TextImageDatasetClass · 0.85
updateMethod · 0.45

Tested by

no test coverage detected