* Returns gradients of `f` with respect to each of the `xs`. The gradients * returned are of the same length as `xs`, but some might be null if `f` * was not a function of that `x`. It also takes optional dy to multiply the * gradient, which defaults to `1`.
(
f: () => T, xs: Tensor[], dy?: T,
allowNoGradients = false)
| 1103 | * gradient, which defaults to `1`. |
| 1104 | */ |
| 1105 | gradients<T extends Tensor>( |
| 1106 | f: () => T, xs: Tensor[], dy?: T, |
| 1107 | allowNoGradients = false): {value: T, grads: Tensor[]} { |
| 1108 | util.assert( |
| 1109 | xs.length > 0, () => 'gradients() received an empty list of xs.'); |
| 1110 | if (dy != null && dy.dtype !== 'float32') { |
| 1111 | throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`); |
| 1112 | } |
| 1113 | |
| 1114 | const y = this.scopedRun( |
| 1115 | () => this.startTape(), () => this.endTape(), |
| 1116 | () => this.tidy('forward', f)); |
| 1117 | |
| 1118 | util.assert( |
| 1119 | y instanceof Tensor, |
| 1120 | () => 'The result y returned by f() must be a tensor.'); |
| 1121 | // Filter out the nodes that don't connect x => y. |
| 1122 | const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y); |
| 1123 | if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { |
| 1124 | throw new Error( |
| 1125 | 'Cannot compute gradient of y=f(x) with respect to x. Make sure ' + |
| 1126 | 'that the f you passed encloses all operations that lead from x ' + |
| 1127 | 'to y.'); |
| 1128 | } |
| 1129 | |
| 1130 | return this.tidy('backward', () => { |
| 1131 | const accumulatedGradientMap: {[tensorId: number]: Tensor} = {}; |
| 1132 | accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy; |
| 1133 | |
| 1134 | // Backprop gradients through the filtered nodes. |
| 1135 | backpropagateGradients( |
| 1136 | accumulatedGradientMap, filteredTape, |
| 1137 | // Pass the tidy function to avoid circular dep with `tape.ts`. |
| 1138 | f => this.tidy(f as ScopeFn<Tensor>), |
| 1139 | // Pass an add function to avoide a circular dep with `tape.ts`. |
| 1140 | add); |
| 1141 | const grads = xs.map(x => accumulatedGradientMap[x.id]); |
| 1142 | |
| 1143 | if (this.state.gradientDepth === 0) { |
| 1144 | // This means that we are not computing higher-order gradients |
| 1145 | // and can clean up the tape. |
| 1146 | this.state.activeTape.forEach(node => { |
| 1147 | for (const tensor of node.saved) { |
| 1148 | tensor.dispose(); |
| 1149 | } |
| 1150 | }); |
| 1151 | this.state.activeTape = null; |
| 1152 | } |
| 1153 | return {value: y, grads}; |
| 1154 | }); |
| 1155 | } |
| 1156 | |
| 1157 | customGrad<T extends Tensor>(f: CustomGradientFunc<T>): |
| 1158 | (...args: Array<Tensor|GradSaveFunc>) => T { |
no test coverage detected