(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool)
| 64 | |
| 65 | class FusedAdamKernel(AdamKernel): |
| 66 | def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: |
| 67 | super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) |
| 68 | from colossalai.kernel.kernel_loader import FusedOptimizerLoader |
| 69 | |
| 70 | fused_optim = FusedOptimizerLoader().load() |
| 71 | self.fused_adam = fused_optim.multi_tensor_adam |
| 72 | self.dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device()) |
| 73 | |
| 74 | def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): |
| 75 | multi_tensor_applier( |
nothing calls this directly
no test coverage detected