(FLAGS, cfg)
| 120 | |
| 121 | |
| 122 | def run(FLAGS, cfg): |
| 123 | # init fleet environment |
| 124 | if cfg.fleet: |
| 125 | init_fleet_env(cfg.get('find_unused_parameters', False)) |
| 126 | else: |
| 127 | # init parallel environment if nranks > 1 |
| 128 | init_parallel_env() |
| 129 | |
| 130 | if FLAGS.enable_ce: |
| 131 | set_random_seed(0) |
| 132 | |
| 133 | # build trainer |
| 134 | ssod_method = cfg.get('ssod_method', None) |
| 135 | if ssod_method is not None: |
| 136 | if ssod_method == 'DenseTeacher': |
| 137 | trainer = Trainer_DenseTeacher(cfg, mode='train') |
| 138 | elif ssod_method == 'ARSL': |
| 139 | trainer = Trainer_ARSL(cfg, mode='train') |
| 140 | elif ssod_method == 'Semi_RTDETR': |
| 141 | trainer = Trainer_Semi_RTDETR(cfg, mode='train') |
| 142 | else: |
| 143 | raise ValueError( |
| 144 | "Semi-Supervised Object Detection only no support this method.") |
| 145 | elif cfg.get('use_cot', False): |
| 146 | trainer = TrainerCot(cfg, mode='train') |
| 147 | else: |
| 148 | trainer = Trainer(cfg, mode='train') |
| 149 | |
| 150 | # load weights |
| 151 | if FLAGS.resume is not None: |
| 152 | trainer.resume_weights(FLAGS.resume) |
| 153 | elif 'pretrain_student_weights' in cfg and 'pretrain_teacher_weights' in cfg \ |
| 154 | and cfg.pretrain_teacher_weights and cfg.pretrain_student_weights: |
| 155 | trainer.load_semi_weights(cfg.pretrain_teacher_weights, |
| 156 | cfg.pretrain_student_weights) |
| 157 | elif 'pretrain_weights' in cfg and cfg.pretrain_weights: |
| 158 | trainer.load_weights(cfg.pretrain_weights) |
| 159 | |
| 160 | # training |
| 161 | trainer.train(FLAGS.eval) |
| 162 | |
| 163 | |
| 164 | def main(): |
no test coverage detected