| 1175 | }); |
| 1176 | |
| 1177 | const forwardFunc: ForwardFunc<T> = (_, save) => { |
| 1178 | res = f(...[...inputs, save]); |
| 1179 | util.assert( |
| 1180 | res.value instanceof Tensor, |
| 1181 | () => 'The function f passed in customGrad(f) must return an ' + |
| 1182 | 'object where `obj.value` is a tensor'); |
| 1183 | util.assert( |
| 1184 | util.isFunction(res.gradFunc), |
| 1185 | () => 'The function f passed in customGrad(f) must return an ' + |
| 1186 | 'object where `obj.gradFunc` is a function.'); |
| 1187 | return res.value; |
| 1188 | }; |
| 1189 | |
| 1190 | const backwardsFunc = (dy: T, saved: Tensor[]) => { |
| 1191 | const gradRes = res.gradFunc(dy, saved); |