(self, n_rows, block_m)
| 83 | expected_tokens_per_expt: int = field(default=None) |
| 84 | |
| 85 | def n_blocks(self, n_rows, block_m): |
| 86 | if n_rows <= self.n_expts_tot: |
| 87 | return n_rows |
| 88 | else: |
| 89 | return triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + self.n_expts_tot - 1 |
| 90 | |
| 91 | |
| 92 | # -------------------------- |
no test coverage detected