MCPcopy Index your code
hub / github.com/PaddlePaddle/PaddleDetection / run

Function run

tools/train.py:122–161  ·  view source on GitHub ↗
(FLAGS, cfg)

Source from the content-addressed store, hash-verified

120
121
122def run(FLAGS, cfg):
123 # init fleet environment
124 if cfg.fleet:
125 init_fleet_env(cfg.get('find_unused_parameters', False))
126 else:
127 # init parallel environment if nranks > 1
128 init_parallel_env()
129
130 if FLAGS.enable_ce:
131 set_random_seed(0)
132
133 # build trainer
134 ssod_method = cfg.get('ssod_method', None)
135 if ssod_method is not None:
136 if ssod_method == 'DenseTeacher':
137 trainer = Trainer_DenseTeacher(cfg, mode='train')
138 elif ssod_method == 'ARSL':
139 trainer = Trainer_ARSL(cfg, mode='train')
140 elif ssod_method == 'Semi_RTDETR':
141 trainer = Trainer_Semi_RTDETR(cfg, mode='train')
142 else:
143 raise ValueError(
144 "Semi-Supervised Object Detection only no support this method.")
145 elif cfg.get('use_cot', False):
146 trainer = TrainerCot(cfg, mode='train')
147 else:
148 trainer = Trainer(cfg, mode='train')
149
150 # load weights
151 if FLAGS.resume is not None:
152 trainer.resume_weights(FLAGS.resume)
153 elif 'pretrain_student_weights' in cfg and 'pretrain_teacher_weights' in cfg \
154 and cfg.pretrain_teacher_weights and cfg.pretrain_student_weights:
155 trainer.load_semi_weights(cfg.pretrain_teacher_weights,
156 cfg.pretrain_student_weights)
157 elif 'pretrain_weights' in cfg and cfg.pretrain_weights:
158 trainer.load_weights(cfg.pretrain_weights)
159
160 # training
161 trainer.train(FLAGS.eval)
162
163
164def main():

Callers 1

mainFunction · 0.70

Calls 13

resume_weightsMethod · 0.95
load_weightsMethod · 0.95
trainMethod · 0.95
init_fleet_envFunction · 0.90
init_parallel_envFunction · 0.90
set_random_seedFunction · 0.90
Trainer_ARSLClass · 0.90
Trainer_Semi_RTDETRClass · 0.90
TrainerCotClass · 0.90
TrainerClass · 0.90
load_semi_weightsMethod · 0.80

Tested by

no test coverage detected