MCPcopy
hub / github.com/tensorflow/tfjs / customGrad

Method customGrad

tfjs-core/src/engine.ts:1157–1216  ·  view source on GitHub ↗
(f: CustomGradientFunc<T>)

Source from the content-addressed store, hash-verified

1155 }
1156
1157 customGrad<T extends Tensor>(f: CustomGradientFunc<T>):
1158 (...args: Array<Tensor|GradSaveFunc>) => T {
1159 util.assert(
1160 util.isFunction(f),
1161 () => 'The f passed in customGrad(f) must be a function.');
1162 return (...inputs: Tensor[]): T => {
1163 util.assert(
1164 inputs.every(t => t instanceof Tensor),
1165 () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
1166 'tensors');
1167
1168 let res: {
1169 value: T,
1170 gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[],
1171 };
1172 const inputMap: NamedTensorMap = {};
1173 inputs.forEach((input, i) => {
1174 inputMap[i] = input;
1175 });
1176
1177 const forwardFunc: ForwardFunc<T> = (_, save) => {
1178 res = f(...[...inputs, save]);
1179 util.assert(
1180 res.value instanceof Tensor,
1181 () => 'The function f passed in customGrad(f) must return an ' +
1182 'object where `obj.value` is a tensor');
1183 util.assert(
1184 util.isFunction(res.gradFunc),
1185 () => 'The function f passed in customGrad(f) must return an ' +
1186 'object where `obj.gradFunc` is a function.');
1187 return res.value;
1188 };
1189
1190 const backwardsFunc = (dy: T, saved: Tensor[]) => {
1191 const gradRes = res.gradFunc(dy, saved);
1192 const grads: Tensor[] = Array.isArray(gradRes) ? gradRes : [gradRes];
1193 util.assert(
1194 grads.length === inputs.length,
1195 () => 'The function f passed in customGrad(f) must return an ' +
1196 'object where `obj.gradFunc` is a function that returns ' +
1197 'the same number of tensors as inputs passed to f(...).');
1198 util.assert(
1199 grads.every(t => t instanceof Tensor),
1200 () => 'The function f passed in customGrad(f) must return an ' +
1201 'object where `obj.gradFunc` is a function that returns ' +
1202 'a list of only tensors.');
1203 const gradMap: {[key: string]: () => Tensor} = {};
1204 grads.forEach((grad, i) => {
1205 gradMap[i] = () => grad;
1206 });
1207 return gradMap;
1208 };
1209
1210 return this.runKernelFunc({
1211 forwardFunc,
1212 backwardsFunc,
1213 inputs: inputMap,
1214 });

Callers 3

gradients_test.tsFile · 0.80
customGradFunction · 0.80
squareAndAddFunction · 0.80

Calls 1

runKernelFuncMethod · 0.95

Tested by

no test coverage detected