(config_path: str)
| 52 | |
| 53 | |
| 54 | def main(config_path: str): |
| 55 | with open(config_path, "r") as f: |
| 56 | config = yaml.safe_load(f) |
| 57 | |
| 58 | pretrained_model_path = Path(config["model"]["pretrained_model_path"]) |
| 59 | device = torch.device(config["model"]["device"]) |
| 60 | dtype_map = { |
| 61 | "bfloat16": torch.bfloat16, |
| 62 | "float16": torch.float16, |
| 63 | "float32": torch.float32, |
| 64 | } |
| 65 | dtype = dtype_map.get(config["model"]["dtype"], torch.bfloat16) |
| 66 | torch.set_default_device(device) |
| 67 | torch.random.manual_seed(config["training"]["random_seed"]) |
| 68 | BATCH_SIZE = config["training"]["batch_size"] |
| 69 | NUM_QUESTIONS_PER_BATCH = config["training"]["num_questions_per_batch"] |
| 70 | NUM_ANSWERS_PER_QUESTION = BATCH_SIZE // NUM_QUESTIONS_PER_BATCH |
| 71 | |
| 72 | current_time = datetime.now().strftime(r"%Y%m%d-%H%M%S") |
| 73 | tb_writer = SummaryWriter(log_dir=f"{config['training']['log_dir']}/{current_time}") |
| 74 | tokenizer = Tokenizer(str(pretrained_model_path / "tokenizer.json")) |
| 75 | |
| 76 | train_dataset = CountdownTasksDataset( |
| 77 | data_path=config["data"]["path"], |
| 78 | tokenizer=tokenizer, |
| 79 | split="train", |
| 80 | test_size=config["data"]["test_size"], |
| 81 | ) |
| 82 | generator = torch.Generator(device=device) |
| 83 | train_dataloader = DataLoader( |
| 84 | train_dataset, |
| 85 | shuffle=True, |
| 86 | collate_fn=CountdownTasksDataset.collate_fn, |
| 87 | generator=generator, |
| 88 | batch_size=NUM_QUESTIONS_PER_BATCH, |
| 89 | ) |
| 90 | |
| 91 | model = Transformer.from_pretrained(pretrained_model_path, device=device).train() |
| 92 | |
| 93 | optimizer = MemoryEfficientAdamW( |
| 94 | model.parameters(), |
| 95 | lr=config["training"]["learning_rate"], |
| 96 | weight_decay=config["training"]["weight_decay"], |
| 97 | betas=config["training"]["betas"], |
| 98 | enabled=config["training"]["memory_efficient_adamw"], |
| 99 | ) |
| 100 | |
| 101 | start_time = time.time() |
| 102 | ckpt_dir = Path(config["training"]["ckpt_dir"]) |
| 103 | ckpt_dir.mkdir(parents=True, exist_ok=True) |
| 104 | |
| 105 | for step, batch in enumerate(train_dataloader, start=1): |
| 106 | episodes = rollout( |
| 107 | model=model, |
| 108 | tokenizer=tokenizer, |
| 109 | batch=batch, |
| 110 | max_gen_len=config["training"]["max_gen_len"], |
| 111 | num_answer_per_question=NUM_ANSWERS_PER_QUESTION, |
no test coverage detected