| 62 | } |
| 63 | |
| 64 | applyGradients(variableGradients: NamedVariableMap|NamedTensor[]) { |
| 65 | const varNames = Array.isArray(variableGradients) ? |
| 66 | variableGradients.map(v => v.name) : |
| 67 | Object.keys(variableGradients); |
| 68 | tidy(() => { |
| 69 | const oneMinusAccBeta1 = sub(1, this.accBeta1); |
| 70 | const oneMinusAccBeta2 = sub(1, this.accBeta2); |
| 71 | |
| 72 | varNames.forEach((name, i) => { |
| 73 | const value = ENGINE.registeredVariables[name]; |
| 74 | const trainable = false; |
| 75 | if (this.accumulatedFirstMoment[i] == null) { |
| 76 | this.accumulatedFirstMoment[i] = { |
| 77 | originalName: `${name}/m`, |
| 78 | variable: tidy(() => zerosLike(value).variable(trainable)) |
| 79 | }; |
| 80 | } |
| 81 | if (this.accumulatedSecondMoment[i] == null) { |
| 82 | this.accumulatedSecondMoment[i] = { |
| 83 | originalName: `${name}/v`, |
| 84 | variable: tidy(() => zerosLike(value).variable(trainable)) |
| 85 | }; |
| 86 | } |
| 87 | |
| 88 | const gradient = Array.isArray(variableGradients) ? |
| 89 | variableGradients[i].tensor : |
| 90 | variableGradients[name]; |
| 91 | if (gradient == null) { |
| 92 | return; |
| 93 | } |
| 94 | |
| 95 | const firstMoment = this.accumulatedFirstMoment[i].variable; |
| 96 | const secondMoment = this.accumulatedSecondMoment[i].variable; |
| 97 | |
| 98 | const newFirstMoment = |
| 99 | add(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1)); |
| 100 | const newSecondMoment = |
| 101 | add(mul(secondMoment, this.beta2), |
| 102 | mul(square(gradient), 1 - this.beta2)); |
| 103 | |
| 104 | const biasCorrectedFirstMoment = div(newFirstMoment, oneMinusAccBeta1); |
| 105 | const biasCorrectedSecondMoment = |
| 106 | div(newSecondMoment, oneMinusAccBeta2); |
| 107 | |
| 108 | firstMoment.assign(newFirstMoment); |
| 109 | secondMoment.assign(newSecondMoment); |
| 110 | |
| 111 | const newValue = |
| 112 | add(mul(div(biasCorrectedFirstMoment, |
| 113 | add(sqrt(biasCorrectedSecondMoment), this.epsilon)), |
| 114 | -this.learningRate), |
| 115 | value); |
| 116 | value.assign(newValue); |
| 117 | }); |
| 118 | |
| 119 | this.accBeta1.assign(mul(this.accBeta1, this.beta1)); |
| 120 | this.accBeta2.assign(mul(this.accBeta2, this.beta2)); |
| 121 | }); |