(args)
| 42 | ################################################################################# |
| 43 | |
| 44 | def main(args): |
| 45 | |
| 46 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
| 47 | |
| 48 | # Setup DDP: |
| 49 | setup_distributed() |
| 50 | |
| 51 | rank = int(os.environ["RANK"]) |
| 52 | local_rank = int(os.environ["LOCAL_RANK"]) |
| 53 | device = torch.device("cuda", local_rank) |
| 54 | |
| 55 | seed = args.global_seed + rank |
| 56 | torch.manual_seed(seed) |
| 57 | torch.cuda.set_device(device) |
| 58 | print(f"Starting rank={rank}, local rank={local_rank}, seed={seed}, world_size={dist.get_world_size()}.") |
| 59 | |
| 60 | # Setup an experiment folder: |
| 61 | if rank == 0: |
| 62 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) |
| 63 | experiment_index = len(glob(f"{args.results_dir}/*")) |
| 64 | model_string_name = args.model.replace("/", "-") # e.g., Latte-XL/2 --> Latte-XL-2 (for naming folders) |
| 65 | num_frame_string = 'F' + str(args.num_frames) + 'S' + str(args.frame_interval) |
| 66 | experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}-{num_frame_string}-{args.dataset}" # Create an experiment folder |
| 67 | experiment_dir = get_experiment_dir(experiment_dir, args) |
| 68 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints |
| 69 | os.makedirs(checkpoint_dir, exist_ok=True) |
| 70 | logger = create_logger(experiment_dir) |
| 71 | tb_writer = create_tensorboard(experiment_dir) |
| 72 | OmegaConf.save(args, os.path.join(experiment_dir, 'config.yaml')) |
| 73 | logger.info(f"Experiment directory created at {experiment_dir}") |
| 74 | else: |
| 75 | logger = create_logger(None) |
| 76 | tb_writer = None |
| 77 | |
| 78 | # Create model: |
| 79 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." |
| 80 | sample_size = args.image_size // 8 |
| 81 | args.latent_size = sample_size |
| 82 | model = get_models(args) |
| 83 | |
| 84 | diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule |
| 85 | # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device) |
| 86 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="sd-vae-ft-mse").to(device) |
| 87 | |
| 88 | # # use pretrained model? |
| 89 | if args.pretrained: |
| 90 | checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage) |
| 91 | if "ema" in checkpoint: # supports checkpoints from train.py |
| 92 | logger.info('Using ema ckpt!') |
| 93 | checkpoint = checkpoint["ema"] |
| 94 | |
| 95 | model_dict = model.state_dict() |
| 96 | # 1. filter out unnecessary keys |
| 97 | pretrained_dict = {} |
| 98 | for k, v in checkpoint.items(): |
| 99 | if k in model_dict: |
| 100 | pretrained_dict[k] = v |
| 101 | else: |
no test coverage detected