Recursively apply a function to every element in a container. If the input data is a list or any sequence other than a string, returns a list whose elements are the same elements applied with the given function. If the input data is a dict or any mapping, returns a dict whose keys are
(data, fn, *args, **kwargs)
| 1092 | |
| 1093 | |
| 1094 | def recursive_apply(data, fn, *args, **kwargs): |
| 1095 | """Recursively apply a function to every element in a container. |
| 1096 | |
| 1097 | If the input data is a list or any sequence other than a string, returns a list |
| 1098 | whose elements are the same elements applied with the given function. |
| 1099 | |
| 1100 | If the input data is a dict or any mapping, returns a dict whose keys are the same |
| 1101 | and values are the elements applied with the given function. |
| 1102 | |
| 1103 | If the input data is a nested container, the result will have the same nested |
| 1104 | structure where each element is transformed recursively. |
| 1105 | |
| 1106 | The first argument of the function will be passed with the individual elements from |
| 1107 | the input data, followed by the arguments in :attr:`args` and :attr:`kwargs`. |
| 1108 | |
| 1109 | Parameters |
| 1110 | ---------- |
| 1111 | data : any |
| 1112 | Any object. |
| 1113 | fn : callable |
| 1114 | Any function. |
| 1115 | args, kwargs : |
| 1116 | Additional arguments and keyword-arguments passed to the function. |
| 1117 | |
| 1118 | Examples |
| 1119 | -------- |
| 1120 | Applying a ReLU function to a dictionary of tensors: |
| 1121 | |
| 1122 | >>> h = {k: torch.randn(3) for k in ['A', 'B', 'C']} |
| 1123 | >>> h = recursive_apply(h, torch.nn.functional.relu) |
| 1124 | >>> assert all((v >= 0).all() for v in h.values()) |
| 1125 | """ |
| 1126 | if isinstance(data, Mapping): |
| 1127 | return { |
| 1128 | k: recursive_apply(v, fn, *args, **kwargs) for k, v in data.items() |
| 1129 | } |
| 1130 | elif isinstance(data, tuple): |
| 1131 | return tuple(recursive_apply(v, fn, *args, **kwargs) for v in data) |
| 1132 | elif is_listlike(data): |
| 1133 | return [recursive_apply(v, fn, *args, **kwargs) for v in data] |
| 1134 | else: |
| 1135 | return fn(data, *args, **kwargs) |
| 1136 | |
| 1137 | |
| 1138 | def recursive_apply_pair(data1, data2, fn, *args, **kwargs): |
no test coverage detected