| 193 | return stats.xent() |
| 194 | |
| 195 | def wait_and_validate(self): |
| 196 | time_step = 0 |
| 197 | if self.args.test_all: |
| 198 | cp_files = sorted(glob.glob(os.path.join(self.args.model_path, 'model_step_*.pt'))) |
| 199 | cp_files.sort(key=os.path.getmtime) |
| 200 | xent_lst = [] |
| 201 | for i, cp in enumerate(cp_files): |
| 202 | step = int(cp.split('.')[-2].split('_')[-1]) |
| 203 | xent = self.validate(step=step) |
| 204 | xent_lst.append((xent, cp)) |
| 205 | max_step = xent_lst.index(min(xent_lst)) |
| 206 | if i - max_step > 10: |
| 207 | break |
| 208 | xent_lst = sorted(xent_lst, key=lambda x: x[0])[:3] |
| 209 | logger.info('PPL %s' % str(xent_lst)) |
| 210 | for xent, cp in xent_lst: |
| 211 | step = int(cp.split('.')[-2].split('_')[-1]) |
| 212 | self.test(step) |
| 213 | else: |
| 214 | while True: |
| 215 | cp_files = sorted(glob.glob(os.path.join(self.args.model_path, 'model_step_*.pt'))) |
| 216 | cp_files.sort(key=os.path.getmtime) |
| 217 | if cp_files: |
| 218 | cp = cp_files[-1] |
| 219 | time_of_cp = os.path.getmtime(cp) |
| 220 | if not os.path.getsize(cp) > 0: |
| 221 | time.sleep(60) |
| 222 | continue |
| 223 | if time_of_cp > time_step: |
| 224 | time_step = time_of_cp |
| 225 | step = int(cp.split('.')[-2].split('_')[-1]) |
| 226 | self.validate(step) |
| 227 | self.test(step) |
| 228 | |
| 229 | cp_files = sorted(glob.glob(os.path.join(self.args.model_path, 'model_step_*.pt'))) |
| 230 | cp_files.sort(key=os.path.getmtime) |
| 231 | if cp_files: |
| 232 | cp = cp_files[-1] |
| 233 | time_of_cp = os.path.getmtime(cp) |
| 234 | if time_of_cp > time_step: |
| 235 | continue |
| 236 | else: |
| 237 | time.sleep(300) |
| 238 | |
| 239 | def test(self, step=None): |
| 240 | if not step: |