MCPcopy
hub / github.com/Vchitect/Latte / update_ema

Function update_ema

utils.py:191–200  ·  view source on GitHub ↗

Step the EMA model towards the current model.

(ema_model, model, decay=0.9999)

Source from the content-addressed store, hash-verified

189
190@torch.no_grad()
191def update_ema(ema_model, model, decay=0.9999):
192 """
193 Step the EMA model towards the current model.
194 """
195 ema_params = OrderedDict(ema_model.named_parameters())
196 model_params = OrderedDict(model.named_parameters())
197
198 for name, param in model_params.items():
199 # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
200 ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
201
202def requires_grad(model, flag=True):
203 """

Callers 6

mainFunction · 0.90
mainFunction · 0.90
__init__Method · 0.90
on_train_batch_endMethod · 0.90
__init__Method · 0.90
on_train_batch_endMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected