(categories=None)
| 291 | |
| 292 | |
| 293 | def build_openset_label_embedding(categories=None): |
| 294 | if categories is None: |
| 295 | categories = openimages_rare_unseen |
| 296 | print("Creating pretrained CLIP model") |
| 297 | model, _ = clip.load("ViT-B/16") |
| 298 | templates = multiple_templates |
| 299 | |
| 300 | run_on_gpu = torch.cuda.is_available() |
| 301 | |
| 302 | with torch.no_grad(): |
| 303 | openset_label_embedding = [] |
| 304 | for category in categories: |
| 305 | texts = [ |
| 306 | template.format( |
| 307 | processed_name(category, rm_dot=True), article=article(category) |
| 308 | ) |
| 309 | for template in templates |
| 310 | ] |
| 311 | texts = [ |
| 312 | "This is " + text if text.startswith("a") or text.startswith("the") else text |
| 313 | for text in texts |
| 314 | ] |
| 315 | texts = clip.tokenize(texts) # tokenize |
| 316 | if run_on_gpu: |
| 317 | texts = texts.cuda() |
| 318 | model = model.cuda() |
| 319 | text_embeddings = model.encode_text(texts) |
| 320 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) |
| 321 | text_embedding = text_embeddings.mean(dim=0) |
| 322 | text_embedding /= text_embedding.norm() |
| 323 | openset_label_embedding.append(text_embedding) |
| 324 | openset_label_embedding = torch.stack(openset_label_embedding, dim=1) |
| 325 | if run_on_gpu: |
| 326 | openset_label_embedding = openset_label_embedding.cuda() |
| 327 | |
| 328 | openset_label_embedding = openset_label_embedding.t() |
| 329 | return openset_label_embedding, categories |
| 330 | |
| 331 | |
| 332 |
nothing calls this directly
no test coverage detected