(clip_vision, image, mask=None, batch_size=0, tiles=1, ratio=1.0, clipvision_size=224)
| 238 | return out |
| 239 | |
| 240 | def encode_image_masked(clip_vision, image, mask=None, batch_size=0, tiles=1, ratio=1.0, clipvision_size=224): |
| 241 | # full image embeds |
| 242 | embeds = encode_image_masked_(clip_vision, image, mask, batch_size, clipvision_size=clipvision_size) |
| 243 | tiles = min(tiles, 16) |
| 244 | |
| 245 | if tiles > 1: |
| 246 | # split in tiles |
| 247 | image_split = split_tiles(image, tiles) |
| 248 | |
| 249 | # get the embeds for each tile |
| 250 | embeds_split = Output() |
| 251 | for i in image_split: |
| 252 | encoded = encode_image_masked_(clip_vision, i, mask, batch_size, clipvision_size=clipvision_size) |
| 253 | if not hasattr(embeds_split, "image_embeds"): |
| 254 | #embeds_split["last_hidden_state"] = encoded["last_hidden_state"] |
| 255 | embeds_split["image_embeds"] = encoded["image_embeds"] |
| 256 | embeds_split["penultimate_hidden_states"] = encoded["penultimate_hidden_states"] |
| 257 | else: |
| 258 | #embeds_split["last_hidden_state"] = torch.cat((embeds_split["last_hidden_state"], encoded["last_hidden_state"]), dim=0) |
| 259 | embeds_split["image_embeds"] = torch.cat((embeds_split["image_embeds"], encoded["image_embeds"]), dim=0) |
| 260 | embeds_split["penultimate_hidden_states"] = torch.cat((embeds_split["penultimate_hidden_states"], encoded["penultimate_hidden_states"]), dim=0) |
| 261 | |
| 262 | #embeds_split['last_hidden_state'] = merge_hiddenstates(embeds_split['last_hidden_state']) |
| 263 | embeds_split["image_embeds"] = merge_embeddings(embeds_split["image_embeds"], tiles) |
| 264 | embeds_split["penultimate_hidden_states"] = merge_hiddenstates(embeds_split["penultimate_hidden_states"], tiles) |
| 265 | |
| 266 | #embeds['last_hidden_state'] = torch.cat([embeds_split['last_hidden_state'], embeds['last_hidden_state']]) |
| 267 | if embeds['image_embeds'].shape[0] > 1: # if we have more than one image we need to average the embeddings for consistency |
| 268 | embeds['image_embeds'] = embeds['image_embeds']*ratio + embeds_split['image_embeds']*(1-ratio) |
| 269 | embeds['penultimate_hidden_states'] = embeds['penultimate_hidden_states']*ratio + embeds_split['penultimate_hidden_states']*(1-ratio) |
| 270 | #embeds['image_embeds'] = (embeds['image_embeds']*ratio + embeds_split['image_embeds']) / 2 |
| 271 | #embeds['penultimate_hidden_states'] = (embeds['penultimate_hidden_states']*ratio + embeds_split['penultimate_hidden_states']) / 2 |
| 272 | else: # otherwise we can concatenate them, they can be averaged later |
| 273 | embeds['image_embeds'] = torch.cat([embeds['image_embeds']*ratio, embeds_split['image_embeds']]) |
| 274 | embeds['penultimate_hidden_states'] = torch.cat([embeds['penultimate_hidden_states']*ratio, embeds_split['penultimate_hidden_states']]) |
| 275 | |
| 276 | #del embeds_split |
| 277 | |
| 278 | return embeds |
| 279 | |
| 280 | def encode_image_masked_(clip_vision, image, mask=None, batch_size=0, clipvision_size=224): |
| 281 | model_management.load_model_gpu(clip_vision.patcher) |
no test coverage detected