| 330 | |
| 331 | |
| 332 | class Logger(object): |
| 333 | DEFAULT = None # A logger with no output files. (See right below class definition) |
| 334 | # So that you can still log to the terminal without setting up any output files |
| 335 | CURRENT = None # Current logger being used by the free functions above |
| 336 | |
| 337 | def __init__(self, dir, output_formats, comm=None): |
| 338 | self.name2val = defaultdict(float) # values this iteration |
| 339 | self.name2cnt = defaultdict(int) |
| 340 | self.level = INFO |
| 341 | self.dir = dir |
| 342 | self.output_formats = output_formats |
| 343 | self.comm = comm |
| 344 | |
| 345 | # Logging API, forwarded |
| 346 | # ---------------------------------------- |
| 347 | def logkv(self, key, val): |
| 348 | self.name2val[key] = val |
| 349 | |
| 350 | def logkv_mean(self, key, val): |
| 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] |
| 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) |
| 353 | self.name2cnt[key] = cnt + 1 |
| 354 | |
| 355 | def dumpkvs(self): |
| 356 | if self.comm is None: |
| 357 | d = self.name2val |
| 358 | else: |
| 359 | d = mpi_weighted_mean( |
| 360 | self.comm, |
| 361 | { |
| 362 | name: (val, self.name2cnt.get(name, 1)) |
| 363 | for (name, val) in self.name2val.items() |
| 364 | }, |
| 365 | ) |
| 366 | if self.comm.rank != 0: |
| 367 | d["dummy"] = 1 # so we don't get a warning about empty dict |
| 368 | out = d.copy() # Return the dict for unit testing purposes |
| 369 | for fmt in self.output_formats: |
| 370 | if isinstance(fmt, KVWriter): |
| 371 | fmt.writekvs(d) |
| 372 | self.name2val.clear() |
| 373 | self.name2cnt.clear() |
| 374 | return out |
| 375 | |
| 376 | def log(self, *args, level=INFO): |
| 377 | if self.level <= level: |
| 378 | self._do_log(args) |
| 379 | |
| 380 | # Configuration |
| 381 | # ---------------------------------------- |
| 382 | def set_level(self, level): |
| 383 | self.level = level |
| 384 | |
| 385 | def set_comm(self, comm): |
| 386 | self.comm = comm |
| 387 | |
| 388 | def get_dir(self): |
| 389 | return self.dir |