Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss.
(self, closure=None)
| 168 | ) |
| 169 | |
| 170 | def step(self, closure=None): |
| 171 | """Performs a single optimization step. |
| 172 | |
| 173 | Arguments: |
| 174 | closure (callable, optional): A closure that reevaluates the model |
| 175 | and returns the loss. |
| 176 | """ |
| 177 | loss = None |
| 178 | if closure is not None: |
| 179 | loss = closure() |
| 180 | check = 1#torch.norm(all_grads, 2) |
| 181 | |
| 182 | grad_list = [] |
| 183 | param_list = [] |
| 184 | per_param_decay = [] |
| 185 | momentum = [] |
| 186 | velocity = [] |
| 187 | |
| 188 | fp16_grad_list = [] |
| 189 | fp16_from_fp32_param_list = [] |
| 190 | fp32_param_list = [] |
| 191 | fp16_per_param_decay = [] |
| 192 | fp16_momentum = [] |
| 193 | fp16_velocity = [] |
| 194 | |
| 195 | if not self.updates_created: |
| 196 | self.update = [] |
| 197 | self.fp16_update = [] |
| 198 | for group in self.param_groups: |
| 199 | for p in group['params']: |
| 200 | if p.grad is None: |
| 201 | continue |
| 202 | grad = p.grad.data |
| 203 | if grad.is_sparse: |
| 204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') |
| 205 | |
| 206 | state = self.state[p] |
| 207 | |
| 208 | # State initialization |
| 209 | if len(state) == 0: |
| 210 | # Keep step here for compatibility with earlier resume from checkpoint |
| 211 | state['step'] = 0 |
| 212 | # Exponential moving average of gradient values |
| 213 | state['momentum'] = torch.zeros_like(p.data, dtype=torch.float32) |
| 214 | # Exponential moving average of squared gradient values |
| 215 | state['velocity'] = torch.zeros_like(p.data, dtype=torch.float32) |
| 216 | # fp32 master weights |
| 217 | if 'master_param' not in state.keys() and p.type() == 'torch.cuda.HalfTensor': |
| 218 | state['master_param'] = p.detach().clone().float() |
| 219 | |
| 220 | # ensure these 3 are float tensors |
| 221 | if state['momentum'].type() != 'torch.cuda.FloatTensor': |
| 222 | state['momentum'] = state['momentum'].float() |
| 223 | if state['velocity'].type() != 'torch.cuda.FloatTensor': |
| 224 | state['velocity'] = state['velocity'].float() |
| 225 | if 'master_param' in state.keys() and state['master_param'].type() != 'torch.cuda.FloatTensor': |
| 226 | state['master_param'] = state['master_param'].float() |
| 227 |
no test coverage detected