MCPcopy
hub / github.com/apple/ml-4m / main

Function main

run_generation.py:491–629  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

489
490
491def main(args):
492 args = copy.deepcopy(args)
493 utils.init_distributed_mode(args)
494
495 device = torch.device(args.device)
496
497 # Fix the seed for reproducibility
498 args.seed = args.seed + utils.get_rank()
499 torch.manual_seed(args.seed)
500 np.random.seed(args.seed)
501 # random.seed(args.seed)
502
503 cudnn.benchmark = True
504
505 if not args.show_user_warnings:
506 warnings.filterwarnings("ignore", category=UserWarning)
507
508 if args.dtype in ['float16', 'fp16']:
509 dtype = torch.float16
510 elif args.dtype in ['bfloat16', 'bf16']:
511 dtype = torch.bfloat16
512 elif args.dtype in ['float32', 'fp32']:
513 dtype = torch.float32
514 else:
515 raise ValueError(f"Invalid dtype: {args.dtype}")
516
517 if args.data_name == 'auto':
518 args.data_name = Path(args.data_config_path).stem
519 if args.name == 'auto':
520 args.name = Path(args.gen_config_path).stem
521 if args.sr_name == 'auto':
522 args.sr_name = Path(args.sr_config_path).stem
523
524 # Output directory
525 args.output_dir = os.path.join(args.output_dir, args.data_name, f'{args.name}--{args.sr_name}' if args.sr_name else args.name)
526 Path(args.output_dir).mkdir(parents=True, exist_ok=True)
527
528 # Prepare args
529 delim = '-'
530
531 # Generation parameters
532 args.cond_domains = sorted(list(string_to_list(args.cond_domains, dtype=str, delim=delim)))
533 args.target_domains = string_to_list(args.target_domains, dtype=str, delim=delim)
534 args.all_domains = sorted(list(set(args.cond_domains) | set(args.target_domains)))
535 args.loaded_domains = sorted(list(set(args.cond_domains) | set(['rgb'])))
536 n_targets = len(args.target_domains)
537 args.tokens_per_target = repeat_if_necessary(string_to_list(args.tokens_per_target, dtype=int, delim=delim), n_targets)
538 args.autoregression_schemes = repeat_if_necessary(string_to_list(args.autoregression_schemes, dtype=str, delim=delim), n_targets)
539 args.decoding_steps = repeat_if_necessary(string_to_list(args.decoding_steps, dtype=int, delim=delim), n_targets)
540 args.token_decoding_schedules = repeat_if_necessary(string_to_list(args.token_decoding_schedules, dtype=str, delim=delim), n_targets)
541 args.temps = repeat_if_necessary(string_to_list(args.temps, dtype=float, delim=delim), n_targets)
542 args.temp_schedules = repeat_if_necessary(string_to_list(args.temp_schedules, dtype=str, delim=delim), n_targets)
543 args.cfg_scales = repeat_if_necessary(string_to_list(args.cfg_scales, dtype=float, delim=delim), n_targets)
544 args.cfg_schedules = repeat_if_necessary(string_to_list(args.cfg_schedules, dtype=str, delim=delim), n_targets)
545
546 # Super-resolution parameters
547 if args.sr_cond_domains is None:
548 args.sr_cond_domains = args.cond_domains + args.target_domains

Callers 1

run_generation.pyFile · 0.70

Calls 11

set_stepMethod · 0.95
updateMethod · 0.95
GenerationSamplerClass · 0.90
string_to_listFunction · 0.85
repeat_if_necessaryFunction · 0.85
load_tokenizersFunction · 0.85
get_datasetFunction · 0.85
printFunction · 0.85
generateFunction · 0.85
deviceMethod · 0.80
load_modelFunction · 0.70

Tested by

no test coverage detected