(self, x_inp, extract_layers=(), skip=False, mask=None)
| 123 | return torch.cat([self.model.positional_embedding[:1], b]) |
| 124 | |
| 125 | def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): |
| 126 | |
| 127 | |
| 128 | with torch.no_grad(): |
| 129 | |
| 130 | inp_size = x_inp.shape[2:] |
| 131 | |
| 132 | if self.n_tokens is not None: |
| 133 | stride2 = x_inp.shape[2] // self.n_tokens |
| 134 | conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True) |
| 135 | x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation) |
| 136 | else: |
| 137 | x = self.model.conv1(x_inp) # shape = [*, width, grid, grid] |
| 138 | |
| 139 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] |
| 140 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] |
| 141 | |
| 142 | x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] |
| 143 | |
| 144 | standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197 |
| 145 | |
| 146 | if x.shape[1] != standard_n_tokens: |
| 147 | new_shape = int(math.sqrt(x.shape[1]-1)) |
| 148 | x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:] |
| 149 | else: |
| 150 | x = x + self.model.positional_embedding.to(x.dtype) |
| 151 | |
| 152 | x = self.model.ln_pre(x) |
| 153 | |
| 154 | x = x.permute(1, 0, 2) # NLD -> LND |
| 155 | |
| 156 | activations, affinities = [], [] |
| 157 | for i, res_block in enumerate(self.model.transformer.resblocks): |
| 158 | |
| 159 | if mask is not None: |
| 160 | mask_layer, mask_type, mask_tensor = mask |
| 161 | if mask_layer == i or mask_layer == 'all': |
| 162 | # import ipdb; ipdb.set_trace() |
| 163 | size = int(math.sqrt(x.shape[0] - 1)) |
| 164 | |
| 165 | attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size)) |
| 166 | |
| 167 | else: |
| 168 | attn_mask = None |
| 169 | else: |
| 170 | attn_mask = None |
| 171 | |
| 172 | x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask) |
| 173 | |
| 174 | if i in extract_layers: |
| 175 | affinities += [aff_per_head] |
| 176 | |
| 177 | #if self.n_tokens is not None: |
| 178 | # activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)] |
| 179 | #else: |
| 180 | activations += [x] |
| 181 | |
| 182 | if len(extract_layers) > 0 and i == max(extract_layers) and skip: |
no test coverage detected