| 46 | |
| 47 | |
| 48 | def train(): |
| 49 | parser = ArgumentParser() |
| 50 | parser.add_argument('--gpt2', action='store_true', help="use gpt2") |
| 51 | parser.add_argument("--model_checkpoint", type=str, default="config/cgpt/", help="Path or URL of the model") |
| 52 | parser.add_argument("--from_step", type=int, default=-1, help="Init learning rate from this step") |
| 53 | parser.add_argument('--pretrained', action='store_true', help="If False train from scratch") |
| 54 | parser.add_argument("--data_path", type=str, default="", |
| 55 | help="Path or url of the dataset. ") |
| 56 | parser.add_argument("--train_path", type=str, default="data/toy_train.txt", |
| 57 | help="Path of the train dataset for dist dataset. ") |
| 58 | parser.add_argument("--valid_path", type=str, default="data/toy_valid.txt", |
| 59 | help="Path of the valid dataset for dist dataset. ") |
| 60 | parser.add_argument("--dataset_cache", type=str, default="dataset_cache", |
| 61 | help="Path or url of the dataset cache") |
| 62 | parser.add_argument('--log_file', '-log_file', type=str, default="", help="Output logs to a file under this path") |
| 63 | parser.add_argument("--num_workers", type=int, default=8, help="Number of subprocesses for data loading") |
| 64 | parser.add_argument("--n_epochs", type=int, default=70, help="Number of training epochs") |
| 65 | parser.add_argument("--train_batch_size", type=int, default=2, help="Batch size for training") |
| 66 | parser.add_argument("--valid_batch_size", type=int, default=2, help="Batch size for validation") |
| 67 | parser.add_argument("--max_history", type=int, default=15, help="Number of previous exchanges to keep in history") |
| 68 | parser.add_argument("--scheduler", type=str, default="noam", choices=['noam', 'linear'], help="method of optim") |
| 69 | parser.add_argument("--n_emd", type=int, default=768, help="Number of n_emd in config file (for noam)") |
| 70 | parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate") |
| 71 | parser.add_argument("--eval_before_start", action='store_true', |
| 72 | help="If true start with a first evaluation before training") |
| 73 | parser.add_argument("--warmup_steps", type=int, default=5000, help="Warm up steps") |
| 74 | parser.add_argument("--valid_steps", type=int, default=5000, help="Perfom validation every X steps") |
| 75 | parser.add_argument("--gradient_accumulation_steps", type=int, default=64, |
| 76 | help="Accumulate gradients on several steps") |
| 77 | parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") |
| 78 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", |
| 79 | help="Device (cuda or cpu)") |
| 80 | parser.add_argument("--fp16", type=str, default="", |
| 81 | help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") |
| 82 | parser.add_argument("--local_rank", type=int, default=-1, |
| 83 | help="Local rank for distributed training (-1: not distributed)") |
| 84 | args = parser.parse_args() |
| 85 | |
| 86 | # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. |
| 87 | # logger.info => log main process only, logger.warning => log all processes |
| 88 | logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) |
| 89 | logger.warning("Running process %d", args.local_rank) |
| 90 | logger.info("Arguments: %s", pformat(args)) |
| 91 | |
| 92 | # Initialize distributed training if needed |
| 93 | args.distributed = (args.local_rank != -1) |
| 94 | if args.distributed: |
| 95 | torch.cuda.set_device(args.local_rank) |
| 96 | args.device = torch.device("cuda", args.local_rank) |
| 97 | torch.distributed.init_process_group(backend='nccl', init_method='env://') |
| 98 | |
| 99 | logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning") |
| 100 | model_class = OpenAIGPTLMHeadModel if not args.gpt2 else GPT2LMHeadModel |
| 101 | config_class = OpenAIGPTConfig if not args.gpt2 else GPT2Config |
| 102 | tokenizer_class = BertTokenizer |
| 103 | if args.pretrained: |
| 104 | tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint, do_lower_case=True, never_split=["[speaker1]", "[speaker2]"]) |
| 105 | model = model_class.from_pretrained(args.model_checkpoint) |