(self, inp_image, conditional=None, return_features=False, mask=None)
| 344 | |
| 345 | |
| 346 | def forward(self, inp_image, conditional=None, return_features=False, mask=None): |
| 347 | |
| 348 | assert type(return_features) == bool |
| 349 | |
| 350 | inp_image = inp_image.to(self.model.positional_embedding.device) |
| 351 | |
| 352 | if mask is not None: |
| 353 | raise ValueError('mask not supported') |
| 354 | |
| 355 | # x_inp = normalize(inp_image) |
| 356 | x_inp = inp_image |
| 357 | |
| 358 | bs, dev = inp_image.shape[0], x_inp.device |
| 359 | |
| 360 | cond = self.get_cond_vec(conditional, bs) |
| 361 | |
| 362 | visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) |
| 363 | |
| 364 | activation1 = activations[0] |
| 365 | activations = activations[1:] |
| 366 | |
| 367 | _activations = activations[::-1] if not self.rev_activations else activations |
| 368 | |
| 369 | a = None |
| 370 | for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)): |
| 371 | |
| 372 | if a is not None: |
| 373 | a = reduce(activation) + a |
| 374 | else: |
| 375 | a = reduce(activation) |
| 376 | |
| 377 | if i == self.cond_layer: |
| 378 | if self.reduce_cond is not None: |
| 379 | cond = self.reduce_cond(cond) |
| 380 | |
| 381 | a = self.film_mul(cond) * a + self.film_add(cond) |
| 382 | |
| 383 | a = block(a) |
| 384 | |
| 385 | for block in self.extra_blocks: |
| 386 | a = a + block(a) |
| 387 | |
| 388 | a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens |
| 389 | |
| 390 | size = int(math.sqrt(a.shape[2])) |
| 391 | |
| 392 | a = a.view(bs, a.shape[1], size, size) |
| 393 | |
| 394 | a = self.trans_conv(a) |
| 395 | |
| 396 | if self.n_tokens is not None: |
| 397 | a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True) |
| 398 | |
| 399 | if self.upsample_proj is not None: |
| 400 | a = self.upsample_proj(a) |
| 401 | a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') |
| 402 | |
| 403 | if return_features: |
no test coverage detected