(
kernelName: string, inputs: NamedTensorMap, outputs: Tensor[],
gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap)
| 1002 | } |
| 1003 | |
| 1004 | private addTapeNode( |
| 1005 | kernelName: string, inputs: NamedTensorMap, outputs: Tensor[], |
| 1006 | gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap): void { |
| 1007 | const tapeNode: TapeNode = |
| 1008 | {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved}; |
| 1009 | |
| 1010 | const gradConfig = getGradient(kernelName); |
| 1011 | if (gradConfig != null) { |
| 1012 | gradientsFunc = gradConfig.gradFunc; |
| 1013 | } |
| 1014 | if (gradientsFunc != null) { |
| 1015 | tapeNode.gradient = (dys: Tensor[]) => { |
| 1016 | // TODO(smilkov): To optimize back-prop, pass dys that are not used in |
| 1017 | // the backprop graph to the user as null instead of zeros |
| 1018 | dys = dys.map((dy, i) => { |
| 1019 | if (dy == null) { |
| 1020 | const output = outputs[i]; |
| 1021 | const vals = util.makeZerosTypedArray(output.size, output.dtype); |
| 1022 | return this.makeTensor(vals, output.shape, output.dtype); |
| 1023 | } |
| 1024 | return dy; |
| 1025 | }); |
| 1026 | // Grad functions of ops with single outputs expect a dy, while ops |
| 1027 | // with multiple outputs expect dys (array of dy). |
| 1028 | return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); |
| 1029 | }; |
| 1030 | } |
| 1031 | this.state.activeTape.push(tapeNode); |
| 1032 | } |
| 1033 | |
| 1034 | keep<T extends Tensor>(result: T): T { |
| 1035 | result.kept = true; |
no test coverage detected