MCPcopy Index your code
hub / github.com/pytorch/pytorch / AddGradientOperators

Method AddGradientOperators

caffe2/python/core.py:2025–2057  ·  view source on GitHub ↗

Add the gradient for operators in the net. Inputs: ys: a list or a dictionary specifying what blobs we want to compute derivatives of. If the input is a list, we will automatically generate their gradients with all-one values; if the input is a

(self, ys, skip=0)

Source from the content-addressed store, hash-verified

2023 self._recreate_lookup_tables = False
2024
2025 def AddGradientOperators(self, ys, skip=0):
2026 """Add the gradient for operators in the net.
2027
2028 Inputs:
2029 ys: a list or a dictionary specifying what blobs we want to compute
2030 derivatives of. If the input is a list, we will automatically
2031 generate their gradients with all-one values; if the input is a
2032 dictionary, for any dictionary entries that are not None, we will
2033 take the corresponding blobs as their gradients; for all those
2034 that are None, we will auto-fill them with 1.
2035 skip: skips the first n operators. This is provided mainly because a
2036 lot of nets may use the first few operators for data generation
2037 like stuff which really do not need to have gradients.
2038
2039 Outputs:
2040 returns a map from the blob name in the input network to a blob
2041 containing gradient or a GradientSlice in case of sparse gradient
2042
2043 Currently, this is hard-coded for float operators if there are branches
2044 (i.e. a blob is used as input to multiple operators). This is because
2045 the gradient accumulation (Sum) is float only right now.
2046 """
2047
2048 grad_ops, input_to_grad = GradientRegistry.GetBackwardPass(
2049 self._net.op[skip:], ys)
2050 # Check if in immediate mode: the grad_ops are actually being produced
2051 # by C++ and bypasses the CreateOperator() call, so in immediate mode
2052 # we will have to explicitly run them.
2053 if workspace.IsImmediate():
2054 for op in grad_ops:
2055 workspace.RunOperatorImmediate(op)
2056 self._ExtendOps(grad_ops)
2057 return input_to_grad
2058
2059 def AddArgument(self, arg_name, arg_value):
2060 self._net.arg.extend([utils.MakeArgument(arg_name, arg_value)])

Calls 2

_ExtendOpsMethod · 0.95
GetBackwardPassMethod · 0.45