* Split the values of a Tensor into the TensorArray. * @param length number[] with the lengths to use when splitting value along * its first dimension. * @param tensor Tensor, the tensor to split.
(length: number[], tensor: Tensor)
| 272 | * @param tensor Tensor, the tensor to split. |
| 273 | */ |
| 274 | split(length: number[], tensor: Tensor) { |
| 275 | if (tensor.dtype !== this.dtype) { |
| 276 | throw new Error(`TensorArray dtype is ${ |
| 277 | this.dtype} but tensor has dtype ${tensor.dtype}`); |
| 278 | } |
| 279 | let totalLength = 0; |
| 280 | const cumulativeLengths = length.map(len => { |
| 281 | totalLength += len; |
| 282 | return totalLength; |
| 283 | }); |
| 284 | |
| 285 | if (totalLength !== tensor.shape[0]) { |
| 286 | throw new Error(`Expected sum of lengths to be equal to |
| 287 | tensor.shape[0], but sum of lengths is |
| 288 | ${totalLength}, and tensor's shape is: ${tensor.shape}`); |
| 289 | } |
| 290 | |
| 291 | if (!this.dynamicSize && length.length !== this.maxSize) { |
| 292 | throw new Error( |
| 293 | `TensorArray's size is not equal to the size of lengths (${ |
| 294 | this.maxSize} vs. ${length.length}), ` + |
| 295 | 'and the TensorArray is not marked as dynamically resizeable'); |
| 296 | } |
| 297 | |
| 298 | const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength; |
| 299 | const tensors: Tensor[] = []; |
| 300 | tidy(() => { |
| 301 | tensor = reshape(tensor, [1, totalLength, elementPerRow]); |
| 302 | for (let i = 0; i < length.length; ++i) { |
| 303 | const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1]; |
| 304 | const indices = [0, previousLength, 0]; |
| 305 | const sizes = [1, length[i], elementPerRow]; |
| 306 | tensors[i] = reshape(slice(tensor, indices, sizes), this.elementShape); |
| 307 | } |
| 308 | return tensors; |
| 309 | }); |
| 310 | const indices = []; |
| 311 | for (let i = 0; i < length.length; i++) { |
| 312 | indices[i] = i; |
| 313 | } |
| 314 | this.writeMany(indices, tensors); |
| 315 | } |
| 316 | } |