* Like `tf.grad`, but also returns the value of `f()`. Useful when `f()` * returns a metric you want to show. * * The result is a rich object with the following properties: * - grad: The gradient of `f(x)` w.r.t. `x` (result of `tf.grad`). * - value: The value returned by `f(x)`. * * ```js *
(f: (x: I) => O)
| 162 | * @doc {heading: 'Training', subheading: 'Gradients'} |
| 163 | */ |
| 164 | function valueAndGrad<I extends Tensor, O extends Tensor>(f: (x: I) => O): ( |
| 165 | x: I, dy?: O) => { |
| 166 | value: O; |
| 167 | grad: I; |
| 168 | } { |
| 169 | util.assert( |
| 170 | util.isFunction(f), |
| 171 | () => 'The f passed in valueAndGrad(f) must be a function'); |
| 172 | return (x: I, dy?: O) => { |
| 173 | util.assert( |
| 174 | x instanceof Tensor, |
| 175 | () => 'The x passed in valueAndGrad(f)(x) must be a tensor'); |
| 176 | util.assert( |
| 177 | dy == null || dy instanceof Tensor, |
| 178 | () => 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor'); |
| 179 | const {grads, value} = ENGINE.gradients(() => f(x), [x], dy); |
| 180 | checkGrads(grads); |
| 181 | return {grad: grads[0] as I, value}; |
| 182 | }; |
| 183 | } |
| 184 | |
| 185 | /** |
| 186 | * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()` |
nothing calls this directly
no test coverage detected
searching dependent graphs…