MCPcopy
hub / github.com/showlab/Show-o / main

Function main

show-o2/train_mixed_modality_simple.py:56–558  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

54
55
56def main():
57 #########################
58 # SETUP Accelerator #
59 #########################
60 config = get_config()
61
62 # Enable TF32 on Ampere GPUs
63 if config.training.enable_tf32:
64 torch.backends.cuda.matmul.allow_tf32 = True
65 torch.backends.cudnn.benchmark = True
66 torch.backends.cudnn.deterministic = False
67
68 config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs")
69 accelerator = Accelerator(
70 gradient_accumulation_steps=config.training.gradient_accumulation_steps,
71 mixed_precision=config.training.mixed_precision,
72 log_with="wandb",
73 project_dir=config.experiment.logging_dir,
74 split_batches=True,
75 )
76
77 bs_mixed_modal = config.training.batch_size_mixed_modal
78
79 if "concat" in config.dataset.mixed_loader_mode:
80 raise NotImplementedError
81 else:
82 total_batch_size_per_gpu = bs_mixed_modal * config.dataset.accumulation
83 total_batch_size_without_accum = total_batch_size_per_gpu * accelerator.num_processes
84 total_batch_size = total_batch_size_without_accum * config.training.gradient_accumulation_steps
85
86 if accelerator.distributed_type == DistributedType.DEEPSPEED:
87 accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = (
88 total_batch_size_per_gpu
89 )
90
91 #####################################
92 # SETUP LOGGING, SEED and CONFIG #
93 #####################################
94 # Make one log on every process with the configuration for debugging.
95 logging.basicConfig(
96 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
97 datefmt="%m/%d/%Y %H:%M:%S",
98 level=logging.INFO,
99 )
100 logger.info(accelerator.state, main_process_only=False)
101 if accelerator.is_local_main_process:
102 set_verbosity_info()
103 else:
104 set_verbosity_error()
105
106 # We need to initialize the trackers we use, and also store our configuration.
107 # The trackers initializes automatically on the main process.
108 if accelerator.is_main_process:
109 resume_wandb_run = config.wandb.resume
110 run_id = config.wandb.get("run_id", None)
111 if run_id is None:
112 resume_wandb_run = False
113 run_id = wandb.util.generate_id()

Callers 1

Calls 15

updateMethod · 0.95
resetMethod · 0.95
get_configFunction · 0.90
set_verbosity_infoFunction · 0.90
set_verbosity_errorFunction · 0.90
flatten_omega_confFunction · 0.90
set_seedFunction · 0.90
get_weight_typeFunction · 0.90
WanVAEClass · 0.90
get_text_tokenizerFunction · 0.90
Showo2Qwen2_5Class · 0.90
_freeze_paramsFunction · 0.90

Tested by

no test coverage detected