MCPcopy Index your code
hub / github.com/InternLM/InternLM / get_tensor_norm

Function get_tensor_norm

internlm/utils/common.py:83–88  ·  view source on GitHub ↗
(norm: Union[float, torch.Tensor], move_to_cuda)

Source from the content-addressed store, hash-verified

81
82
83def get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
84 if isinstance(norm, float):
85 norm = torch.Tensor([norm])
86 if move_to_cuda:
87 norm = norm.to(torch.cuda.current_device())
88 return norm
89
90
91def get_current_device() -> torch.device:

Callers 1

compute_normFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected