MCPcopy Index your code
hub / github.com/POSTECH-CVLab/PyTorch-StudioGAN / stack_features

Function stack_features

src/metrics/features.py:107–142  ·  view source on GitHub ↗
(data_loader, eval_model, num_feats, batch_size, quantize, world_size, DDP, device, disable_tqdm)

Source from the content-addressed store, hash-verified

105
106
107def stack_features(data_loader, eval_model, num_feats, batch_size, quantize, world_size, DDP, device, disable_tqdm):
108 eval_model.eval()
109 data_iter = iter(data_loader)
110 num_batches = math.ceil(float(num_feats) / float(batch_size))
111 if DDP: num_batches = num_batches//world_size + 1
112
113 real_feats, real_probs, real_labels = [], [], []
114 for i in tqdm(range(0, num_batches), disable=disable_tqdm):
115 start = i * batch_size
116 end = start + batch_size
117 try:
118 images, labels = next(data_iter)
119 except StopIteration:
120 break
121
122 images, labels = images.to(device), labels.to(device)
123
124 with torch.no_grad():
125 embeddings, logits = eval_model.get_outputs(images, quantize=quantize)
126 probs = torch.nn.functional.softmax(logits, dim=1)
127 real_feats.append(embeddings)
128 real_probs.append(probs)
129 real_labels.append(labels)
130
131 real_feats = torch.cat(real_feats, dim=0)
132 real_probs = torch.cat(real_probs, dim=0)
133 real_labels = torch.cat(real_labels, dim=0)
134 if DDP:
135 real_feats = torch.cat(losses.GatherLayer.apply(real_feats), dim=0)
136 real_probs = torch.cat(losses.GatherLayer.apply(real_probs), dim=0)
137 real_labels = torch.cat(losses.GatherLayer.apply(real_labels), dim=0)
138
139 real_feats = real_feats.detach().cpu().numpy().astype(np.float64)
140 real_probs = real_probs.detach().cpu().numpy().astype(np.float64)
141 real_labels = real_labels.detach().cpu().numpy()
142 return real_feats, real_probs, real_labels

Callers

nothing calls this directly

Calls 2

evalMethod · 0.80
get_outputsMethod · 0.80

Tested by

no test coverage detected