MCPcopy
hub / github.com/tensorflow/tfjs / addTapeNode

Method addTapeNode

tfjs-core/src/engine.ts:1004–1032  ·  view source on GitHub ↗
(
      kernelName: string, inputs: NamedTensorMap, outputs: Tensor[],
      gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap)

Source from the content-addressed store, hash-verified

1002 }
1003
1004 private addTapeNode(
1005 kernelName: string, inputs: NamedTensorMap, outputs: Tensor[],
1006 gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap): void {
1007 const tapeNode: TapeNode =
1008 {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved};
1009
1010 const gradConfig = getGradient(kernelName);
1011 if (gradConfig != null) {
1012 gradientsFunc = gradConfig.gradFunc;
1013 }
1014 if (gradientsFunc != null) {
1015 tapeNode.gradient = (dys: Tensor[]) => {
1016 // TODO(smilkov): To optimize back-prop, pass dys that are not used in
1017 // the backprop graph to the user as null instead of zeros
1018 dys = dys.map((dy, i) => {
1019 if (dy == null) {
1020 const output = outputs[i];
1021 const vals = util.makeZerosTypedArray(output.size, output.dtype);
1022 return this.makeTensor(vals, output.shape, output.dtype);
1023 }
1024 return dy;
1025 });
1026 // Grad functions of ops with single outputs expect a dy, while ops
1027 // with multiple outputs expect dys (array of dy).
1028 return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
1029 };
1030 }
1031 this.state.activeTape.push(tapeNode);
1032 }
1033
1034 keep<T extends Tensor>(result: T): T {
1035 result.kept = true;

Callers 2

cloneMethod · 0.95
runKernelFuncMethod · 0.95

Calls 3

makeTensorMethod · 0.95
getGradientFunction · 0.90
pushMethod · 0.45

Tested by

no test coverage detected