(model, args)
| 265 | |
| 266 | |
| 267 | def launch_training_task(model, args): |
| 268 | # dataset and data loader |
| 269 | dataset = TextImageDataset( |
| 270 | args.dataset_path, |
| 271 | steps_per_epoch=args.steps_per_epoch * args.batch_size, |
| 272 | height=args.height, |
| 273 | width=args.width, |
| 274 | center_crop=args.center_crop, |
| 275 | random_flip=args.random_flip |
| 276 | ) |
| 277 | train_loader = torch.utils.data.DataLoader( |
| 278 | dataset, |
| 279 | shuffle=True, |
| 280 | batch_size=args.batch_size, |
| 281 | num_workers=args.dataloader_num_workers |
| 282 | ) |
| 283 | # train |
| 284 | if args.use_swanlab: |
| 285 | from swanlab.integration.pytorch_lightning import SwanLabLogger |
| 286 | swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} |
| 287 | swanlab_config.update(vars(args)) |
| 288 | swanlab_logger = SwanLabLogger( |
| 289 | project="diffsynth_studio", |
| 290 | name="diffsynth_studio", |
| 291 | config=swanlab_config, |
| 292 | mode=args.swanlab_mode, |
| 293 | logdir=os.path.join(args.output_path, "swanlog"), |
| 294 | ) |
| 295 | logger = [swanlab_logger] |
| 296 | else: |
| 297 | logger = None |
| 298 | trainer = pl.Trainer( |
| 299 | max_epochs=args.max_epochs, |
| 300 | accelerator="gpu", |
| 301 | devices="auto", |
| 302 | precision=args.precision, |
| 303 | strategy=args.training_strategy, |
| 304 | default_root_dir=args.output_path, |
| 305 | accumulate_grad_batches=args.accumulate_grad_batches, |
| 306 | callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], |
| 307 | logger=logger, |
| 308 | ) |
| 309 | trainer.fit(model=model, train_dataloaders=train_loader) |
| 310 | |
| 311 | # Upload models |
| 312 | if args.modelscope_model_id is not None and args.modelscope_access_token is not None: |
| 313 | print(f"Uploading models to modelscope. model_id: {args.modelscope_model_id} local_path: {trainer.log_dir}") |
| 314 | with open(os.path.join(trainer.log_dir, "configuration.json"), "w", encoding="utf-8") as f: |
| 315 | f.write('{"framework":"Pytorch","task":"text-to-image-synthesis"}\n') |
| 316 | api = HubApi() |
| 317 | api.login(args.modelscope_access_token) |
| 318 | api.push_model(model_id=args.modelscope_model_id, model_dir=trainer.log_dir) |
no test coverage detected