(args)
| 489 | |
| 490 | |
| 491 | def main(args): |
| 492 | args = copy.deepcopy(args) |
| 493 | utils.init_distributed_mode(args) |
| 494 | |
| 495 | device = torch.device(args.device) |
| 496 | |
| 497 | # Fix the seed for reproducibility |
| 498 | args.seed = args.seed + utils.get_rank() |
| 499 | torch.manual_seed(args.seed) |
| 500 | np.random.seed(args.seed) |
| 501 | # random.seed(args.seed) |
| 502 | |
| 503 | cudnn.benchmark = True |
| 504 | |
| 505 | if not args.show_user_warnings: |
| 506 | warnings.filterwarnings("ignore", category=UserWarning) |
| 507 | |
| 508 | if args.dtype in ['float16', 'fp16']: |
| 509 | dtype = torch.float16 |
| 510 | elif args.dtype in ['bfloat16', 'bf16']: |
| 511 | dtype = torch.bfloat16 |
| 512 | elif args.dtype in ['float32', 'fp32']: |
| 513 | dtype = torch.float32 |
| 514 | else: |
| 515 | raise ValueError(f"Invalid dtype: {args.dtype}") |
| 516 | |
| 517 | if args.data_name == 'auto': |
| 518 | args.data_name = Path(args.data_config_path).stem |
| 519 | if args.name == 'auto': |
| 520 | args.name = Path(args.gen_config_path).stem |
| 521 | if args.sr_name == 'auto': |
| 522 | args.sr_name = Path(args.sr_config_path).stem |
| 523 | |
| 524 | # Output directory |
| 525 | args.output_dir = os.path.join(args.output_dir, args.data_name, f'{args.name}--{args.sr_name}' if args.sr_name else args.name) |
| 526 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
| 527 | |
| 528 | # Prepare args |
| 529 | delim = '-' |
| 530 | |
| 531 | # Generation parameters |
| 532 | args.cond_domains = sorted(list(string_to_list(args.cond_domains, dtype=str, delim=delim))) |
| 533 | args.target_domains = string_to_list(args.target_domains, dtype=str, delim=delim) |
| 534 | args.all_domains = sorted(list(set(args.cond_domains) | set(args.target_domains))) |
| 535 | args.loaded_domains = sorted(list(set(args.cond_domains) | set(['rgb']))) |
| 536 | n_targets = len(args.target_domains) |
| 537 | args.tokens_per_target = repeat_if_necessary(string_to_list(args.tokens_per_target, dtype=int, delim=delim), n_targets) |
| 538 | args.autoregression_schemes = repeat_if_necessary(string_to_list(args.autoregression_schemes, dtype=str, delim=delim), n_targets) |
| 539 | args.decoding_steps = repeat_if_necessary(string_to_list(args.decoding_steps, dtype=int, delim=delim), n_targets) |
| 540 | args.token_decoding_schedules = repeat_if_necessary(string_to_list(args.token_decoding_schedules, dtype=str, delim=delim), n_targets) |
| 541 | args.temps = repeat_if_necessary(string_to_list(args.temps, dtype=float, delim=delim), n_targets) |
| 542 | args.temp_schedules = repeat_if_necessary(string_to_list(args.temp_schedules, dtype=str, delim=delim), n_targets) |
| 543 | args.cfg_scales = repeat_if_necessary(string_to_list(args.cfg_scales, dtype=float, delim=delim), n_targets) |
| 544 | args.cfg_schedules = repeat_if_necessary(string_to_list(args.cfg_schedules, dtype=str, delim=delim), n_targets) |
| 545 | |
| 546 | # Super-resolution parameters |
| 547 | if args.sr_cond_domains is None: |
| 548 | args.sr_cond_domains = args.cond_domains + args.target_domains |
no test coverage detected