MCPcopy
hub / github.com/alibaba/EasyCV / single_gpu_test

Function single_gpu_test

easycv/apis/test.py:76–152  ·  view source on GitHub ↗

Test model with single. This method tests model with single Args: model (nn.Module): Model to be tested. data_loader (nn.Dataloader): Pytorch data loader. model (str): mode for model to forward use_fp16: Use fp16 inference Returns: list: The pr

(model, data_loader, mode='test', use_fp16=False, **kwargs)

Source from the content-addressed store, hash-verified

74
75
76def single_gpu_test(model, data_loader, mode='test', use_fp16=False, **kwargs):
77 """Test model with single.
78
79 This method tests model with single
80
81 Args:
82 model (nn.Module): Model to be tested.
83 data_loader (nn.Dataloader): Pytorch data loader.
84 model (str): mode for model to forward
85 use_fp16: Use fp16 inference
86
87 Returns:
88 list: The prediction results.
89 """
90
91 if use_fp16:
92 device = next(model.parameters()).device
93 assert device.type == 'cuda', 'fp16 can only be used in gpu, model is placed on cpu'
94 model.half()
95
96 model.eval()
97 if hasattr(data_loader, 'dataset'): # normal dataloader
98 data_len = len(data_loader.dataset)
99 else:
100 data_len = len(data_loader) * data_loader.batch_size
101
102 prog_bar = mmcv.ProgressBar(data_len)
103 results = {}
104 for i, data in enumerate(data_loader):
105 # use scatter_kwargs to unpack DataContainer data for raw torch.nn.module
106 if not isinstance(model,
107 (MMDistributedDataParallel,
108 MMDataParallel)) and not is_torchacc_enabled():
109 input_args, kwargs = scatter_kwargs(None, data,
110 [torch.cuda.current_device()])
111 with torch.no_grad():
112 result = model(**kwargs[0], mode=mode)
113 else:
114 with torch.no_grad():
115 result = model(**data, mode=mode)
116
117 for k, v in result.items():
118 if k not in results:
119 results[k] = []
120 results[k].append(v)
121
122 if 'img_metas' in data:
123 if isinstance(data['img_metas'], list):
124 batch_size = len(data['img_metas'][0].data[0])
125 else:
126 batch_size = len(data['img_metas'].data[0])
127
128 else:
129 if isinstance(data['img'], list):
130 batch_size = data['img'][0].size(0)
131 else:
132 batch_size = data['img'].size(0)
133

Callers 4

mainFunction · 0.90
quantize_evalFunction · 0.90
after_train_epochMethod · 0.90
test_model_testMethod · 0.90

Calls 5

is_torchacc_enabledFunction · 0.90
ValueErrorClass · 0.90
sizeMethod · 0.45
updateMethod · 0.45
catMethod · 0.45

Tested by

no test coverage detected