MCPcopy
hub / github.com/tensorflow/tfjs / applyGradients

Method applyGradients

tfjs-core/src/optimizers/adam_optimizer.ts:64–123  ·  view source on GitHub ↗
(variableGradients: NamedVariableMap|NamedTensor[])

Source from the content-addressed store, hash-verified

62 }
63
64 applyGradients(variableGradients: NamedVariableMap|NamedTensor[]) {
65 const varNames = Array.isArray(variableGradients) ?
66 variableGradients.map(v => v.name) :
67 Object.keys(variableGradients);
68 tidy(() => {
69 const oneMinusAccBeta1 = sub(1, this.accBeta1);
70 const oneMinusAccBeta2 = sub(1, this.accBeta2);
71
72 varNames.forEach((name, i) => {
73 const value = ENGINE.registeredVariables[name];
74 const trainable = false;
75 if (this.accumulatedFirstMoment[i] == null) {
76 this.accumulatedFirstMoment[i] = {
77 originalName: `${name}/m`,
78 variable: tidy(() => zerosLike(value).variable(trainable))
79 };
80 }
81 if (this.accumulatedSecondMoment[i] == null) {
82 this.accumulatedSecondMoment[i] = {
83 originalName: `${name}/v`,
84 variable: tidy(() => zerosLike(value).variable(trainable))
85 };
86 }
87
88 const gradient = Array.isArray(variableGradients) ?
89 variableGradients[i].tensor :
90 variableGradients[name];
91 if (gradient == null) {
92 return;
93 }
94
95 const firstMoment = this.accumulatedFirstMoment[i].variable;
96 const secondMoment = this.accumulatedSecondMoment[i].variable;
97
98 const newFirstMoment =
99 add(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
100 const newSecondMoment =
101 add(mul(secondMoment, this.beta2),
102 mul(square(gradient), 1 - this.beta2));
103
104 const biasCorrectedFirstMoment = div(newFirstMoment, oneMinusAccBeta1);
105 const biasCorrectedSecondMoment =
106 div(newSecondMoment, oneMinusAccBeta2);
107
108 firstMoment.assign(newFirstMoment);
109 secondMoment.assign(newSecondMoment);
110
111 const newValue =
112 add(mul(div(biasCorrectedFirstMoment,
113 add(sqrt(biasCorrectedSecondMoment), this.epsilon)),
114 -this.learningRate),
115 value);
116 value.assign(newValue);
117 });
118
119 this.accBeta1.assign(mul(this.accBeta1, this.beta1));
120 this.accBeta2.assign(mul(this.accBeta2, this.beta2));
121 });

Callers

nothing calls this directly

Calls 6

tidyFunction · 0.90
squareFunction · 0.85
variableMethod · 0.80
assignMethod · 0.80
zerosLikeFunction · 0.50
addFunction · 0.50

Tested by

no test coverage detected