MCPcopy
hub / github.com/tensorflow/tfjs / getTensorsForGradient

Method getTensorsForGradient

tfjs-core/src/engine.ts:763–796  ·  view source on GitHub ↗

* Returns a list of tensors to save for a given gradient calculation. * * @param kernelName name of kernel to look up gradient for. * @param inputs a map of input tensors. * @param outputs an array of output tensors from forward mode of kernel.

(
      kernelName: string, inputs: NamedTensorMap,
      outputs: Tensor[])

Source from the content-addressed store, hash-verified

761 * @param outputs an array of output tensors from forward mode of kernel.
762 */
763 private getTensorsForGradient(
764 kernelName: string, inputs: NamedTensorMap,
765 outputs: Tensor[]): Tensor[]|null {
766 const gradConfig = getGradient(kernelName);
767 if (gradConfig != null) {
768 const inputsToSave: string[] = gradConfig.inputsToSave || [];
769 const outputsToSave: boolean[] = gradConfig.outputsToSave || [];
770
771 // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs
772 // specified in inputsToSave will be saved.
773 let inputTensorsToSave: Tensor[];
774 if (gradConfig.saveAllInputs) {
775 util.assert(
776 Array.isArray(inputs),
777 () => 'saveAllInputs is true, expected inputs to be an array.');
778
779 inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]);
780 } else {
781 inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]);
782 }
783
784 const outputTensorsToSave: Tensor[] =
785 outputs.filter((_, i) => outputsToSave[i]);
786
787 return inputTensorsToSave.concat(outputTensorsToSave);
788 }
789 // We return an empty list rather than throw an error because the kernel we
790 // are looking up may not actually be relevant to backproping through the
791 // overall function
792 //
793 // See 'does not error if irrelevant (pruned) ops are missing grads' test
794 // in gradients_test.ts for an example.
795 return [];
796 }
797
798 /**
799 * Internal method used by public APIs for tensor creation. Makes a new

Callers 1

runKernelFuncMethod · 0.95

Calls 2

getGradientFunction · 0.90
concatMethod · 0.65

Tested by

no test coverage detected