(self, dataloader, modelpath=None)
| 136 | |
| 137 | |
| 138 | def evaluate(self, dataloader, modelpath=None): |
| 139 | if modelpath is not None: |
| 140 | self.model.load_state_dict(torch.load(modelpath)) |
| 141 | texts = dataloader.dataset.labels |
| 142 | text_features = self.get_text_features_list(texts) |
| 143 | res = None |
| 144 | for batch in tqdm(dataloader): |
| 145 | image, _, label = batch |
| 146 | image = image.to(self.device) |
| 147 | label = label.to(self.device) |
| 148 | image_features = self.get_image_features(image) |
| 149 | similarity = self.get_similarity(image_features, text_features) |
| 150 | _, indices = similarity.topk(1) |
| 151 | |
| 152 | pred = torch.squeeze(indices) |
| 153 | result = torch.cat([pred.view(-1, 1), label.view(-1, 1)], dim=1) |
| 154 | if res is None: |
| 155 | res = result |
| 156 | else: |
| 157 | res = torch.cat([res, result], dim=0) |
| 158 | res = res.cpu().numpy() |
| 159 | acc = np.mean(np.array(res)[:, 0] == np.array(res)[:, 1]) |
| 160 | return acc, res |
| 161 | |
| 162 | |
| 163 | if __name__ == '__main__': |
no test coverage detected