(self, model_type, model_path, dist=False)
| 30 | super().tearDown() |
| 31 | |
| 32 | def _base_predict(self, model_type, model_path, dist=False): |
| 33 | input_file = tempfile.NamedTemporaryFile('w').name |
| 34 | input_line_num = 10 |
| 35 | with open(input_file, 'w') as ofile: |
| 36 | for _ in range(input_line_num): |
| 37 | ofile.write( |
| 38 | os.path.join(TEST_IMAGES_DIR, '000000289059.jpg') + '\n') |
| 39 | output_file = tempfile.NamedTemporaryFile('w').name |
| 40 | |
| 41 | if dist: |
| 42 | cmd = f'PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 --master_port=29527 \ |
| 43 | tools/predict.py \ |
| 44 | --input_file {input_file} \ |
| 45 | --output_file {output_file} \ |
| 46 | --model_type {model_type} \ |
| 47 | --model_path {model_path} \ |
| 48 | --launcher pytorch' |
| 49 | |
| 50 | else: |
| 51 | cmd = f'PYTHONPATH=. python tools/predict.py \ |
| 52 | --input_file {input_file} \ |
| 53 | --output_file {output_file} \ |
| 54 | --model_type {model_type} \ |
| 55 | --model_path {model_path} ' |
| 56 | |
| 57 | logging.info('run command: %s' % cmd) |
| 58 | run_in_subprocess(cmd) |
| 59 | |
| 60 | with open(output_file, 'r') as infile: |
| 61 | output_line_num = len(infile.readlines()) |
| 62 | self.assertEqual(input_line_num, output_line_num) |
| 63 | |
| 64 | io.remove(input_file) |
| 65 | io.remove(output_file) |
| 66 | |
| 67 | def test_predict(self): |
| 68 | model_type = 'YoloXPredictor' |
no test coverage detected