Test model with single. This method tests model with single Args: model (nn.Module): Model to be tested. data_loader (nn.Dataloader): Pytorch data loader. model (str): mode for model to forward use_fp16: Use fp16 inference Returns: list: The pr
(model, data_loader, mode='test', use_fp16=False, **kwargs)
| 74 | |
| 75 | |
| 76 | def single_gpu_test(model, data_loader, mode='test', use_fp16=False, **kwargs): |
| 77 | """Test model with single. |
| 78 | |
| 79 | This method tests model with single |
| 80 | |
| 81 | Args: |
| 82 | model (nn.Module): Model to be tested. |
| 83 | data_loader (nn.Dataloader): Pytorch data loader. |
| 84 | model (str): mode for model to forward |
| 85 | use_fp16: Use fp16 inference |
| 86 | |
| 87 | Returns: |
| 88 | list: The prediction results. |
| 89 | """ |
| 90 | |
| 91 | if use_fp16: |
| 92 | device = next(model.parameters()).device |
| 93 | assert device.type == 'cuda', 'fp16 can only be used in gpu, model is placed on cpu' |
| 94 | model.half() |
| 95 | |
| 96 | model.eval() |
| 97 | if hasattr(data_loader, 'dataset'): # normal dataloader |
| 98 | data_len = len(data_loader.dataset) |
| 99 | else: |
| 100 | data_len = len(data_loader) * data_loader.batch_size |
| 101 | |
| 102 | prog_bar = mmcv.ProgressBar(data_len) |
| 103 | results = {} |
| 104 | for i, data in enumerate(data_loader): |
| 105 | # use scatter_kwargs to unpack DataContainer data for raw torch.nn.module |
| 106 | if not isinstance(model, |
| 107 | (MMDistributedDataParallel, |
| 108 | MMDataParallel)) and not is_torchacc_enabled(): |
| 109 | input_args, kwargs = scatter_kwargs(None, data, |
| 110 | [torch.cuda.current_device()]) |
| 111 | with torch.no_grad(): |
| 112 | result = model(**kwargs[0], mode=mode) |
| 113 | else: |
| 114 | with torch.no_grad(): |
| 115 | result = model(**data, mode=mode) |
| 116 | |
| 117 | for k, v in result.items(): |
| 118 | if k not in results: |
| 119 | results[k] = [] |
| 120 | results[k].append(v) |
| 121 | |
| 122 | if 'img_metas' in data: |
| 123 | if isinstance(data['img_metas'], list): |
| 124 | batch_size = len(data['img_metas'][0].data[0]) |
| 125 | else: |
| 126 | batch_size = len(data['img_metas'].data[0]) |
| 127 | |
| 128 | else: |
| 129 | if isinstance(data['img'], list): |
| 130 | batch_size = data['img'][0].size(0) |
| 131 | else: |
| 132 | batch_size = data['img'].size(0) |
| 133 |
no test coverage detected