(model, classnames, templates, args)
| 11 | |
| 12 | |
| 13 | def zero_shot_classifier(model, classnames, templates, args): |
| 14 | with torch.no_grad(): |
| 15 | zeroshot_weights = [] |
| 16 | for classname in tqdm(classnames): |
| 17 | texts = [template(classname) for template in templates] # format with class |
| 18 | texts = tokenize(texts).to(args.device) # tokenize |
| 19 | if args.distributed and not args.horovod: |
| 20 | class_embeddings = model.module.encode_text(texts) |
| 21 | else: |
| 22 | class_embeddings = model.encode_text(texts) |
| 23 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) |
| 24 | class_embedding /= class_embedding.norm() |
| 25 | zeroshot_weights.append(class_embedding) |
| 26 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) |
| 27 | return zeroshot_weights |
| 28 | |
| 29 | |
| 30 | def accuracy(output, target, topk=(1,)): |
no test coverage detected