MCPcopy
hub / github.com/XPixelGroup/DiffBIR / build_openset_label_embedding

Function build_openset_label_embedding

ram/utils/openset_utils.py:293–329  ·  view source on GitHub ↗
(categories=None)

Source from the content-addressed store, hash-verified

291
292
293def build_openset_label_embedding(categories=None):
294 if categories is None:
295 categories = openimages_rare_unseen
296 print("Creating pretrained CLIP model")
297 model, _ = clip.load("ViT-B/16")
298 templates = multiple_templates
299
300 run_on_gpu = torch.cuda.is_available()
301
302 with torch.no_grad():
303 openset_label_embedding = []
304 for category in categories:
305 texts = [
306 template.format(
307 processed_name(category, rm_dot=True), article=article(category)
308 )
309 for template in templates
310 ]
311 texts = [
312 "This is " + text if text.startswith("a") or text.startswith("the") else text
313 for text in texts
314 ]
315 texts = clip.tokenize(texts) # tokenize
316 if run_on_gpu:
317 texts = texts.cuda()
318 model = model.cuda()
319 text_embeddings = model.encode_text(texts)
320 text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
321 text_embedding = text_embeddings.mean(dim=0)
322 text_embedding /= text_embedding.norm()
323 openset_label_embedding.append(text_embedding)
324 openset_label_embedding = torch.stack(openset_label_embedding, dim=1)
325 if run_on_gpu:
326 openset_label_embedding = openset_label_embedding.cuda()
327
328 openset_label_embedding = openset_label_embedding.t()
329 return openset_label_embedding, categories
330
331
332

Callers

nothing calls this directly

Calls 4

processed_nameFunction · 0.85
articleFunction · 0.85
encode_textMethod · 0.80
tMethod · 0.80

Tested by

no test coverage detected