MCPcopy
hub / github.com/deepspeedai/DeepSpeed / compute_eigenvalue

Method compute_eigenvalue

deepspeed/runtime/eigenvalue.py:71–147  ·  view source on GitHub ↗
(self, module, device=None, scale=1.0)

Source from the content-addressed store, hash-verified

69 return m
70
71 def compute_eigenvalue(self, module, device=None, scale=1.0):
72 block_eigenvalue = []
73 param_keys = []
74 layers = self.get_layers(module)
75
76 for block in range(self.layer_num):
77 model_block = layers[block]
78
79 # We found this randn() has obvious accuracy impact in some cases, save/recover random state here.
80 rng_state = torch.random.get_rng_state()
81 if device is None:
82 v = [
83 torch.randn(p.size()) for p in model_block.parameters()
84 if p.grad is not None and p.grad.grad_fn is not None
85 ]
86 else:
87 v = [
88 torch.randn(p.size(), device=device) for p in model_block.parameters()
89 if p.grad is not None and p.grad.grad_fn is not None
90 ]
91 torch.random.set_rng_state(rng_state)
92
93 grads = [
94 param.grad for param in model_block.parameters()
95 if param.grad is not None and param.grad.grad_fn is not None
96 ]
97 params = [
98 param for param in model_block.parameters()
99 if param.grad is not None and param.grad.grad_fn is not None
100 ]
101
102 layer_keys = [id(p) for p in model_block.parameters()]
103 param_keys.append(layer_keys)
104
105 v = self.normalize(v)
106
107 # Disable eigenvalue if the model doesn't support second order gradients computation,
108 # e.g. when enabling DS transformer kernel.
109 if len(grads) == 0 or len(params) == 0:
110 log_dist('The model does NOT support eigenvalue computation.', ranks=[0], level=logging.WARNING)
111 return []
112
113 i = 0
114 eigenvalue_current, eigenvalue_previous = 1., 0.
115
116 while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs(
117 (eigenvalue_current - eigenvalue_previous) / eigenvalue_current)
118 >= self.tol): # test convergence criteria
119 eigenvalue_previous = eigenvalue_current
120
121 Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=True)
122 #Hv = [hv.float() for hv in Hv]
123 Hv = [self.nan_to_num(hv).float() for hv in Hv]
124
125 eigenvalue_current = self.inner_product(Hv, v).item()
126
127 v = self.normalize(Hv)
128 v = [x / scale for x in v]

Callers 1

stepMethod · 0.80

Calls 12

get_layersMethod · 0.95
normalizeMethod · 0.95
nan_to_numMethod · 0.95
inner_productMethod · 0.95
post_processMethod · 0.95
log_distFunction · 0.90
appendMethod · 0.80
get_rng_stateMethod · 0.45
sizeMethod · 0.45
parametersMethod · 0.45
set_rng_stateMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected