(FLAGS, cfg)
| 122 | |
| 123 | |
| 124 | def run(FLAGS, cfg): |
| 125 | if FLAGS.json_eval: |
| 126 | logger.info( |
| 127 | "In json_eval mode, PaddleDetection will evaluate json files in " |
| 128 | "output_eval directly. And proposal.json, bbox.json and mask.json " |
| 129 | "will be detected by default.") |
| 130 | json_eval_results( |
| 131 | cfg.metric, |
| 132 | json_directory=FLAGS.output_eval, |
| 133 | dataset=create('EvalDataset')()) |
| 134 | return |
| 135 | |
| 136 | # init parallel environment if nranks > 1 |
| 137 | init_parallel_env() |
| 138 | ssod_method = cfg.get('ssod_method', None) |
| 139 | if ssod_method == 'ARSL': |
| 140 | # build ARSL_trainer |
| 141 | trainer = Trainer_ARSL(cfg, mode='eval') |
| 142 | # load ARSL_weights |
| 143 | trainer.load_weights(cfg.weights, ARSL_eval=True) |
| 144 | else: |
| 145 | # build trainer |
| 146 | trainer = Trainer(cfg, mode='eval') |
| 147 | #load weights |
| 148 | trainer.load_weights(cfg.weights) |
| 149 | |
| 150 | # training |
| 151 | if FLAGS.slice_infer: |
| 152 | trainer.evaluate_slice( |
| 153 | slice_size=FLAGS.slice_size, |
| 154 | overlap_ratio=FLAGS.overlap_ratio, |
| 155 | combine_method=FLAGS.combine_method, |
| 156 | match_threshold=FLAGS.match_threshold, |
| 157 | match_metric=FLAGS.match_metric) |
| 158 | else: |
| 159 | trainer.evaluate() |
| 160 | |
| 161 | |
| 162 | def main(): |
no test coverage detected