(x)
| 944 | import torch |
| 945 | |
| 946 | def forward_t(x): |
| 947 | dtype = x.dtype |
| 948 | x = x.float() |
| 949 | s = 1.0 / x.abs().mean().clamp_(min=1e-5) |
| 950 | x = (x * s).round().clamp(-1, 1) / s |
| 951 | return x.to(dtype) |
| 952 | |
| 953 | def weight_quant(weight): |
| 954 | weight = torch.tensor(weight, dtype=torch.float32) |