MCPcopy
hub / github.com/shenweichen/DeepCTR-Torch / EarlyStopping

Class EarlyStopping

deepctr_torch/callbacks.py:84–155  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

82
83
84class EarlyStopping(Callback):
85 def __init__(
86 self,
87 monitor="val_loss",
88 min_delta=0,
89 patience=0,
90 verbose=0,
91 mode="auto",
92 baseline=None,
93 restore_best_weights=False,
94 ):
95 """Create an early-stopping callback."""
96 super(EarlyStopping, self).__init__()
97 self.monitor = monitor
98 self.min_delta = abs(min_delta)
99 self.patience = patience
100 self.verbose = verbose
101 self.mode = mode
102 self.baseline = baseline
103 self.restore_best_weights = restore_best_weights
104
105 if mode not in {"auto", "min", "max"}:
106 raise ValueError("mode should be one of {'auto', 'min', 'max'}")
107
108 if mode == "min":
109 self.monitor_op = np.less
110 elif mode == "max":
111 self.monitor_op = np.greater
112 else:
113 if "acc" in self.monitor or self.monitor.endswith("auc") or self.monitor.startswith("fmeasure"):
114 self.monitor_op = np.greater
115 else:
116 self.monitor_op = np.less
117
118 def on_train_begin(self, logs=None):
119 self.wait = 0
120 self.stopped_epoch = 0
121 self.best_weights = None
122 if self.baseline is not None:
123 self.best = self.baseline
124 else:
125 self.best = np.inf if self.monitor_op == np.less else -np.inf
126
127 def _is_improvement(self, current, best):
128 if self.monitor_op == np.less:
129 return current < (best - self.min_delta)
130 return current > (best + self.min_delta)
131
132 def on_epoch_end(self, epoch, logs=None):
133 logs = logs or {}
134 current = logs.get(self.monitor)
135 if current is None:
136 return
137
138 if self._is_improvement(current, self.best):
139 self.best = current
140 self.wait = 0
141 if self.restore_best_weights and self.model is not None:

Callers 4

check_modelFunction · 0.90
check_mtl_modelFunction · 0.90
test_AFMFunction · 0.90

Calls

no outgoing calls

Tested by 2

test_AFMFunction · 0.72