Split params between scanned and non-scanned
(data)
| 558 | |
| 559 | |
| 560 | def split_params(data): |
| 561 | """Split params between scanned and non-scanned""" |
| 562 | flat = traverse_util.flatten_dict(unfreeze(data)) |
| 563 | split = {"standard": {}, "scanned_encoder": {}, "scanned_decoder": {}} |
| 564 | for k, v in flat.items(): |
| 565 | if "FlaxBartEncoderLayers" in k: |
| 566 | split["scanned_encoder"][k] = v |
| 567 | elif "FlaxBartDecoderLayers" in k: |
| 568 | split["scanned_decoder"][k] = v |
| 569 | else: |
| 570 | split["standard"][k] = v |
| 571 | # remove empty keys |
| 572 | split = {k: v for k, v in split.items() if v} |
| 573 | for k, v in split.items(): |
| 574 | split[k] = freeze(traverse_util.unflatten_dict(v)) |
| 575 | return split |
| 576 | |
| 577 | |
| 578 | def unsplit_params(data): |
no outgoing calls
no test coverage detected