(data_loader, eval_model, num_feats, batch_size, quantize, world_size, DDP, device, disable_tqdm)
| 105 | |
| 106 | |
| 107 | def stack_features(data_loader, eval_model, num_feats, batch_size, quantize, world_size, DDP, device, disable_tqdm): |
| 108 | eval_model.eval() |
| 109 | data_iter = iter(data_loader) |
| 110 | num_batches = math.ceil(float(num_feats) / float(batch_size)) |
| 111 | if DDP: num_batches = num_batches//world_size + 1 |
| 112 | |
| 113 | real_feats, real_probs, real_labels = [], [], [] |
| 114 | for i in tqdm(range(0, num_batches), disable=disable_tqdm): |
| 115 | start = i * batch_size |
| 116 | end = start + batch_size |
| 117 | try: |
| 118 | images, labels = next(data_iter) |
| 119 | except StopIteration: |
| 120 | break |
| 121 | |
| 122 | images, labels = images.to(device), labels.to(device) |
| 123 | |
| 124 | with torch.no_grad(): |
| 125 | embeddings, logits = eval_model.get_outputs(images, quantize=quantize) |
| 126 | probs = torch.nn.functional.softmax(logits, dim=1) |
| 127 | real_feats.append(embeddings) |
| 128 | real_probs.append(probs) |
| 129 | real_labels.append(labels) |
| 130 | |
| 131 | real_feats = torch.cat(real_feats, dim=0) |
| 132 | real_probs = torch.cat(real_probs, dim=0) |
| 133 | real_labels = torch.cat(real_labels, dim=0) |
| 134 | if DDP: |
| 135 | real_feats = torch.cat(losses.GatherLayer.apply(real_feats), dim=0) |
| 136 | real_probs = torch.cat(losses.GatherLayer.apply(real_probs), dim=0) |
| 137 | real_labels = torch.cat(losses.GatherLayer.apply(real_labels), dim=0) |
| 138 | |
| 139 | real_feats = real_feats.detach().cpu().numpy().astype(np.float64) |
| 140 | real_probs = real_probs.detach().cpu().numpy().astype(np.float64) |
| 141 | real_labels = real_labels.detach().cpu().numpy() |
| 142 | return real_feats, real_probs, real_labels |
nothing calls this directly
no test coverage detected