MCPcopy
hub / github.com/mseitzer/pytorch-fid / get_activations

Function get_activations

src/pytorch_fid/fid_score.py:93–149  ·  view source on GitHub ↗

Calculates the activations of the pool_3 layer for all images. Params: -- files : List of image files paths -- model : Instance of inception model -- batch_size : Batch size of images for the model to process at once. Make sure that the number of sa

(files, model, batch_size=50, dims=2048, device='cpu',
                    num_workers=1)

Source from the content-addressed store, hash-verified

91
92
93def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
94 num_workers=1):
95 """Calculates the activations of the pool_3 layer for all images.
96
97 Params:
98 -- files : List of image files paths
99 -- model : Instance of inception model
100 -- batch_size : Batch size of images for the model to process at once.
101 Make sure that the number of samples is a multiple of
102 the batch size, otherwise some samples are ignored. This
103 behavior is retained to match the original FID score
104 implementation.
105 -- dims : Dimensionality of features returned by Inception
106 -- device : Device to run calculations
107 -- num_workers : Number of parallel dataloader workers
108
109 Returns:
110 -- A numpy array of dimension (num images, dims) that contains the
111 activations of the given tensor when feeding inception with the
112 query tensor.
113 """
114 model.eval()
115
116 if batch_size > len(files):
117 print(('Warning: batch size is bigger than the data size. '
118 'Setting batch size to data size'))
119 batch_size = len(files)
120
121 dataset = ImagePathDataset(files, transforms=TF.ToTensor())
122 dataloader = torch.utils.data.DataLoader(dataset,
123 batch_size=batch_size,
124 shuffle=False,
125 drop_last=False,
126 num_workers=num_workers)
127
128 pred_arr = np.empty((len(files), dims))
129
130 start_idx = 0
131
132 for batch in tqdm(dataloader):
133 batch = batch.to(device)
134
135 with torch.no_grad():
136 pred = model(batch)[0]
137
138 # If model output is not scalar, apply global spatial average pooling.
139 # This happens if you choose a dimensionality not equal 2048.
140 if pred.size(2) != 1 or pred.size(3) != 1:
141 pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
142
143 pred = pred.squeeze(3).squeeze(2).cpu().numpy()
144
145 pred_arr[start_idx:start_idx + pred.shape[0]] = pred
146
147 start_idx = start_idx + pred.shape[0]
148
149 return pred_arr
150

Callers 1

Calls 2

ImagePathDatasetClass · 0.85
tqdmFunction · 0.85

Tested by

no test coverage detected