(logits, labels, k, reduction="mean")
| 183 | |
| 184 | |
| 185 | def compute_top_k(logits, labels, k, reduction="mean"): |
| 186 | _, top_ks = th.topk(logits, k, dim=-1) |
| 187 | if reduction == "mean": |
| 188 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() |
| 189 | elif reduction == "none": |
| 190 | return (top_ks == labels[:, None]).float().sum(dim=-1) |
| 191 | |
| 192 | |
| 193 | def split_microbatches(microbatch, *args): |
no outgoing calls
no test coverage detected