()
| 944 | |
| 945 | # get PartitionSpec for optimizer state |
| 946 | def get_opt_state_spec_and_shape(): |
| 947 | # get opt_state shape without actual init |
| 948 | opt_state_shape = {} |
| 949 | for k, p in split_params(params_shape).items(): |
| 950 | if "scanned" not in k: |
| 951 | opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p) |
| 952 | else: |
| 953 | opt_state_shape[k] = jax.eval_shape(jax.vmap(optimizer[k].init), p) |
| 954 | |
| 955 | if training_args.optim == "adafactor": |
| 956 | # factorized state must be replicated (rank different than params) |
| 957 | opt_state_spec = {k: None for k in split_params(params_shape)} |
| 958 | |
| 959 | elif training_args.optim in ["adam", "distributed_shampoo"]: |
| 960 | |
| 961 | def _opt_state_spec_per_leaf(x, spec): |
| 962 | if isinstance(x, FrozenDict): |
| 963 | # variables with same structure as params |
| 964 | return spec |
| 965 | else: |
| 966 | # other variables such as count |
| 967 | return None |
| 968 | |
| 969 | split_spec = split_params(set_partitions(params_shape, False)) |
| 970 | opt_state_spec = {} |
| 971 | for k, p in split_params(params_shape).items(): |
| 972 | if "scanned" in k: |
| 973 | p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p) |
| 974 | if training_args.optim == "adam": |
| 975 | opt_state_spec[k] = jax.tree_map( |
| 976 | _opt_state_spec_per_leaf, |
| 977 | opt_state_shape[k], |
| 978 | split_spec[k], |
| 979 | # return None spec for empty elements |
| 980 | is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)), |
| 981 | ) |
| 982 | elif training_args.optim == "distributed_shampoo": |
| 983 | opt_state_spec[k] = opt_fn[k].pspec_fn( |
| 984 | p, |
| 985 | split_spec[k], |
| 986 | statistics_partition_spec, |
| 987 | ) |
| 988 | # add dimension for scanned params |
| 989 | if "scanned" in k: |
| 990 | opt_state_spec[k] = jax.tree_map( |
| 991 | lambda x: PartitionSpec(*(None,) + x) |
| 992 | if x is not None |
| 993 | else None, |
| 994 | opt_state_spec[k], |
| 995 | is_leaf=lambda x: isinstance(x, PartitionSpec), |
| 996 | ) |
| 997 | |
| 998 | else: |
| 999 | raise NotImplementedError |
| 1000 | return freeze(opt_state_spec), freeze(opt_state_shape) |
| 1001 | |
| 1002 | opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape() |
| 1003 |
no test coverage detected