(f: CustomGradientFunc<T>)
| 1155 | } |
| 1156 | |
| 1157 | customGrad<T extends Tensor>(f: CustomGradientFunc<T>): |
| 1158 | (...args: Array<Tensor|GradSaveFunc>) => T { |
| 1159 | util.assert( |
| 1160 | util.isFunction(f), |
| 1161 | () => 'The f passed in customGrad(f) must be a function.'); |
| 1162 | return (...inputs: Tensor[]): T => { |
| 1163 | util.assert( |
| 1164 | inputs.every(t => t instanceof Tensor), |
| 1165 | () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' + |
| 1166 | 'tensors'); |
| 1167 | |
| 1168 | let res: { |
| 1169 | value: T, |
| 1170 | gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[], |
| 1171 | }; |
| 1172 | const inputMap: NamedTensorMap = {}; |
| 1173 | inputs.forEach((input, i) => { |
| 1174 | inputMap[i] = input; |
| 1175 | }); |
| 1176 | |
| 1177 | const forwardFunc: ForwardFunc<T> = (_, save) => { |
| 1178 | res = f(...[...inputs, save]); |
| 1179 | util.assert( |
| 1180 | res.value instanceof Tensor, |
| 1181 | () => 'The function f passed in customGrad(f) must return an ' + |
| 1182 | 'object where `obj.value` is a tensor'); |
| 1183 | util.assert( |
| 1184 | util.isFunction(res.gradFunc), |
| 1185 | () => 'The function f passed in customGrad(f) must return an ' + |
| 1186 | 'object where `obj.gradFunc` is a function.'); |
| 1187 | return res.value; |
| 1188 | }; |
| 1189 | |
| 1190 | const backwardsFunc = (dy: T, saved: Tensor[]) => { |
| 1191 | const gradRes = res.gradFunc(dy, saved); |
| 1192 | const grads: Tensor[] = Array.isArray(gradRes) ? gradRes : [gradRes]; |
| 1193 | util.assert( |
| 1194 | grads.length === inputs.length, |
| 1195 | () => 'The function f passed in customGrad(f) must return an ' + |
| 1196 | 'object where `obj.gradFunc` is a function that returns ' + |
| 1197 | 'the same number of tensors as inputs passed to f(...).'); |
| 1198 | util.assert( |
| 1199 | grads.every(t => t instanceof Tensor), |
| 1200 | () => 'The function f passed in customGrad(f) must return an ' + |
| 1201 | 'object where `obj.gradFunc` is a function that returns ' + |
| 1202 | 'a list of only tensors.'); |
| 1203 | const gradMap: {[key: string]: () => Tensor} = {}; |
| 1204 | grads.forEach((grad, i) => { |
| 1205 | gradMap[i] = () => grad; |
| 1206 | }); |
| 1207 | return gradMap; |
| 1208 | }; |
| 1209 | |
| 1210 | return this.runKernelFunc({ |
| 1211 | forwardFunc, |
| 1212 | backwardsFunc, |
| 1213 | inputs: inputMap, |
| 1214 | }); |
no test coverage detected