Apply a function to every element in a container. If the input data is a list or any sequence other than a string, returns a list whose elements are the same elements applied with the given function. If the input data is a dict or any mapping, returns a dict whose keys are the same
(data, fn, *args, **kwargs)
| 1055 | |
| 1056 | |
| 1057 | def apply_each(data, fn, *args, **kwargs): |
| 1058 | """Apply a function to every element in a container. |
| 1059 | |
| 1060 | If the input data is a list or any sequence other than a string, returns a list |
| 1061 | whose elements are the same elements applied with the given function. |
| 1062 | |
| 1063 | If the input data is a dict or any mapping, returns a dict whose keys are the same |
| 1064 | and values are the elements applied with the given function. |
| 1065 | |
| 1066 | The first argument of the function will be passed with the individual elements from |
| 1067 | the input data, followed by the arguments in :attr:`args` and :attr:`kwargs`. |
| 1068 | |
| 1069 | Parameters |
| 1070 | ---------- |
| 1071 | data : any |
| 1072 | Any object. |
| 1073 | fn : callable |
| 1074 | Any function. |
| 1075 | args, kwargs : |
| 1076 | Additional arguments and keyword-arguments passed to the function. |
| 1077 | |
| 1078 | Examples |
| 1079 | -------- |
| 1080 | Applying a ReLU function to a dictionary of tensors: |
| 1081 | |
| 1082 | >>> h = {k: torch.randn(3) for k in ['A', 'B', 'C']} |
| 1083 | >>> h = apply_each(h, torch.nn.functional.relu) |
| 1084 | >>> assert all((v >= 0).all() for v in h.values()) |
| 1085 | """ |
| 1086 | if isinstance(data, Mapping): |
| 1087 | return {k: fn(v, *args, **kwargs) for k, v in data.items()} |
| 1088 | elif is_listlike(data): |
| 1089 | return [fn(v, *args, **kwargs) for v in data] |
| 1090 | else: |
| 1091 | return fn(data, *args, **kwargs) |
| 1092 | |
| 1093 | |
| 1094 | def recursive_apply(data, fn, *args, **kwargs): |
no test coverage detected