(args)
| 45 | ################################################################################# |
| 46 | |
| 47 | def main(args): |
| 48 | |
| 49 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
| 50 | |
| 51 | # Setup DDP: |
| 52 | setup_distributed() |
| 53 | # dist.init_process_group("nccl") |
| 54 | # assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." |
| 55 | # rank = dist.get_rank() |
| 56 | # device = rank % torch.cuda.device_count() |
| 57 | # local_rank = rank |
| 58 | |
| 59 | rank = int(os.environ["RANK"]) |
| 60 | local_rank = int(os.environ["LOCAL_RANK"]) |
| 61 | device = torch.device("cuda", local_rank) |
| 62 | |
| 63 | seed = args.global_seed + rank |
| 64 | torch.manual_seed(seed) |
| 65 | torch.cuda.set_device(device) |
| 66 | print(f"Starting rank={rank}, local rank={local_rank}, seed={seed}, world_size={dist.get_world_size()}.") |
| 67 | |
| 68 | # Setup an experiment folder: |
| 69 | if rank == 0: |
| 70 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) |
| 71 | experiment_index = len(glob(f"{args.results_dir}/*")) |
| 72 | model_string_name = args.model.replace("/", "-") # e.g., Latte-XL/2 --> Latte-XL-2 (for naming folders) |
| 73 | num_frame_string = 'F' + str(args.num_frames) + 'S' + str(args.frame_interval) |
| 74 | experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}-{num_frame_string}-{args.dataset}" # Create an experiment folder |
| 75 | experiment_dir = get_experiment_dir(experiment_dir, args) |
| 76 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints |
| 77 | os.makedirs(checkpoint_dir, exist_ok=True) |
| 78 | logger = create_logger(experiment_dir) |
| 79 | tb_writer = create_tensorboard(experiment_dir) |
| 80 | OmegaConf.save(args, os.path.join(experiment_dir, 'config.yaml')) |
| 81 | logger.info(f"Experiment directory created at {experiment_dir}") |
| 82 | else: |
| 83 | logger = create_logger(None) |
| 84 | tb_writer = None |
| 85 | |
| 86 | # Create model: |
| 87 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." |
| 88 | sample_size = args.image_size // 8 |
| 89 | args.latent_size = sample_size |
| 90 | model = get_models(args) |
| 91 | |
| 92 | diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule |
| 93 | # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device) |
| 94 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device) |
| 95 | |
| 96 | # # use pretrained model? |
| 97 | if args.pretrained: |
| 98 | checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage) |
| 99 | if "ema" in checkpoint: # supports checkpoints from train.py |
| 100 | logger.info('Using ema ckpt!') |
| 101 | checkpoint = checkpoint["ema"] |
| 102 | |
| 103 | model_dict = model.state_dict() |
| 104 | # 1. filter out unnecessary keys |
no test coverage detected