(args, dtype: torch.dtype)
| 64 | |
| 65 | |
| 66 | def _cast_float(args, dtype: torch.dtype): |
| 67 | if isinstance(args, torch.Tensor) and torch.is_floating_point(args): |
| 68 | args = args.to(dtype) |
| 69 | elif isinstance(args, (list, tuple)): |
| 70 | args = type(args)(_cast_float(t, dtype) for t in args) |
| 71 | elif isinstance(args, dict): |
| 72 | args = {k: _cast_float(v, dtype) for k, v in args.items()} |
| 73 | return args |
| 74 | |
| 75 | |
| 76 | def set_seed(seed): |