MCPcopy
hub / github.com/Audio-AGI/AudioSep / zero_shot_classifier

Function zero_shot_classifier

models/CLAP/training/zero_shot.py:13–27  ·  view source on GitHub ↗
(model, classnames, templates, args)

Source from the content-addressed store, hash-verified

11
12
13def zero_shot_classifier(model, classnames, templates, args):
14 with torch.no_grad():
15 zeroshot_weights = []
16 for classname in tqdm(classnames):
17 texts = [template(classname) for template in templates] # format with class
18 texts = tokenize(texts).to(args.device) # tokenize
19 if args.distributed and not args.horovod:
20 class_embeddings = model.module.encode_text(texts)
21 else:
22 class_embeddings = model.encode_text(texts)
23 class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
24 class_embedding /= class_embedding.norm()
25 zeroshot_weights.append(class_embedding)
26 zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device)
27 return zeroshot_weights
28
29
30def accuracy(output, target, topk=(1,)):

Callers 1

zero_shot_evalFunction · 0.85

Calls 3

tokenizeFunction · 0.90
encode_textMethod · 0.80
appendMethod · 0.80

Tested by

no test coverage detected