MCPcopy
hub / github.com/alibaba/EasyCV / _base_predict

Method _base_predict

tests/test_tools/test_predict.py:32–65  ·  view source on GitHub ↗
(self, model_type, model_path, dist=False)

Source from the content-addressed store, hash-verified

30 super().tearDown()
31
32 def _base_predict(self, model_type, model_path, dist=False):
33 input_file = tempfile.NamedTemporaryFile('w').name
34 input_line_num = 10
35 with open(input_file, 'w') as ofile:
36 for _ in range(input_line_num):
37 ofile.write(
38 os.path.join(TEST_IMAGES_DIR, '000000289059.jpg') + '\n')
39 output_file = tempfile.NamedTemporaryFile('w').name
40
41 if dist:
42 cmd = f'PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 --master_port=29527 \
43 tools/predict.py \
44 --input_file {input_file} \
45 --output_file {output_file} \
46 --model_type {model_type} \
47 --model_path {model_path} \
48 --launcher pytorch'
49
50 else:
51 cmd = f'PYTHONPATH=. python tools/predict.py \
52 --input_file {input_file} \
53 --output_file {output_file} \
54 --model_type {model_type} \
55 --model_path {model_path} '
56
57 logging.info('run command: %s' % cmd)
58 run_in_subprocess(cmd)
59
60 with open(output_file, 'r') as infile:
61 output_line_num = len(infile.readlines())
62 self.assertEqual(input_line_num, output_line_num)
63
64 io.remove(input_file)
65 io.remove(output_file)
66
67 def test_predict(self):
68 model_type = 'YoloXPredictor'

Callers 2

test_predictMethod · 0.95
test_predict_distMethod · 0.95

Calls 4

run_in_subprocessFunction · 0.90
writeMethod · 0.80
infoMethod · 0.45
removeMethod · 0.45

Tested by

no test coverage detected