| 54 | |
| 55 | |
| 56 | class Trainer(object): |
| 57 | def __init__( |
| 58 | self, |
| 59 | epochs: int, |
| 60 | max_epoch: int = None, |
| 61 | nstage: int = 1, |
| 62 | cpu: bool = False, |
| 63 | checkpoint_path: str = "./checkpoints/model.pt", |
| 64 | resume_training: str = False, |
| 65 | device_ids: Optional[list] = None, |
| 66 | distributed_training: bool = False, |
| 67 | distributed_inference: bool = False, |
| 68 | master_addr: str = "localhost", |
| 69 | master_port: int = 10086, |
| 70 | early_stopping: bool = True, |
| 71 | patience: int = 100, |
| 72 | eval_step: int = 1, |
| 73 | save_emb_path: Optional[str] = None, |
| 74 | load_emb_path: Optional[str] = None, |
| 75 | cpu_inference: bool = False, |
| 76 | progress_bar: str = "epoch", |
| 77 | clip_grad_norm: float = 5.0, |
| 78 | logger: str = None, |
| 79 | log_path: str = "./runs", |
| 80 | project: str = "cogdl-exp", |
| 81 | return_model: bool = False, |
| 82 | actnn: bool = False, |
| 83 | fp16: bool = False, |
| 84 | rp_ratio: int = 1, |
| 85 | attack=None, |
| 86 | attack_mode="injection", |
| 87 | do_test: bool = True, |
| 88 | do_valid: bool = True, |
| 89 | ): |
| 90 | self.epochs = epochs |
| 91 | self.nstage = nstage |
| 92 | self.patience = patience |
| 93 | self.early_stopping = early_stopping |
| 94 | self.eval_step = eval_step |
| 95 | self.monitor = None |
| 96 | self.evaluation_metric = None |
| 97 | self.progress_bar = progress_bar |
| 98 | |
| 99 | if max_epoch is not None: |
| 100 | warnings.warn("The max_epoch is deprecated and will be removed in the future, please use epochs instead!") |
| 101 | self.epochs = max_epoch |
| 102 | |
| 103 | self.cpu = cpu |
| 104 | self.devices, self.world_size = self.set_device(device_ids) |
| 105 | self.checkpoint_path = checkpoint_path |
| 106 | self.resume_training = resume_training |
| 107 | |
| 108 | self.distributed_training = distributed_training |
| 109 | self.distributed_inference = distributed_inference |
| 110 | |
| 111 | self.master_addr = master_addr |
| 112 | self.master_port = master_port |
| 113 |
no outgoing calls