MCPcopy
hub / github.com/tinygrad/tinygrad / train_bert

Function train_bert

examples/mlperf/model_train.py:935–1282  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

933 return masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss
934
935def train_bert():
936 # NOTE: pip install tensorflow, wandb required
937 from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
938 from examples.mlperf.helpers import get_mlperf_bert_model, get_fake_data_bert
939 from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
940
941 config = {}
942 BASEDIR = getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki")
943
944 GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
945 print(f"training on {GPUS}")
946 for x in GPUS: Device[x]
947 seed = config["seed"] = getenv("SEED", 12345)
948
949 INITMLPERF = getenv("INITMLPERF")
950 RUNMLPERF = getenv("RUNMLPERF")
951 BENCHMARK = getenv("BENCHMARK")
952 if getenv("LOGMLPERF"):
953 from mlperf_logging import mllog
954 import mlperf_logging.mllog.constants as mllog_constants
955
956 mllog.config(filename=f"result_bert_{seed}.log")
957 mllog.config(root_dir=Path(__file__).parents[3].as_posix())
958 MLLOGGER = mllog.get_mllogger()
959 MLLOGGER.logger.propagate = False
960
961 if INITMLPERF:
962 assert BENCHMARK, "BENCHMARK must be set for INITMLPERF"
963 MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
964 MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
965 MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
966 MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
967
968 MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.BERT)
969
970 diskcache_clear()
971 MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
972 MLLOGGER.start(key=mllog_constants.INIT_START, value=None)
973
974 if RUNMLPERF:
975 MLLOGGER.start(key=mllog_constants.RUN_START, value=None)
976 MLLOGGER.event(key=mllog_constants.SEED, value=seed)
977 else:
978 MLLOGGER = None
979
980 # ** hyperparameters **
981 BS = config["BS"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
982 grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
983 # TODO: implement grad accumulation + mlperf logging
984 assert grad_acc == 1
985 GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
986 EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
987 max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000175 * math.sqrt(GBS/96))
988 opt_lamb_beta_1 = config["OPT_LAMB_BETA_1"] = getenv("OPT_LAMB_BETA_1", 0.9)
989 opt_lamb_beta_2 = config["OPT_LAMB_BETA_2"] = getenv("OPT_LAMB_BETA_2", 0.999)
990
991 train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3600000 // GBS)
992 warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)

Callers

nothing calls this directly

Calls 15

itemMethod · 0.95
getenvFunction · 0.90
diskcache_clearFunction · 0.90
get_mlperf_bert_modelFunction · 0.90
get_parametersFunction · 0.90
get_state_dictFunction · 0.90
LAMBClass · 0.90
OptimizerGroupClass · 0.90
LRSchedulerGroupClass · 0.90
load_training_stateFunction · 0.90
safe_loadFunction · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…