MCPcopy
hub / github.com/state-spaces/mamba / _run

Function _run

tests/test_determinism.py:232–243  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

230 x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype)
231
232 def _run() -> dict[str, torch.Tensor]:
233 _set_deterministic(False)
234 model.zero_grad(set_to_none=True)
235 x = x_data.clone().requires_grad_(True)
236 y = model(x)
237 (y.float().square().mean()).backward()
238 torch.cuda.synchronize()
239 grads = {"input": x.grad.detach().float().clone()}
240 for name, p in model.named_parameters():
241 if p.grad is not None:
242 grads[name] = p.grad.detach().float().clone()
243 return grads
244
245 _run() # warmup
246 ref = _run()

Calls 4

Mamba2Class · 0.90
_set_deterministicFunction · 0.85
_set_seedsFunction · 0.85
backwardMethod · 0.45

Tested by

no test coverage detected