| 161 | |
| 162 | @torch.no_grad() |
| 163 | def forward(self, pixel_values): |
| 164 | g_cache_manager.cache_env_in() |
| 165 | input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) |
| 166 | for i in range(self.layers_num + self.select_layer + 1): |
| 167 | input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i]) |
| 168 | input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight) |
| 169 | g_cache_manager.cache_env_out() |
| 170 | return input_embs |
| 171 | |
| 172 | @torch.no_grad() |
| 173 | def encode(self, images: List[ImageItem]): |