| 236 | ###################### |
| 237 | @classmethod |
| 238 | def start(cls): |
| 239 | |
| 240 | def is_port_in_use(port: int) -> bool: |
| 241 | import socket |
| 242 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| 243 | return s.connect_ex(('localhost', port)) == 0 |
| 244 | |
| 245 | os.environ['MASTER_PORT'] = str(random.randint(10000, 11000)) |
| 246 | while is_port_in_use(int(os.environ['MASTER_PORT'])): |
| 247 | print(f"| Port {os.environ['MASTER_PORT']} is in use. Change another port...") |
| 248 | os.environ['MASTER_PORT'] = str(random.randint(10000, 11000)) |
| 249 | time.sleep(1) |
| 250 | |
| 251 | random.seed(hparams['seed']) |
| 252 | np.random.seed(hparams['seed']) |
| 253 | work_dir = hparams['work_dir'] |
| 254 | trainer = Trainer( |
| 255 | work_dir=work_dir, |
| 256 | val_check_interval=hparams['val_check_interval'], |
| 257 | tb_log_interval=hparams['tb_log_interval'], |
| 258 | max_updates=hparams['max_updates'], |
| 259 | num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams['validate'] else 10000, |
| 260 | accumulate_grad_batches=hparams['accumulate_grad_batches'], |
| 261 | print_nan_grads=hparams['print_nan_grads'], |
| 262 | resume_from_checkpoint=hparams.get('resume_from_checkpoint', 0), |
| 263 | amp=hparams['amp'], |
| 264 | monitor_key=hparams['valid_monitor_key'], |
| 265 | monitor_mode=hparams['valid_monitor_mode'], |
| 266 | num_ckpt_keep=hparams['num_ckpt_keep'], |
| 267 | save_best=hparams['save_best'], |
| 268 | seed=hparams['seed'], |
| 269 | debug=hparams['debug'] |
| 270 | ) |
| 271 | if not hparams['infer']: # train |
| 272 | trainer.fit(cls) |
| 273 | else: |
| 274 | trainer.test(cls) |
| 275 | |
| 276 | def on_keyboard_interrupt(self): |
| 277 | pass |