Main training program.
()
| 287 | |
| 288 | |
| 289 | def main(): |
| 290 | """Main training program.""" |
| 291 | |
| 292 | print('Generate Samples') |
| 293 | |
| 294 | # Disable CuDNN. |
| 295 | torch.backends.cudnn.enabled = False |
| 296 | |
| 297 | # Arguments. |
| 298 | args = get_args() |
| 299 | args.mem_length = args.seq_length + args.mem_length - 1 |
| 300 | |
| 301 | # Pytorch distributed. |
| 302 | initialize_distributed(args) |
| 303 | |
| 304 | # Random seeds for reproducability. |
| 305 | set_random_seed(args.seed) |
| 306 | |
| 307 | # get the tokenizer |
| 308 | tokenizer = prepare_tokenizer(args) |
| 309 | |
| 310 | # Model, optimizer, and learning rate. |
| 311 | model = setup_model(args) |
| 312 | |
| 313 | # setting default batch size to 1 |
| 314 | args.batch_size = 1 |
| 315 | |
| 316 | # generate samples |
| 317 | generate_samples(model, tokenizer, args, torch.cuda.current_device()) |
| 318 | |
| 319 | |
| 320 | if __name__ == "__main__": |
no test coverage detected