MCPcopy
hub / github.com/borisdayma/dalle-mini / create

Method create

tools/train/train.py:619–633  ·  view source on GitHub ↗
(cls, *, apply_fn, params, tx, **kwargs)

Source from the content-addressed store, hash-verified

617
618 @classmethod
619 def create(cls, *, apply_fn, params, tx, **kwargs):
620 opt_state = {}
621 for k, p in split_params(params).items():
622 init_fn = tx[k].init
623 if "scanned" in k:
624 init_fn = jax.vmap(init_fn)
625 opt_state[k] = init_fn(p)
626 return cls(
627 step=0,
628 apply_fn=apply_fn,
629 params=params,
630 tx=tx,
631 opt_state=freeze(opt_state),
632 **kwargs,
633 )
634
635
636def main():

Callers 1

init_stateFunction · 0.80

Calls 2

split_paramsFunction · 0.85
init_fnFunction · 0.50

Tested by

no test coverage detected