MCPcopy
hub / github.com/THUDM/CogDL / Trainer

Class Trainer

cogdl/trainer/trainer.py:56–562  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

54
55
56class Trainer(object):
57 def __init__(
58 self,
59 epochs: int,
60 max_epoch: int = None,
61 nstage: int = 1,
62 cpu: bool = False,
63 checkpoint_path: str = "./checkpoints/model.pt",
64 resume_training: str = False,
65 device_ids: Optional[list] = None,
66 distributed_training: bool = False,
67 distributed_inference: bool = False,
68 master_addr: str = "localhost",
69 master_port: int = 10086,
70 early_stopping: bool = True,
71 patience: int = 100,
72 eval_step: int = 1,
73 save_emb_path: Optional[str] = None,
74 load_emb_path: Optional[str] = None,
75 cpu_inference: bool = False,
76 progress_bar: str = "epoch",
77 clip_grad_norm: float = 5.0,
78 logger: str = None,
79 log_path: str = "./runs",
80 project: str = "cogdl-exp",
81 return_model: bool = False,
82 actnn: bool = False,
83 fp16: bool = False,
84 rp_ratio: int = 1,
85 attack=None,
86 attack_mode="injection",
87 do_test: bool = True,
88 do_valid: bool = True,
89 ):
90 self.epochs = epochs
91 self.nstage = nstage
92 self.patience = patience
93 self.early_stopping = early_stopping
94 self.eval_step = eval_step
95 self.monitor = None
96 self.evaluation_metric = None
97 self.progress_bar = progress_bar
98
99 if max_epoch is not None:
100 warnings.warn("The max_epoch is deprecated and will be removed in the future, please use epochs instead!")
101 self.epochs = max_epoch
102
103 self.cpu = cpu
104 self.devices, self.world_size = self.set_device(device_ids)
105 self.checkpoint_path = checkpoint_path
106 self.resume_training = resume_training
107
108 self.distributed_training = distributed_training
109 self.distributed_inference = distributed_inference
110
111 self.master_addr = master_addr
112 self.master_port = master_port
113

Callers 8

test_adv.pyFile · 0.90
test_defense.pyFile · 0.90
test_injection.pyFile · 0.90
train_modelFunction · 0.90
test_adversarial_trainFunction · 0.90
trainFunction · 0.90
trainFunction · 0.90

Calls

no outgoing calls

Tested by 2

train_modelFunction · 0.72
test_adversarial_trainFunction · 0.72