MCPcopy
hub / github.com/tinygrad/tinygrad / train_llama3

Function train_llama3

examples/mlperf/model_train.py:1284–1654  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

1282 previous_step = i
1283
1284def train_llama3():
1285 from examples.mlperf.models.flat_llama import FlatTransformer, apply_grad, FP8_DTYPE
1286 from examples.llama3 import MODEL_PARAMS
1287 from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
1288 from examples.mlperf.optim import GradAccClipAdamW
1289
1290 INITMLPERF = getenv("INITMLPERF")
1291 RUNMLPERF = getenv("RUNMLPERF")
1292 LOGMLPERF = getenv("LOGMLPERF")
1293 BENCHMARK = getenv("BENCHMARK")
1294
1295 config = {}
1296 BASEDIR = config["BASEDIR"] = Path(getenv("BASEDIR", "/raid/datasets/c4/"))
1297 BS = config["BS"] = getenv("BS", 16)
1298 grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
1299 GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
1300 SEED = config["SEED"] = getenv("SEED", 5760)
1301 DATA_SEED = config["DATA_SEED"] = getenv("DATA_SEED", SEED)
1302 SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
1303 TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
1304 SMALL = config["SMALL"] = getenv("SMALL", 0)
1305 SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152)
1306 EVAL_SAMPLES = config["EVAL_SAMPLES"] = getenv("EVAL_SAMPLES", 5760 if not SMALL else 1024)
1307 MAX_STEPS = config["MAX_STEPS"] = getenv("MAX_STEPS", math.ceil(1_200_000 * 1152 / GBS))
1308 WARMUP_STEPS = config["WARMUP_STEPS"] = getenv("WARMUP_STEPS", math.ceil(8000 * 1152 / GBS))
1309 LR = config["LR"] = getenv("LR", 8e-5 * GBS / 1152)
1310 END_LR = config["END_LR"] = getenv("END_LR", 8e-7)
1311 EVAL_FREQ = config["EVAL_FREQ"] = getenv("EVAL_FREQ", 46080)
1312 EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16)
1313 EVAL_TARGET = config["EVAL_TARGET"] = getenv("EVAL_TARGET", 5.6)
1314
1315 if LOGMLPERF:
1316 from mlperf_logging import mllog
1317 import mlperf_logging.mllog.constants as mllog_constants
1318
1319 mllog.config(filename=f"result_llama31_{SEED}.log")
1320 mllog.config(root_dir=Path(__file__).parents[3].as_posix())
1321 MLLOGGER = mllog.get_mllogger()
1322 MLLOGGER.logger.propagate = False
1323
1324 LLAMA_BENCHMARK = mllog_constants.LLAMA31_405B if getenv("LLAMA3_SIZE", "8B") == "405B" else mllog_constants.LLAMA31_8B
1325
1326 if INITMLPERF:
1327 assert BENCHMARK, "BENCHMARK must be set for INITMLPERF"
1328 MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
1329 MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
1330 MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
1331 MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
1332
1333 MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=LLAMA_BENCHMARK)
1334
1335 diskcache_clear()
1336 MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
1337 MLLOGGER.start(key=mllog_constants.INIT_START, value=None)
1338
1339 if RUNMLPERF:
1340 MLLOGGER.start(key=mllog_constants.RUN_START, value=None)
1341 MLLOGGER.event(key=mllog_constants.SEED, value=SEED)

Callers

nothing calls this directly

Calls 15

shardMethod · 0.95
getenvFunction · 0.90
diskcache_clearFunction · 0.90
round_upFunction · 0.90
FlatTransformerClass · 0.90
get_parametersFunction · 0.90
GradAccClipAdamWClass · 0.90
load_state_dictFunction · 0.90
safe_loadFunction · 0.90
get_state_dictFunction · 0.90
get_llama3_datasetFunction · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…