(parameter: torch.nn.Module)
| 129 | |
| 130 | |
| 131 | def get_parameter_device(parameter: torch.nn.Module) -> torch.device: |
| 132 | from ..hooks.group_offloading import _get_group_onload_device |
| 133 | |
| 134 | try: |
| 135 | # Try to get the onload device from the group offloading hook |
| 136 | return _get_group_onload_device(parameter) |
| 137 | except ValueError: |
| 138 | pass |
| 139 | |
| 140 | try: |
| 141 | # If the onload device is not available due to no group offloading hooks, try to get the device |
| 142 | # from the first parameter or buffer |
| 143 | parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) |
| 144 | return next(parameters_and_buffers).device |
| 145 | except StopIteration: |
| 146 | # For torch.nn.DataParallel compatibility in PyTorch 1.5 |
| 147 | |
| 148 | def find_tensor_attributes(module: torch.nn.Module) -> list[tuple[str, Tensor]]: |
| 149 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] |
| 150 | return tuples |
| 151 | |
| 152 | gen = parameter._named_members(get_members_fn=find_tensor_attributes) |
| 153 | first_tuple = next(gen) |
| 154 | return first_tuple[1].device |
| 155 | |
| 156 | |
| 157 | def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: |
no test coverage detected
searching dependent graphs…