MCPcopy
hub / github.com/FoundationVision/LlamaGen / main

Function main

tokenizer/tokenizer_image/vq_train.py:36–269  ·  view source on GitHub ↗

Trains a new model.

(args)

Source from the content-addressed store, hash-verified

34#################################################################################
35
36def main(args):
37 """
38 Trains a new model.
39 """
40 assert torch.cuda.is_available(), "Training currently requires at least one GPU."
41
42 # Setup DDP:
43 init_distributed_mode(args)
44 assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
45 rank = dist.get_rank()
46 device = rank % torch.cuda.device_count()
47 seed = args.global_seed * dist.get_world_size() + rank
48 torch.manual_seed(seed)
49 torch.cuda.set_device(device)
50
51 # Setup an experiment folder:
52 if rank == 0:
53 os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
54 experiment_index = len(glob(f"{args.results_dir}/*"))
55 model_string_name = args.vq_model.replace("/", "-")
56 experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder
57 checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
58 os.makedirs(checkpoint_dir, exist_ok=True)
59 logger = create_logger(experiment_dir)
60 logger.info(f"Experiment directory created at {experiment_dir}")
61
62 time_record = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
63 cloud_results_dir = f"{args.cloud_save_path}/{time_record}"
64 cloud_checkpoint_dir = f"{cloud_results_dir}/{experiment_index:03d}-{model_string_name}/checkpoints"
65 os.makedirs(cloud_checkpoint_dir, exist_ok=True)
66 logger.info(f"Experiment directory created in cloud at {cloud_checkpoint_dir}")
67
68 else:
69 logger = create_logger(None)
70
71 # training args
72 logger.info(f"{args}")
73
74 # training env
75 logger.info(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
76
77 # create and load model
78 vq_model = VQ_models[args.vq_model](
79 codebook_size=args.codebook_size,
80 codebook_embed_dim=args.codebook_embed_dim,
81 commit_loss_beta=args.commit_loss_beta,
82 entropy_loss_ratio=args.entropy_loss_ratio,
83 dropout_p=args.dropout_p,
84 )
85 logger.info(f"VQ Model Parameters: {sum(p.numel() for p in vq_model.parameters()):,}")
86 if args.ema:
87 ema = deepcopy(vq_model).to(device) # Create an EMA of the model for use after training
88 requires_grad(ema, False)
89 logger.info(f"VQ Model EMA Parameters: {sum(p.numel() for p in ema.parameters()):,}")
90 vq_model = vq_model.to(device)
91
92 vq_loss = VQLoss(
93 disc_start=args.disc_start,

Callers 1

vq_train.pyFile · 0.70

Calls 10

init_distributed_modeFunction · 0.90
create_loggerFunction · 0.90
requires_gradFunction · 0.90
VQLossClass · 0.90
random_crop_arrFunction · 0.90
build_datasetFunction · 0.90
update_emaFunction · 0.90
loadMethod · 0.80
stepMethod · 0.80
updateMethod · 0.80

Tested by

no test coverage detected