MCPcopy
hub / github.com/borisdayma/dalle-mini / split_params

Function split_params

tools/train/train.py:560–575  ·  view source on GitHub ↗

Split params between scanned and non-scanned

(data)

Source from the content-addressed store, hash-verified

558
559
560def split_params(data):
561 """Split params between scanned and non-scanned"""
562 flat = traverse_util.flatten_dict(unfreeze(data))
563 split = {"standard": {}, "scanned_encoder": {}, "scanned_decoder": {}}
564 for k, v in flat.items():
565 if "FlaxBartEncoderLayers" in k:
566 split["scanned_encoder"][k] = v
567 elif "FlaxBartDecoderLayers" in k:
568 split["scanned_decoder"][k] = v
569 else:
570 split["standard"][k] = v
571 # remove empty keys
572 split = {k: v for k, v in split.items() if v}
573 for k, v in split.items():
574 split[k] = freeze(traverse_util.unflatten_dict(v))
575 return split
576
577
578def unsplit_params(data):

Callers 4

apply_gradientsMethod · 0.85
createMethod · 0.85
mainFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected