MCPcopy
hub / github.com/tensorflow/tfjs / buckets

Method buckets

tfjs-node/src/nodejs_kernel_backend.ts:630–674  ·  view source on GitHub ↗

* Group data into histogram buckets. * * @param data A `Tensor` of any shape. Must be castable to `float32` * @param bucketCount Optional positive `number` * @returns A `Tensor` of shape `[k, 3]` and type `float32`. The `i`th row * is * a triple `[leftEdge, rightEdge, count]`

(data: Tensor, bucketCount?: number)

Source from the content-addressed store, hash-verified

628 * of `k` is either `bucketCount`, `1` or `0`.
629 */
630 private buckets(data: Tensor, bucketCount?: number): Tensor<tf.Rank> {
631 if (data.size === 0) {
632 return tf.tensor([], [0, 3], 'float32');
633 }
634
635 // 30 is the default number of buckets in the TensorFlow Python
636 // implementation. See
637 // https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/summary_v2.py
638 bucketCount = bucketCount !== undefined ? bucketCount : 30;
639 util.assert(
640 Number.isInteger(bucketCount) && bucketCount > 0,
641 () =>
642 `Expected bucket count to be a strictly positive integer, but it was ` +
643 `${bucketCount}`);
644 data = data.flatten();
645 data = data.cast('float32');
646 const min: Scalar = data.min();
647 const max: Scalar = data.max();
648 const range: Scalar = max.sub(min);
649 const isSingular = range.equal(0).arraySync() !== 0;
650
651 if (isSingular) {
652 const center = min;
653 const bucketStart: Scalar = center.sub(0.5);
654 const bucketEnd: Scalar = center.add(0.5);
655 const bucketCounts = tf.scalar(data.size, 'float32');
656 return tf.concat([bucketStart, bucketEnd, bucketCounts]).reshape([1, 3]);
657 }
658
659 const bucketWidth = range.div(bucketCount);
660 const offsets = data.sub(min);
661 const bucketIndices = offsets.floorDiv(bucketWidth).cast('int32');
662 const clampedIndices =
663 tf.minimum(bucketIndices, bucketCount - 1).cast('int32');
664 const oneHots = tf.oneHot(clampedIndices, bucketCount);
665 const bucketCounts = oneHots.sum(0).cast('int32');
666 let edges = tf.linspace(min.arraySync(), max.arraySync(), bucketCount + 1);
667 // Ensure last value in edges is max (TF's linspace op doesn't do this)
668 edges = tf.concat([edges.slice(0, bucketCount), max.reshape([1])], 0) as
669 tf.Tensor1D;
670 const leftEdges = edges.slice(0, bucketCount);
671 const rightEdges = edges.slice(1, bucketCount);
672 return tf.stack([leftEdges, rightEdges, bucketCounts.cast('float32')])
673 .transpose();
674 }
675
676 // ~ TensorBoard-related (tfjs-node-specific) backend kernels.
677 // ------------------------------------------------------------

Callers 1

writeHistogramSummaryMethod · 0.95

Calls 15

flattenMethod · 0.80
minMethod · 0.80
maxMethod · 0.80
subMethod · 0.80
arraySyncMethod · 0.80
equalMethod · 0.80
scalarMethod · 0.80
reshapeMethod · 0.80
divMethod · 0.80
floorDivMethod · 0.80
minimumMethod · 0.80
oneHotMethod · 0.80

Tested by

no test coverage detected