MCPcopy
hub / github.com/modelscope/FunASR / main

Function main

funasr/bin/train.py:61–285  ·  view source on GitHub ↗

Main. Args: **kwargs: Additional keyword arguments.

(**kwargs)

Source from the content-addressed store, hash-verified

59
60
61def main(**kwargs):
62
63 # set random seed
64 """Main.
65
66 Args:
67 **kwargs: Additional keyword arguments.
68 """
69 set_all_random_seed(kwargs.get("seed", 0))
70 torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
71 torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
72 torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
73 # open tf32
74 torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
75
76 local_rank = int(os.environ.get("LOCAL_RANK", 0))
77 if local_rank == 0:
78 tables.print()
79 # Check if we are using DDP or FSDP
80 use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
81 use_fsdp = kwargs.get("use_fsdp", False)
82 # use_ddp = False if use_fsdp else use_fsdp
83 if use_ddp or use_fsdp:
84 dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
85 torch.cuda.set_device(local_rank)
86
87 logging.info("Build model, frontend, tokenizer")
88 device = kwargs.get("device", "cuda")
89 kwargs["device"] = "cpu"
90 model = AutoModel(**kwargs)
91
92 # save config.yaml
93 if (
94 (use_ddp or use_fsdp)
95 and dist.get_rank() == 0
96 or not (use_ddp or use_fsdp)
97 and local_rank == 0
98 ):
99 prepare_model_dir(**kwargs)
100
101 # parse kwargs
102 kwargs = model.kwargs
103 kwargs["device"] = device
104 tokenizer = kwargs["tokenizer"]
105 frontend = kwargs["frontend"]
106 model = model.model
107 del kwargs["model"]
108
109 # freeze_param
110 freeze_param = kwargs.get("freeze_param", None)
111 if freeze_param is not None:
112 if "," in freeze_param:
113 freeze_param = freeze_param.split(",")
114 if not isinstance(freeze_param, (list, tuple)):
115 freeze_param = (freeze_param,)
116 logging.info("freeze_param is not None: %s", freeze_param)
117 for t in freeze_param:
118 for k, p in model.named_parameters():

Callers 1

main_hydraFunction · 0.70

Calls 15

resume_checkpointMethod · 0.95
train_epochMethod · 0.95
validate_epochMethod · 0.95
save_checkpointMethod · 0.95
closeMethod · 0.95
set_all_random_seedFunction · 0.90
AutoModelClass · 0.90
prepare_model_dirFunction · 0.90
model_summaryFunction · 0.90
TrainerClass · 0.90
average_checkpointsFunction · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…