(norm: Union[float, torch.Tensor], move_to_cuda)
| 81 | |
| 82 | |
| 83 | def get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor: |
| 84 | if isinstance(norm, float): |
| 85 | norm = torch.Tensor([norm]) |
| 86 | if move_to_cuda: |
| 87 | norm = norm.to(torch.cuda.current_device()) |
| 88 | return norm |
| 89 | |
| 90 | |
| 91 | def get_current_device() -> torch.device: |