MCPcopy
hub / github.com/Vchitect/Latte / main

Function main

train_with_img.py:44–297  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

42#################################################################################
43
44def main(args):
45
46 assert torch.cuda.is_available(), "Training currently requires at least one GPU."
47
48 # Setup DDP:
49 setup_distributed()
50
51 rank = int(os.environ["RANK"])
52 local_rank = int(os.environ["LOCAL_RANK"])
53 device = torch.device("cuda", local_rank)
54
55 seed = args.global_seed + rank
56 torch.manual_seed(seed)
57 torch.cuda.set_device(device)
58 print(f"Starting rank={rank}, local rank={local_rank}, seed={seed}, world_size={dist.get_world_size()}.")
59
60 # Setup an experiment folder:
61 if rank == 0:
62 os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
63 experiment_index = len(glob(f"{args.results_dir}/*"))
64 model_string_name = args.model.replace("/", "-") # e.g., Latte-XL/2 --> Latte-XL-2 (for naming folders)
65 num_frame_string = 'F' + str(args.num_frames) + 'S' + str(args.frame_interval)
66 experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}-{num_frame_string}-{args.dataset}" # Create an experiment folder
67 experiment_dir = get_experiment_dir(experiment_dir, args)
68 checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
69 os.makedirs(checkpoint_dir, exist_ok=True)
70 logger = create_logger(experiment_dir)
71 tb_writer = create_tensorboard(experiment_dir)
72 OmegaConf.save(args, os.path.join(experiment_dir, 'config.yaml'))
73 logger.info(f"Experiment directory created at {experiment_dir}")
74 else:
75 logger = create_logger(None)
76 tb_writer = None
77
78 # Create model:
79 assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
80 sample_size = args.image_size // 8
81 args.latent_size = sample_size
82 model = get_models(args)
83
84 diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
85 # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
86 vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="sd-vae-ft-mse").to(device)
87
88 # # use pretrained model?
89 if args.pretrained:
90 checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage)
91 if "ema" in checkpoint: # supports checkpoints from train.py
92 logger.info('Using ema ckpt!')
93 checkpoint = checkpoint["ema"]
94
95 model_dict = model.state_dict()
96 # 1. filter out unnecessary keys
97 pretrained_dict = {}
98 for k, v in checkpoint.items():
99 if k in model_dict:
100 pretrained_dict[k] = v
101 else:

Callers 1

train_with_img.pyFile · 0.70

Calls 15

setup_distributedFunction · 0.90
get_experiment_dirFunction · 0.90
create_loggerFunction · 0.90
create_tensorboardFunction · 0.90
get_modelsFunction · 0.90
create_diffusionFunction · 0.90
requires_gradFunction · 0.90
get_datasetFunction · 0.90
update_emaFunction · 0.90
clip_grad_norm_Function · 0.90
write_tensorboardFunction · 0.90
cleanupFunction · 0.90

Tested by

no test coverage detected