MCPcopy
hub / github.com/zai-org/CogVideo / __init__

Method __init__

sat/sgm/modules/ema.py:6–25  ·  view source on GitHub ↗
(self, model, decay=0.9999, use_num_upates=True)

Source from the content-addressed store, hash-verified

4
5class LitEma(nn.Module):
6 def __init__(self, model, decay=0.9999, use_num_upates=True):
7 super().__init__()
8 if decay < 0.0 or decay > 1.0:
9 raise ValueError("Decay must be between 0 and 1")
10
11 self.m_name2s_name = {}
12 self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13 self.register_buffer(
14 "num_updates",
15 torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
16 )
17
18 for name, p in model.named_parameters():
19 if p.requires_grad:
20 # remove as '.'-character is not allowed in buffers
21 s_name = name.replace(".", "")
22 self.m_name2s_name.update({name: s_name})
23 self.register_buffer(s_name, p.clone().detach().data)
24
25 self.collected_params = []
26
27 def reset_num_updates(self):
28 del self.num_updates

Callers

nothing calls this directly

Calls 1

updateMethod · 0.45

Tested by

no test coverage detected