(self, module, device=None, scale=1.0)
| 69 | return m |
| 70 | |
| 71 | def compute_eigenvalue(self, module, device=None, scale=1.0): |
| 72 | block_eigenvalue = [] |
| 73 | param_keys = [] |
| 74 | layers = self.get_layers(module) |
| 75 | |
| 76 | for block in range(self.layer_num): |
| 77 | model_block = layers[block] |
| 78 | |
| 79 | # We found this randn() has obvious accuracy impact in some cases, save/recover random state here. |
| 80 | rng_state = torch.random.get_rng_state() |
| 81 | if device is None: |
| 82 | v = [ |
| 83 | torch.randn(p.size()) for p in model_block.parameters() |
| 84 | if p.grad is not None and p.grad.grad_fn is not None |
| 85 | ] |
| 86 | else: |
| 87 | v = [ |
| 88 | torch.randn(p.size(), device=device) for p in model_block.parameters() |
| 89 | if p.grad is not None and p.grad.grad_fn is not None |
| 90 | ] |
| 91 | torch.random.set_rng_state(rng_state) |
| 92 | |
| 93 | grads = [ |
| 94 | param.grad for param in model_block.parameters() |
| 95 | if param.grad is not None and param.grad.grad_fn is not None |
| 96 | ] |
| 97 | params = [ |
| 98 | param for param in model_block.parameters() |
| 99 | if param.grad is not None and param.grad.grad_fn is not None |
| 100 | ] |
| 101 | |
| 102 | layer_keys = [id(p) for p in model_block.parameters()] |
| 103 | param_keys.append(layer_keys) |
| 104 | |
| 105 | v = self.normalize(v) |
| 106 | |
| 107 | # Disable eigenvalue if the model doesn't support second order gradients computation, |
| 108 | # e.g. when enabling DS transformer kernel. |
| 109 | if len(grads) == 0 or len(params) == 0: |
| 110 | log_dist('The model does NOT support eigenvalue computation.', ranks=[0], level=logging.WARNING) |
| 111 | return [] |
| 112 | |
| 113 | i = 0 |
| 114 | eigenvalue_current, eigenvalue_previous = 1., 0. |
| 115 | |
| 116 | while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs( |
| 117 | (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) |
| 118 | >= self.tol): # test convergence criteria |
| 119 | eigenvalue_previous = eigenvalue_current |
| 120 | |
| 121 | Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=True) |
| 122 | #Hv = [hv.float() for hv in Hv] |
| 123 | Hv = [self.nan_to_num(hv).float() for hv in Hv] |
| 124 | |
| 125 | eigenvalue_current = self.inner_product(Hv, v).item() |
| 126 | |
| 127 | v = self.normalize(Hv) |
| 128 | v = [x / scale for x in v] |
no test coverage detected