Add the gradient for operators in the net. Inputs: ys: a list or a dictionary specifying what blobs we want to compute derivatives of. If the input is a list, we will automatically generate their gradients with all-one values; if the input is a
(self, ys, skip=0)
| 2023 | self._recreate_lookup_tables = False |
| 2024 | |
| 2025 | def AddGradientOperators(self, ys, skip=0): |
| 2026 | """Add the gradient for operators in the net. |
| 2027 | |
| 2028 | Inputs: |
| 2029 | ys: a list or a dictionary specifying what blobs we want to compute |
| 2030 | derivatives of. If the input is a list, we will automatically |
| 2031 | generate their gradients with all-one values; if the input is a |
| 2032 | dictionary, for any dictionary entries that are not None, we will |
| 2033 | take the corresponding blobs as their gradients; for all those |
| 2034 | that are None, we will auto-fill them with 1. |
| 2035 | skip: skips the first n operators. This is provided mainly because a |
| 2036 | lot of nets may use the first few operators for data generation |
| 2037 | like stuff which really do not need to have gradients. |
| 2038 | |
| 2039 | Outputs: |
| 2040 | returns a map from the blob name in the input network to a blob |
| 2041 | containing gradient or a GradientSlice in case of sparse gradient |
| 2042 | |
| 2043 | Currently, this is hard-coded for float operators if there are branches |
| 2044 | (i.e. a blob is used as input to multiple operators). This is because |
| 2045 | the gradient accumulation (Sum) is float only right now. |
| 2046 | """ |
| 2047 | |
| 2048 | grad_ops, input_to_grad = GradientRegistry.GetBackwardPass( |
| 2049 | self._net.op[skip:], ys) |
| 2050 | # Check if in immediate mode: the grad_ops are actually being produced |
| 2051 | # by C++ and bypasses the CreateOperator() call, so in immediate mode |
| 2052 | # we will have to explicitly run them. |
| 2053 | if workspace.IsImmediate(): |
| 2054 | for op in grad_ops: |
| 2055 | workspace.RunOperatorImmediate(op) |
| 2056 | self._ExtendOps(grad_ops) |
| 2057 | return input_to_grad |
| 2058 | |
| 2059 | def AddArgument(self, arg_name, arg_value): |
| 2060 | self._net.arg.extend([utils.MakeArgument(arg_name, arg_value)]) |