call function as forward Args: image: type: torch.Tensor shape: batch_size * 3 * 384 * 384 caption: type: list[string] len: batch_size tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative samp
(self, image, caption, image_tag, clip_feature, batch_text_embed)
| 199 | del layer.attention |
| 200 | |
| 201 | def forward(self, image, caption, image_tag, clip_feature, batch_text_embed): |
| 202 | """ |
| 203 | call function as forward |
| 204 | |
| 205 | Args: |
| 206 | image: type: torch.Tensor shape: batch_size * 3 * 384 * 384 |
| 207 | caption: type: list[string] len: batch_size |
| 208 | tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0 |
| 209 | |
| 210 | Returns: |
| 211 | loss: type: torch.Tensor |
| 212 | """ |
| 213 | |
| 214 | image_embeds = self.image_proj(self.visual_encoder(image)) |
| 215 | image_atts = torch.ones(image_embeds.size()[:-1], |
| 216 | dtype=torch.long).to(image.device) |
| 217 | |
| 218 | ##================= Distillation from CLIP ================## |
| 219 | image_cls_embeds = image_embeds[:, 0, :] |
| 220 | image_spatial_embeds = image_embeds[:, 1:, :] |
| 221 | |
| 222 | loss_dis = F.l1_loss(image_cls_embeds, clip_feature) |
| 223 | |
| 224 | ###===========multi tag des reweight==============### |
| 225 | bs = image_embeds.shape[0] |
| 226 | |
| 227 | des_per_class = int(self.label_embed.shape[0] / self.num_class) |
| 228 | |
| 229 | image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True) |
| 230 | reweight_scale = self.reweight_scale.exp() |
| 231 | logits_per_image = (reweight_scale * image_cls_embeds @ self.label_embed.t()) |
| 232 | logits_per_image = logits_per_image.view(bs, -1,des_per_class) |
| 233 | |
| 234 | weight_normalized = F.softmax(logits_per_image, dim=2) |
| 235 | label_embed_reweight = torch.empty(bs, self.num_class, 512).to(image.device).to(image.dtype) |
| 236 | |
| 237 | for i in range(bs): |
| 238 | reshaped_value = self.label_embed.view(-1, des_per_class, 512) |
| 239 | product = weight_normalized[i].unsqueeze(-1) * reshaped_value |
| 240 | label_embed_reweight[i] = product.sum(dim=1) |
| 241 | |
| 242 | label_embed = torch.nn.functional.relu(self.wordvec_proj(label_embed_reweight)) |
| 243 | |
| 244 | ##================= Image Tagging ================## |
| 245 | |
| 246 | tagging_embed = self.tagging_head( |
| 247 | encoder_embeds=label_embed, |
| 248 | encoder_hidden_states=image_embeds, |
| 249 | encoder_attention_mask=image_atts, |
| 250 | return_dict=False, |
| 251 | mode='tagging', |
| 252 | ) |
| 253 | |
| 254 | logits = self.fc(tagging_embed[0]).squeeze(-1) |
| 255 | |
| 256 | loss_tag = self.tagging_loss_function(logits, image_tag) |
| 257 | |
| 258 | ##================= Image-text Alignment ================## |