* Partitions array where all elements smaller than the (k+1) smallest element * are found to the left of it, and all larger to the right of it. * Based on the Floyd-Rivest Algorithm, ref: * https://en.wikipedia.org/wiki/Floyd%E2%80%93Rivest_algorithm * @param array: Array to partition * @param
(array: Pair[], k: number, left = 0, right = array.length - 1)
| 41 | * when left = 0 |
| 42 | */ |
| 43 | function select(array: Pair[], k: number, left = 0, right = array.length - 1) { |
| 44 | while (right > left) { |
| 45 | // Use select recursively to sample a smaller set of size s |
| 46 | // the arbitrary constants 600 and 0.5 are used in the original |
| 47 | // version to minimize execution time. |
| 48 | if (right - left > 600) { |
| 49 | const n = right - left + 1; |
| 50 | const i = k - left + 1; |
| 51 | const z = Math.log(n); |
| 52 | const s = 0.5 * Math.exp(2 * z / 3); |
| 53 | const sd = 0.5 * Math.sqrt(z * s * (n - s) / n) * Math.sign(i - n / 2); |
| 54 | const newLeft = Math.max(left, Math.floor(k - i * s / n + sd)); |
| 55 | const newRight = Math.min(right, Math.floor(k + (n - i) * s / n + sd)); |
| 56 | select(array, k, newLeft, newRight); |
| 57 | } |
| 58 | // partition the elements between left and right around t |
| 59 | const t = array[k]; |
| 60 | let i = left; |
| 61 | let j = right; |
| 62 | |
| 63 | util.swap(array, left, k); |
| 64 | |
| 65 | if (comparePair(array[right], t) > 0) { |
| 66 | util.swap(array, left, right); |
| 67 | } |
| 68 | while (i < j) { |
| 69 | util.swap(array, i, j); |
| 70 | i++; |
| 71 | j--; |
| 72 | while (comparePair(array[i], t) < 0) { |
| 73 | i = i + 1; |
| 74 | } |
| 75 | while (comparePair(array[j], t) > 0) { |
| 76 | j = j - 1; |
| 77 | } |
| 78 | } |
| 79 | if (comparePair(array[left], t) === 0) { |
| 80 | util.swap(array, left, j); |
| 81 | } else { |
| 82 | j = j + 1; |
| 83 | util.swap(array, j, right); |
| 84 | } |
| 85 | // Adjust left and right towards the boundaries of the subset |
| 86 | // containing the (k - left + 1)th smallest element. |
| 87 | if (j <= k) { |
| 88 | left = j + 1; |
| 89 | } |
| 90 | if (k <= j) { |
| 91 | right = j - 1; |
| 92 | } |
| 93 | } |
| 94 | } |
| 95 | |
| 96 | export function topKImpl<T extends Tensor, R extends Rank>( |
| 97 | x: TypedArray, xShape: number[], xDtype: NumericDataType, k: number, |
no test coverage detected
searching dependent graphs…