| 253 | } |
| 254 | |
| 255 | async function convertToRowMajor(inputValues: number[][]| |
| 256 | tf.Tensor2D): Promise<number[][]> { |
| 257 | let originalShape: number[]; |
| 258 | let transposed: tf.Tensor2D; |
| 259 | if (inputValues instanceof tf.Tensor) { |
| 260 | originalShape = inputValues.shape; |
| 261 | transposed = inputValues.transpose(); |
| 262 | } else { |
| 263 | originalShape = [inputValues.length, inputValues[0].length]; |
| 264 | transposed = tf.tidy(() => tf.tensor2d(inputValues).transpose()); |
| 265 | } |
| 266 | |
| 267 | assert( |
| 268 | transposed.rank === 2, |
| 269 | 'Input to renderHeatmap must be a 2d array or Tensor2d'); |
| 270 | |
| 271 | // Download the intermediate tensor values and |
| 272 | // dispose the transposed tensor. |
| 273 | const transposedValues = await transposed.array(); |
| 274 | transposed.dispose(); |
| 275 | |
| 276 | const transposedShape = [transposedValues.length, transposedValues[0].length]; |
| 277 | assert( |
| 278 | originalShape[0] === transposedShape[1] && |
| 279 | originalShape[1] === transposedShape[0], |
| 280 | `Unexpected transposed shape. Original ${originalShape} : Transposed ${ |
| 281 | transposedShape}`); |
| 282 | return transposedValues; |
| 283 | } |
| 284 | |
| 285 | function assertLabelsMatchShape( |
| 286 | inputValues: number[][]|tf.Tensor2D, labels: string[], dimension: 0|1) { |