Estimates pass@k of each problem and returns them in an array.
(
num_samples: Union[int, List[int], np.ndarray],
num_correct: Union[List[int], np.ndarray],
k: int
)
| 433 | |
| 434 | |
| 435 | def estimate_pass_at_k( |
| 436 | num_samples: Union[int, List[int], np.ndarray], |
| 437 | num_correct: Union[List[int], np.ndarray], |
| 438 | k: int |
| 439 | ) -> np.ndarray: |
| 440 | """ |
| 441 | Estimates pass@k of each problem and returns them in an array. |
| 442 | """ |
| 443 | |
| 444 | def estimator(n: int, c: int, k: int) -> float: |
| 445 | """ |
| 446 | Calculates 1 - comb(n - c, k) / comb(n, k). |
| 447 | """ |
| 448 | if n - c < k: |
| 449 | return 1.0 |
| 450 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) |
| 451 | |
| 452 | if isinstance(num_samples, int): |
| 453 | num_samples_it = itertools.repeat(num_samples, len(num_correct)) |
| 454 | else: |
| 455 | assert len(num_samples) == len(num_correct) |
| 456 | num_samples_it = iter(num_samples) |
| 457 | |
| 458 | return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]) |
| 459 | |
| 460 | |
| 461 | class Logger: |
no test coverage detected