Main training program.
()
| 293 | |
| 294 | |
| 295 | def main(): |
| 296 | """Main training program.""" |
| 297 | |
| 298 | print('Generate Samples') |
| 299 | |
| 300 | # Disable CuDNN. |
| 301 | torch.backends.cudnn.enabled = False |
| 302 | |
| 303 | # Arguments. |
| 304 | args = get_args() |
| 305 | |
| 306 | # Pytorch distributed. |
| 307 | initialize_distributed(args) |
| 308 | |
| 309 | # set device, this args.device is only used in inference |
| 310 | if args.device is not None: |
| 311 | device = int(args.device) |
| 312 | torch.cuda.set_device(device) |
| 313 | |
| 314 | # Random seeds for reproducability. |
| 315 | set_random_seed(args.seed) |
| 316 | |
| 317 | # get the tokenizer |
| 318 | tokenizer = prepare_tokenizer(args) |
| 319 | |
| 320 | # Model, optimizer, and learning rate. |
| 321 | model = setup_model(args) |
| 322 | |
| 323 | generate_images_continually(model, args) |
| 324 | |
| 325 | if __name__ == "__main__": |
| 326 | main() |
no test coverage detected