| 18 | import {BackendValues, DataType, TensorBuffer, TypedArray, util} from '@tensorflow/tfjs-core'; |
| 19 | |
| 20 | export function uniqueImpl( |
| 21 | values: BackendValues, axis: number, shape: number[], dtype: DataType): { |
| 22 | outputValues: BackendValues, |
| 23 | outputShape: number[], |
| 24 | indices: BackendValues |
| 25 | } { |
| 26 | // Normalize and validate axis. |
| 27 | const $axis = util.parseAxisParam(axis, shape)[0]; |
| 28 | |
| 29 | // Calculate the new shape that is suitable for extracting data along the |
| 30 | // given axis. |
| 31 | // |
| 32 | // The rank is 3. |
| 33 | // The size of the 1st dimension is the size of all the axes < the given axis. |
| 34 | // The size of the 2nd dimension is the same as the size of the given axis. |
| 35 | // The size of the 3rd dimension is the size of all the axes > the given axis. |
| 36 | // |
| 37 | // For example, for a 4D tensor with shape=[2, 3, 5, 4] and axis=2, the |
| 38 | // newShape would be: [2*3, 5, 4]. |
| 39 | // |
| 40 | // Note that this is not the final output shape. This will be the shape for an |
| 41 | // intermediate TensorBuffer (see inputBuffer below) to allow us to extract |
| 42 | // values along the given axis. To demonstrate how it works, consider the |
| 43 | // following example: |
| 44 | // |
| 45 | // Input: a 3D tensor, with shape [1, 2, 3] |
| 46 | // [ |
| 47 | // [ |
| 48 | // [1,2,3], |
| 49 | // [4,5,6] |
| 50 | // ] |
| 51 | // ] |
| 52 | // Axis: 2 (the last axis). |
| 53 | // Along axis 2, we expect to extract 3 tensors: [1,4], [2,5], [3,6]. |
| 54 | // |
| 55 | // For this example, newShape would be: [2, 3, 1], where 2 is calculated from |
| 56 | // 1*2. The re-shaped data would look like: |
| 57 | // |
| 58 | // [ |
| 59 | // [ |
| 60 | // [1], [2], [3] |
| 61 | // ], |
| 62 | // [ |
| 63 | // [4], [5], [6] |
| 64 | // ] |
| 65 | // ] |
| 66 | // |
| 67 | // Then, we can construct a 3-level nested loop by the following dimension |
| 68 | // order to extract the values along the axis (dimension1): |
| 69 | // i: dimension1 // 0,1,2 (newShape[1]) |
| 70 | // m: dimension0 // 0,1 (newShape[0]) |
| 71 | // n: dimension2 // 0 (newShape[2]) |
| 72 | // |
| 73 | // m, i, n |
| 74 | // --------- |
| 75 | // Iteration 0: data at [0, 0, 0] => "1" |
| 76 | // Iteration 1: data at [1, 0, 0] => "4" |
| 77 | // We got [1,4]. |