MCPcopy
hub / github.com/XPixelGroup/DiffBIR / forward

Method forward

ram/models/ram_plus.py:201–277  ·  view source on GitHub ↗

call function as forward Args: image: type: torch.Tensor shape: batch_size * 3 * 384 * 384 caption: type: list[string] len: batch_size tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative samp

(self, image, caption, image_tag, clip_feature, batch_text_embed)

Source from the content-addressed store, hash-verified

199 del layer.attention
200
201 def forward(self, image, caption, image_tag, clip_feature, batch_text_embed):
202 """
203 call function as forward
204
205 Args:
206 image: type: torch.Tensor shape: batch_size * 3 * 384 * 384
207 caption: type: list[string] len: batch_size
208 tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0
209
210 Returns:
211 loss: type: torch.Tensor
212 """
213
214 image_embeds = self.image_proj(self.visual_encoder(image))
215 image_atts = torch.ones(image_embeds.size()[:-1],
216 dtype=torch.long).to(image.device)
217
218 ##================= Distillation from CLIP ================##
219 image_cls_embeds = image_embeds[:, 0, :]
220 image_spatial_embeds = image_embeds[:, 1:, :]
221
222 loss_dis = F.l1_loss(image_cls_embeds, clip_feature)
223
224 ###===========multi tag des reweight==============###
225 bs = image_embeds.shape[0]
226
227 des_per_class = int(self.label_embed.shape[0] / self.num_class)
228
229 image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True)
230 reweight_scale = self.reweight_scale.exp()
231 logits_per_image = (reweight_scale * image_cls_embeds @ self.label_embed.t())
232 logits_per_image = logits_per_image.view(bs, -1,des_per_class)
233
234 weight_normalized = F.softmax(logits_per_image, dim=2)
235 label_embed_reweight = torch.empty(bs, self.num_class, 512).to(image.device).to(image.dtype)
236
237 for i in range(bs):
238 reshaped_value = self.label_embed.view(-1, des_per_class, 512)
239 product = weight_normalized[i].unsqueeze(-1) * reshaped_value
240 label_embed_reweight[i] = product.sum(dim=1)
241
242 label_embed = torch.nn.functional.relu(self.wordvec_proj(label_embed_reweight))
243
244 ##================= Image Tagging ================##
245
246 tagging_embed = self.tagging_head(
247 encoder_embeds=label_embed,
248 encoder_hidden_states=image_embeds,
249 encoder_attention_mask=image_atts,
250 return_dict=False,
251 mode='tagging',
252 )
253
254 logits = self.fc(tagging_embed[0]).squeeze(-1)
255
256 loss_tag = self.tagging_loss_function(logits, image_tag)
257
258 ##================= Image-text Alignment ================##

Callers

nothing calls this directly

Calls 2

tMethod · 0.80
repeatMethod · 0.80

Tested by

no test coverage detected