MCPcopy
hub / github.com/tensorflow/tfjs / binaryKernelFunc

Function binaryKernelFunc

tfjs-backend-cpu/src/utils/binary_utils.ts:38–129  ·  view source on GitHub ↗
(
    name: string, simpleImpl: SimpleBinaryKernelImpl,
    complexImpl?: ComplexBinaryKernelImpl, dtype?: DataType)

Source from the content-addressed store, hash-verified

36 * comparison kernels, such as Equal, Less, Greater, etc.
37 */
38export function binaryKernelFunc(
39 name: string, simpleImpl: SimpleBinaryKernelImpl,
40 complexImpl?: ComplexBinaryKernelImpl, dtype?: DataType): KernelFunc {
41 if (complexImpl == null) {
42 return ({inputs, backend}) => {
43 const {a, b} = inputs as BinaryInputs;
44 const cpuBackend = backend as MathBackendCPU;
45
46 assertNotComplex([a, b], name);
47
48 const aVals = cpuBackend.data.get(a.dataId).values as TypedArray;
49 const bVals = cpuBackend.data.get(b.dataId).values as TypedArray;
50
51 const decodedAVals = a.dtype === 'string' ?
52 // tslint:disable-next-line: no-any
53 backend_util.fromUint8ToStringArray(aVals as any as Uint8Array[]) :
54 aVals;
55 const decodedBVals = a.dtype === 'string' ?
56 // tslint:disable-next-line: no-any
57 backend_util.fromUint8ToStringArray(bVals as any as Uint8Array[]) :
58 bVals;
59 const $dtype = dtype || a.dtype;
60
61 const [resultData, resultShape] =
62 simpleImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);
63
64 return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
65 };
66 }
67
68 return ({inputs, backend}) => {
69 const {a, b} = inputs as BinaryInputs;
70 const cpuBackend = backend as MathBackendCPU;
71
72 if (a.dtype === 'complex64' || b.dtype === 'complex64') {
73 const $aComplex = cast(
74 {inputs: {x: a}, backend: cpuBackend, attrs: {dtype: 'complex64'}});
75
76 const $aComplexVals = cpuBackend.data.get($aComplex.dataId);
77
78 const aReal = $aComplexVals.complexTensorInfos.real;
79 const aImag = $aComplexVals.complexTensorInfos.imag;
80
81 const aRealVals =
82 cpuBackend.data.get(aReal.dataId).values as Float32Array;
83 const aImagVals =
84 cpuBackend.data.get(aImag.dataId).values as Float32Array;
85
86 const $bComplex = cast(
87 {inputs: {x: b}, backend: cpuBackend, attrs: {dtype: 'complex64'}});
88
89 const $bComplexVals = cpuBackend.data.get($bComplex.dataId);
90
91 const bReal = $bComplexVals.complexTensorInfos.real;
92 const bImag = $bComplexVals.complexTensorInfos.imag;
93
94 const bRealVals =
95 cpuBackend.data.get(bReal.dataId).values as Float32Array;

Callers 15

Atan2.tsFile · 0.90
LogicalAnd.tsFile · 0.90
Greater.tsFile · 0.90
BitwiseAnd.tsFile · 0.90
Pow.tsFile · 0.90
Less.tsFile · 0.90
Equal.tsFile · 0.90
Multiply.tsFile · 0.90
Mod.tsFile · 0.90
Maximum.tsFile · 0.90
Add.tsFile · 0.90
GreaterEqual.tsFile · 0.90

Calls 6

assertNotComplexFunction · 0.90
castFunction · 0.90
complexFunction · 0.90
getMethod · 0.45
makeTensorInfoMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…