(path: tuple)
| 108 | |
| 109 | |
| 110 | def path_tuple_to_string(path: tuple) -> str: |
| 111 | pieces = [] |
| 112 | for elem in path: |
| 113 | if isinstance(elem, jax.tree_util.DictKey): |
| 114 | pieces.append(elem.key) |
| 115 | elif isinstance(elem, jax.tree_util.GetAttrKey): |
| 116 | pieces.append(elem.name) |
| 117 | else: |
| 118 | assert isinstance(elem, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey)) |
| 119 | return "/".join(pieces) |
| 120 | |
| 121 | |
| 122 | def get_load_path_str( |
no outgoing calls
no test coverage detected