MCPcopy
hub / github.com/tensorflow/tfjs / customGrad

Function customGrad

tfjs-core/src/gradients.ts:374–377  ·  view source on GitHub ↗

* Overrides the gradient computation of a function `f`. * * Takes a function * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}` * and returns another function `g(...inputs)` which takes the same inputs as * `f`. When called, `g` returns `f().value`. In backward mode, c

(f: CustomGradientFunc<T>)

Source from the content-addressed store, hash-verified

372 * @doc {heading: 'Training', subheading: 'Gradients'}
373 */
374function customGrad<T extends Tensor>(f: CustomGradientFunc<T>):
375 (...args: Tensor[]) => T {
376 return ENGINE.customGrad(f);
377}
378
379function checkGrads(grads: Tensor[]) {
380 const numNullGradients = grads.filter(g => g == null).length;

Callers 6

logSoftmax_Function · 0.90
logSigmoid_Function · 0.90
fusedDepthwiseConv2d_Function · 0.90
fusedConv2d_Function · 0.90
fusedMatMul_Function · 0.90

Calls 1

customGradMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…