MCPcopy
hub / github.com/Vchitect/Latte / main

Function main

train.py:47–277  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

45#################################################################################
46
47def main(args):
48
49 assert torch.cuda.is_available(), "Training currently requires at least one GPU."
50
51 # Setup DDP:
52 setup_distributed()
53 # dist.init_process_group("nccl")
54 # assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
55 # rank = dist.get_rank()
56 # device = rank % torch.cuda.device_count()
57 # local_rank = rank
58
59 rank = int(os.environ["RANK"])
60 local_rank = int(os.environ["LOCAL_RANK"])
61 device = torch.device("cuda", local_rank)
62
63 seed = args.global_seed + rank
64 torch.manual_seed(seed)
65 torch.cuda.set_device(device)
66 print(f"Starting rank={rank}, local rank={local_rank}, seed={seed}, world_size={dist.get_world_size()}.")
67
68 # Setup an experiment folder:
69 if rank == 0:
70 os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
71 experiment_index = len(glob(f"{args.results_dir}/*"))
72 model_string_name = args.model.replace("/", "-") # e.g., Latte-XL/2 --> Latte-XL-2 (for naming folders)
73 num_frame_string = 'F' + str(args.num_frames) + 'S' + str(args.frame_interval)
74 experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}-{num_frame_string}-{args.dataset}" # Create an experiment folder
75 experiment_dir = get_experiment_dir(experiment_dir, args)
76 checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
77 os.makedirs(checkpoint_dir, exist_ok=True)
78 logger = create_logger(experiment_dir)
79 tb_writer = create_tensorboard(experiment_dir)
80 OmegaConf.save(args, os.path.join(experiment_dir, 'config.yaml'))
81 logger.info(f"Experiment directory created at {experiment_dir}")
82 else:
83 logger = create_logger(None)
84 tb_writer = None
85
86 # Create model:
87 assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
88 sample_size = args.image_size // 8
89 args.latent_size = sample_size
90 model = get_models(args)
91
92 diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
93 # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
94 vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device)
95
96 # # use pretrained model?
97 if args.pretrained:
98 checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage)
99 if "ema" in checkpoint: # supports checkpoints from train.py
100 logger.info('Using ema ckpt!')
101 checkpoint = checkpoint["ema"]
102
103 model_dict = model.state_dict()
104 # 1. filter out unnecessary keys

Callers 1

train.pyFile · 0.70

Calls 15

setup_distributedFunction · 0.90
get_experiment_dirFunction · 0.90
create_loggerFunction · 0.90
create_tensorboardFunction · 0.90
get_modelsFunction · 0.90
create_diffusionFunction · 0.90
requires_gradFunction · 0.90
get_datasetFunction · 0.90
update_emaFunction · 0.90
clip_grad_norm_Function · 0.90
write_tensorboardFunction · 0.90
cleanupFunction · 0.90

Tested by

no test coverage detected