MCPcopy
hub / github.com/yerfor/GeneFacePlusPlus / HumanOutputFormat

Class HumanOutputFormat

modules/commons/improved_diffusion/logger.py:36–95  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

34
35
36class HumanOutputFormat(KVWriter, SeqWriter):
37 def __init__(self, filename_or_file):
38 if isinstance(filename_or_file, str):
39 self.file = open(filename_or_file, "wt")
40 self.own_file = True
41 else:
42 assert hasattr(filename_or_file, "read"), (
43 "expected file or str, got %s" % filename_or_file
44 )
45 self.file = filename_or_file
46 self.own_file = False
47
48 def writekvs(self, kvs):
49 # Create strings for printing
50 key2str = {}
51 for (key, val) in sorted(kvs.items()):
52 if hasattr(val, "__float__"):
53 valstr = "%-8.3g" % val
54 else:
55 valstr = str(val)
56 key2str[self._truncate(key)] = self._truncate(valstr)
57
58 # Find max widths
59 if len(key2str) == 0:
60 print("WARNING: tried to write empty key-value dict")
61 return
62 else:
63 keywidth = max(map(len, key2str.keys()))
64 valwidth = max(map(len, key2str.values()))
65
66 # Write out the data
67 dashes = "-" * (keywidth + valwidth + 7)
68 lines = [dashes]
69 for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70 lines.append(
71 "| %s%s | %s%s |"
72 % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73 )
74 lines.append(dashes)
75 self.file.write("\n".join(lines) + "\n")
76
77 # Flush the output to the file
78 self.file.flush()
79
80 def _truncate(self, s):
81 maxlen = 30
82 return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83
84 def writeseq(self, seq):
85 seq = list(seq)
86 for (i, elem) in enumerate(seq):
87 self.file.write(elem)
88 if i < len(seq) - 1: # add space unless this is the last one
89 self.file.write(" ")
90 self.file.write("\n")
91 self.file.flush()
92
93 def close(self):

Callers 1

make_output_formatFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected