Recursively apply a function to every element in a container. If the input data is a list or any sequence other than a string, returns a list whose elements are the same elements applied with the given function. If the input data is a dict or any mapping, returns a dict whose keys are
(data, fn, *args, **kwargs)
| 70 | |
| 71 | |
| 72 | def recursive_apply(data, fn, *args, **kwargs): |
| 73 | """Recursively apply a function to every element in a container. |
| 74 | |
| 75 | If the input data is a list or any sequence other than a string, returns a list |
| 76 | whose elements are the same elements applied with the given function. |
| 77 | |
| 78 | If the input data is a dict or any mapping, returns a dict whose keys are the same |
| 79 | and values are the elements applied with the given function. |
| 80 | |
| 81 | If the input data is a nested container, the result will have the same nested |
| 82 | structure where each element is transformed recursively. |
| 83 | |
| 84 | The first argument of the function will be passed with the individual elements from |
| 85 | the input data, followed by the arguments in :attr:`args` and :attr:`kwargs`. |
| 86 | |
| 87 | Parameters |
| 88 | ---------- |
| 89 | data : any |
| 90 | Any object. |
| 91 | fn : callable |
| 92 | Any function. |
| 93 | args, kwargs : |
| 94 | Additional arguments and keyword-arguments passed to the function. |
| 95 | |
| 96 | Examples |
| 97 | -------- |
| 98 | Applying a ReLU function to a dictionary of tensors: |
| 99 | |
| 100 | >>> h = {k: torch.randn(3) for k in ['A', 'B', 'C']} |
| 101 | >>> h = recursive_apply(h, torch.nn.functional.relu) |
| 102 | >>> assert all((v >= 0).all() for v in h.values()) |
| 103 | """ |
| 104 | if isinstance(data, Mapping): |
| 105 | return { |
| 106 | k: recursive_apply(v, fn, *args, **kwargs) for k, v in data.items() |
| 107 | } |
| 108 | elif isinstance(data, tuple): |
| 109 | return tuple(recursive_apply(v, fn, *args, **kwargs) for v in data) |
| 110 | elif is_listlike(data): |
| 111 | return [recursive_apply(v, fn, *args, **kwargs) for v in data] |
| 112 | else: |
| 113 | return fn(data, *args, **kwargs) |
| 114 | |
| 115 | |
| 116 | def recursive_apply_reduce_all(data, fn, *args, **kwargs): |
no test coverage detected