* Overrides the gradient computation of a function `f`. * * Takes a function * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}` * and returns another function `g(...inputs)` which takes the same inputs as * `f`. When called, `g` returns `f().value`. In backward mode, c
(f: CustomGradientFunc<T>)
| 372 | * @doc {heading: 'Training', subheading: 'Gradients'} |
| 373 | */ |
| 374 | function customGrad<T extends Tensor>(f: CustomGradientFunc<T>): |
| 375 | (...args: Tensor[]) => T { |
| 376 | return ENGINE.customGrad(f); |
| 377 | } |
| 378 | |
| 379 | function checkGrads(grads: Tensor[]) { |
| 380 | const numNullGradients = grads.filter(g => g == null).length; |
no test coverage detected
searching dependent graphs…