| 625 | * @doc {heading: 'Data', subheading: 'Operations', namespace: 'data'} |
| 626 | */ |
| 627 | export function zip<O extends tf.TensorContainer>(datasets: DatasetContainer): |
| 628 | Dataset<O> { |
| 629 | // manually type-check the argument for JS users |
| 630 | if (!isIterable(datasets)) { |
| 631 | throw new Error('The argument to zip() must be an object or array.'); |
| 632 | } |
| 633 | let size; |
| 634 | if (Array.isArray(datasets)) { |
| 635 | for (let i = 0; i < datasets.length; i++) { |
| 636 | size = size == null ? (datasets[i] as Dataset<O>).size : |
| 637 | Math.min(size, (datasets[i] as Dataset<O>).size); |
| 638 | } |
| 639 | } else if (datasets instanceof Object) { |
| 640 | for (const ds in datasets) { |
| 641 | size = size == null ? (datasets[ds] as Dataset<O>).size : |
| 642 | Math.min(size, (datasets[ds] as Dataset<O>).size); |
| 643 | } |
| 644 | } |
| 645 | return datasetFromIteratorFn<O>(async () => { |
| 646 | const streams = await deepMapAndAwaitAll(datasets, d => { |
| 647 | if (d instanceof Dataset) { |
| 648 | return {value: d.iterator(), recurse: false}; |
| 649 | } else if (isIterable(d)) { |
| 650 | return {value: null, recurse: true}; |
| 651 | } else { |
| 652 | throw new Error( |
| 653 | 'Leaves of the structure passed to zip() must be Datasets, ' + |
| 654 | 'not primitives.'); |
| 655 | } |
| 656 | }); |
| 657 | return iteratorFromZipped<O>(streams, ZipMismatchMode.SHORTEST); |
| 658 | }, size); |
| 659 | } |
| 660 | |
| 661 | /** |
| 662 | * A zip function for use with deepZip, passed via the columnMajorBatch call. |