* Returns a list of tensors to save for a given gradient calculation. * * @param kernelName name of kernel to look up gradient for. * @param inputs a map of input tensors. * @param outputs an array of output tensors from forward mode of kernel.
(
kernelName: string, inputs: NamedTensorMap,
outputs: Tensor[])
| 761 | * @param outputs an array of output tensors from forward mode of kernel. |
| 762 | */ |
| 763 | private getTensorsForGradient( |
| 764 | kernelName: string, inputs: NamedTensorMap, |
| 765 | outputs: Tensor[]): Tensor[]|null { |
| 766 | const gradConfig = getGradient(kernelName); |
| 767 | if (gradConfig != null) { |
| 768 | const inputsToSave: string[] = gradConfig.inputsToSave || []; |
| 769 | const outputsToSave: boolean[] = gradConfig.outputsToSave || []; |
| 770 | |
| 771 | // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs |
| 772 | // specified in inputsToSave will be saved. |
| 773 | let inputTensorsToSave: Tensor[]; |
| 774 | if (gradConfig.saveAllInputs) { |
| 775 | util.assert( |
| 776 | Array.isArray(inputs), |
| 777 | () => 'saveAllInputs is true, expected inputs to be an array.'); |
| 778 | |
| 779 | inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]); |
| 780 | } else { |
| 781 | inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]); |
| 782 | } |
| 783 | |
| 784 | const outputTensorsToSave: Tensor[] = |
| 785 | outputs.filter((_, i) => outputsToSave[i]); |
| 786 | |
| 787 | return inputTensorsToSave.concat(outputTensorsToSave); |
| 788 | } |
| 789 | // We return an empty list rather than throw an error because the kernel we |
| 790 | // are looking up may not actually be relevant to backproping through the |
| 791 | // overall function |
| 792 | // |
| 793 | // See 'does not error if irrelevant (pruned) ops are missing grads' test |
| 794 | // in gradients_test.ts for an example. |
| 795 | return []; |
| 796 | } |
| 797 | |
| 798 | /** |
| 799 | * Internal method used by public APIs for tensor creation. Makes a new |
no test coverage detected