MCPcopy
hub / github.com/policy-gradient/GRPO-Zero / main

Function main

train.py:54–184  ·  view source on GitHub ↗
(config_path: str)

Source from the content-addressed store, hash-verified

52
53
54def main(config_path: str):
55 with open(config_path, "r") as f:
56 config = yaml.safe_load(f)
57
58 pretrained_model_path = Path(config["model"]["pretrained_model_path"])
59 device = torch.device(config["model"]["device"])
60 dtype_map = {
61 "bfloat16": torch.bfloat16,
62 "float16": torch.float16,
63 "float32": torch.float32,
64 }
65 dtype = dtype_map.get(config["model"]["dtype"], torch.bfloat16)
66 torch.set_default_device(device)
67 torch.random.manual_seed(config["training"]["random_seed"])
68 BATCH_SIZE = config["training"]["batch_size"]
69 NUM_QUESTIONS_PER_BATCH = config["training"]["num_questions_per_batch"]
70 NUM_ANSWERS_PER_QUESTION = BATCH_SIZE // NUM_QUESTIONS_PER_BATCH
71
72 current_time = datetime.now().strftime(r"%Y%m%d-%H%M%S")
73 tb_writer = SummaryWriter(log_dir=f"{config['training']['log_dir']}/{current_time}")
74 tokenizer = Tokenizer(str(pretrained_model_path / "tokenizer.json"))
75
76 train_dataset = CountdownTasksDataset(
77 data_path=config["data"]["path"],
78 tokenizer=tokenizer,
79 split="train",
80 test_size=config["data"]["test_size"],
81 )
82 generator = torch.Generator(device=device)
83 train_dataloader = DataLoader(
84 train_dataset,
85 shuffle=True,
86 collate_fn=CountdownTasksDataset.collate_fn,
87 generator=generator,
88 batch_size=NUM_QUESTIONS_PER_BATCH,
89 )
90
91 model = Transformer.from_pretrained(pretrained_model_path, device=device).train()
92
93 optimizer = MemoryEfficientAdamW(
94 model.parameters(),
95 lr=config["training"]["learning_rate"],
96 weight_decay=config["training"]["weight_decay"],
97 betas=config["training"]["betas"],
98 enabled=config["training"]["memory_efficient_adamw"],
99 )
100
101 start_time = time.time()
102 ckpt_dir = Path(config["training"]["ckpt_dir"])
103 ckpt_dir.mkdir(parents=True, exist_ok=True)
104
105 for step, batch in enumerate(train_dataloader, start=1):
106 episodes = rollout(
107 model=model,
108 tokenizer=tokenizer,
109 batch=batch,
110 max_gen_len=config["training"]["max_gen_len"],
111 num_answer_per_question=NUM_ANSWERS_PER_QUESTION,

Callers 1

train.pyFile · 0.85

Calls 7

TokenizerClass · 0.90
rolloutFunction · 0.90
update_policyFunction · 0.90
evaluateFunction · 0.85
from_pretrainedMethod · 0.80

Tested by

no test coverage detected