MCPcopy
hub / github.com/showlab/Show-o / main

Function main

training/train_w_clip_vit.py:73–733  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

71
72
73def main():
74 #########################
75 # SETUP Accelerator #
76 #########################
77 config = get_config()
78
79 # Enable TF32 on Ampere GPUs
80 if config.training.enable_tf32:
81 torch.backends.cuda.matmul.allow_tf32 = True
82 torch.backends.cudnn.benchmark = True
83 torch.backends.cudnn.deterministic = False
84
85 config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs")
86 accelerator = Accelerator(
87 gradient_accumulation_steps=config.training.gradient_accumulation_steps,
88 mixed_precision=config.training.mixed_precision,
89 log_with="wandb",
90 project_dir=config.experiment.logging_dir,
91 split_batches=True,
92 )
93
94 total_batch_size_per_gpu = (config.training.batch_size_t2i
95 + config.training.batch_size_lm
96 + config.training.batch_size_mmu)
97 total_batch_size = (
98 (config.training.batch_size_t2i + config.training.batch_size_lm + config.training.batch_size_mmu)
99 * accelerator.num_processes * config.training.gradient_accumulation_steps
100 )
101
102 if accelerator.distributed_type == DistributedType.DEEPSPEED:
103 accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = (
104 total_batch_size_per_gpu
105 )
106
107 #####################################
108 # SETUP LOGGING, SEED and CONFIG #
109 #####################################
110 # Make one log on every process with the configuration for debugging.
111 logging.basicConfig(
112 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
113 datefmt="%m/%d/%Y %H:%M:%S",
114 level=logging.INFO,
115 )
116 logger.info(accelerator.state, main_process_only=False)
117 if accelerator.is_local_main_process:
118 set_verbosity_info()
119 else:
120 set_verbosity_error()
121
122 # We need to initialize the trackers we use, and also store our configuration.
123 # The trackers initializes automatically on the main process.
124 if accelerator.is_main_process:
125 resume_wandb_run = config.wandb.resume
126 run_id = config.wandb.get("run_id", None)
127 if run_id is None:
128 resume_wandb_run = False
129 run_id = wandb.util.generate_id()
130 config.wandb.run_id = run_id

Callers 1

Calls 15

updateMethod · 0.95
resetMethod · 0.95
get_configFunction · 0.90
set_verbosity_infoFunction · 0.90
set_verbosity_errorFunction · 0.90
flatten_omega_confFunction · 0.90
set_seedFunction · 0.90
UniversalPromptingClass · 0.90
ShowoClass · 0.90
CLIPVisionTowerClass · 0.90
get_mask_cheduleFunction · 0.90
get_schedulerFunction · 0.90

Tested by

no test coverage detected