MCPcopy
hub / github.com/borisdayma/dalle-mini / get_opt_state_spec_and_shape

Function get_opt_state_spec_and_shape

tools/train/train.py:946–1000  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

944
945 # get PartitionSpec for optimizer state
946 def get_opt_state_spec_and_shape():
947 # get opt_state shape without actual init
948 opt_state_shape = {}
949 for k, p in split_params(params_shape).items():
950 if "scanned" not in k:
951 opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p)
952 else:
953 opt_state_shape[k] = jax.eval_shape(jax.vmap(optimizer[k].init), p)
954
955 if training_args.optim == "adafactor":
956 # factorized state must be replicated (rank different than params)
957 opt_state_spec = {k: None for k in split_params(params_shape)}
958
959 elif training_args.optim in ["adam", "distributed_shampoo"]:
960
961 def _opt_state_spec_per_leaf(x, spec):
962 if isinstance(x, FrozenDict):
963 # variables with same structure as params
964 return spec
965 else:
966 # other variables such as count
967 return None
968
969 split_spec = split_params(set_partitions(params_shape, False))
970 opt_state_spec = {}
971 for k, p in split_params(params_shape).items():
972 if "scanned" in k:
973 p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
974 if training_args.optim == "adam":
975 opt_state_spec[k] = jax.tree_map(
976 _opt_state_spec_per_leaf,
977 opt_state_shape[k],
978 split_spec[k],
979 # return None spec for empty elements
980 is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
981 )
982 elif training_args.optim == "distributed_shampoo":
983 opt_state_spec[k] = opt_fn[k].pspec_fn(
984 p,
985 split_spec[k],
986 statistics_partition_spec,
987 )
988 # add dimension for scanned params
989 if "scanned" in k:
990 opt_state_spec[k] = jax.tree_map(
991 lambda x: PartitionSpec(*(None,) + x)
992 if x is not None
993 else None,
994 opt_state_spec[k],
995 is_leaf=lambda x: isinstance(x, PartitionSpec),
996 )
997
998 else:
999 raise NotImplementedError
1000 return freeze(opt_state_spec), freeze(opt_state_shape)
1001
1002 opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape()
1003

Callers 1

mainFunction · 0.85

Calls 2

set_partitionsFunction · 0.90
split_paramsFunction · 0.85

Tested by

no test coverage detected