MCPcopy
hub / github.com/thu-coai/CDial-GPT / train

Function train

train.py:48–233  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

46
47
48def train():
49 parser = ArgumentParser()
50 parser.add_argument('--gpt2', action='store_true', help="use gpt2")
51 parser.add_argument("--model_checkpoint", type=str, default="config/cgpt/", help="Path or URL of the model")
52 parser.add_argument("--from_step", type=int, default=-1, help="Init learning rate from this step")
53 parser.add_argument('--pretrained', action='store_true', help="If False train from scratch")
54 parser.add_argument("--data_path", type=str, default="",
55 help="Path or url of the dataset. ")
56 parser.add_argument("--train_path", type=str, default="data/toy_train.txt",
57 help="Path of the train dataset for dist dataset. ")
58 parser.add_argument("--valid_path", type=str, default="data/toy_valid.txt",
59 help="Path of the valid dataset for dist dataset. ")
60 parser.add_argument("--dataset_cache", type=str, default="dataset_cache",
61 help="Path or url of the dataset cache")
62 parser.add_argument('--log_file', '-log_file', type=str, default="", help="Output logs to a file under this path")
63 parser.add_argument("--num_workers", type=int, default=8, help="Number of subprocesses for data loading")
64 parser.add_argument("--n_epochs", type=int, default=70, help="Number of training epochs")
65 parser.add_argument("--train_batch_size", type=int, default=2, help="Batch size for training")
66 parser.add_argument("--valid_batch_size", type=int, default=2, help="Batch size for validation")
67 parser.add_argument("--max_history", type=int, default=15, help="Number of previous exchanges to keep in history")
68 parser.add_argument("--scheduler", type=str, default="noam", choices=['noam', 'linear'], help="method of optim")
69 parser.add_argument("--n_emd", type=int, default=768, help="Number of n_emd in config file (for noam)")
70 parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
71 parser.add_argument("--eval_before_start", action='store_true',
72 help="If true start with a first evaluation before training")
73 parser.add_argument("--warmup_steps", type=int, default=5000, help="Warm up steps")
74 parser.add_argument("--valid_steps", type=int, default=5000, help="Perfom validation every X steps")
75 parser.add_argument("--gradient_accumulation_steps", type=int, default=64,
76 help="Accumulate gradients on several steps")
77 parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
78 parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
79 help="Device (cuda or cpu)")
80 parser.add_argument("--fp16", type=str, default="",
81 help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
82 parser.add_argument("--local_rank", type=int, default=-1,
83 help="Local rank for distributed training (-1: not distributed)")
84 args = parser.parse_args()
85
86 # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process.
87 # logger.info => log main process only, logger.warning => log all processes
88 logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
89 logger.warning("Running process %d", args.local_rank)
90 logger.info("Arguments: %s", pformat(args))
91
92 # Initialize distributed training if needed
93 args.distributed = (args.local_rank != -1)
94 if args.distributed:
95 torch.cuda.set_device(args.local_rank)
96 args.device = torch.device("cuda", args.local_rank)
97 torch.distributed.init_process_group(backend='nccl', init_method='env://')
98
99 logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
100 model_class = OpenAIGPTLMHeadModel if not args.gpt2 else GPT2LMHeadModel
101 config_class = OpenAIGPTConfig if not args.gpt2 else GPT2Config
102 tokenizer_class = BertTokenizer
103 if args.pretrained:
104 tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint, do_lower_case=True, never_split=["[speaker1]", "[speaker2]"])
105 model = model_class.from_pretrained(args.model_checkpoint)

Callers 1

train.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected