MCPcopy
hub / github.com/hpcaitech/ColossalAI / _cast_float

Function _cast_float

colossalai/utils/common.py:66–73  ·  view source on GitHub ↗
(args, dtype: torch.dtype)

Source from the content-addressed store, hash-verified

64
65
66def _cast_float(args, dtype: torch.dtype):
67 if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
68 args = args.to(dtype)
69 elif isinstance(args, (list, tuple)):
70 args = type(args)(_cast_float(t, dtype) for t in args)
71 elif isinstance(args, dict):
72 args = {k: _cast_float(v, dtype) for k, v in args.items()}
73 return args
74
75
76def set_seed(seed):

Callers 3

forwardMethod · 0.90
forwardMethod · 0.90
forwardMethod · 0.90

Calls 1

toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…