MCPcopy
hub / github.com/Vchitect/Latte / FeatureStats

Class FeatureStats

tools/metrics/metric_utils.py:64–140  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

62#----------------------------------------------------------------------------
63
64class FeatureStats:
65 def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
66 self.capture_all = capture_all
67 self.capture_mean_cov = capture_mean_cov
68 self.max_items = max_items
69 self.num_items = 0
70 self.num_features = None
71 self.all_features = None
72 self.raw_mean = None
73 self.raw_cov = None
74
75 def set_num_features(self, num_features):
76 if self.num_features is not None:
77 assert num_features == self.num_features
78 else:
79 self.num_features = num_features
80 self.all_features = []
81 self.raw_mean = np.zeros([num_features], dtype=np.float64)
82 self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
83
84 def is_full(self):
85 return (self.max_items is not None) and (self.num_items >= self.max_items)
86
87 def append(self, x):
88 x = np.asarray(x, dtype=np.float32)
89 assert x.ndim == 2
90 if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
91 if self.num_items >= self.max_items:
92 return
93 x = x[:self.max_items - self.num_items]
94
95 self.set_num_features(x.shape[1])
96 self.num_items += x.shape[0]
97 if self.capture_all:
98 self.all_features.append(x)
99 if self.capture_mean_cov:
100 x64 = x.astype(np.float64)
101 self.raw_mean += x64.sum(axis=0)
102 self.raw_cov += x64.T @ x64
103
104 def append_torch(self, x, num_gpus=1, rank=0):
105 assert isinstance(x, torch.Tensor) and x.ndim == 2
106 assert 0 <= rank < num_gpus
107 if num_gpus > 1:
108 ys = []
109 for src in range(num_gpus):
110 y = x.clone()
111 torch.distributed.broadcast(y, src=src)
112 ys.append(y)
113 x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
114 self.append(x.cpu().numpy())
115
116 def get_all(self):
117 assert self.capture_all
118 return np.concatenate(self.all_features, axis=0)
119
120 def get_all_torch(self):
121 return torch.from_numpy(self.get_all())

Callers 1

loadMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected