| 62 | #---------------------------------------------------------------------------- |
| 63 | |
| 64 | class FeatureStats: |
| 65 | def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): |
| 66 | self.capture_all = capture_all |
| 67 | self.capture_mean_cov = capture_mean_cov |
| 68 | self.max_items = max_items |
| 69 | self.num_items = 0 |
| 70 | self.num_features = None |
| 71 | self.all_features = None |
| 72 | self.raw_mean = None |
| 73 | self.raw_cov = None |
| 74 | |
| 75 | def set_num_features(self, num_features): |
| 76 | if self.num_features is not None: |
| 77 | assert num_features == self.num_features |
| 78 | else: |
| 79 | self.num_features = num_features |
| 80 | self.all_features = [] |
| 81 | self.raw_mean = np.zeros([num_features], dtype=np.float64) |
| 82 | self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) |
| 83 | |
| 84 | def is_full(self): |
| 85 | return (self.max_items is not None) and (self.num_items >= self.max_items) |
| 86 | |
| 87 | def append(self, x): |
| 88 | x = np.asarray(x, dtype=np.float32) |
| 89 | assert x.ndim == 2 |
| 90 | if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): |
| 91 | if self.num_items >= self.max_items: |
| 92 | return |
| 93 | x = x[:self.max_items - self.num_items] |
| 94 | |
| 95 | self.set_num_features(x.shape[1]) |
| 96 | self.num_items += x.shape[0] |
| 97 | if self.capture_all: |
| 98 | self.all_features.append(x) |
| 99 | if self.capture_mean_cov: |
| 100 | x64 = x.astype(np.float64) |
| 101 | self.raw_mean += x64.sum(axis=0) |
| 102 | self.raw_cov += x64.T @ x64 |
| 103 | |
| 104 | def append_torch(self, x, num_gpus=1, rank=0): |
| 105 | assert isinstance(x, torch.Tensor) and x.ndim == 2 |
| 106 | assert 0 <= rank < num_gpus |
| 107 | if num_gpus > 1: |
| 108 | ys = [] |
| 109 | for src in range(num_gpus): |
| 110 | y = x.clone() |
| 111 | torch.distributed.broadcast(y, src=src) |
| 112 | ys.append(y) |
| 113 | x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples |
| 114 | self.append(x.cpu().numpy()) |
| 115 | |
| 116 | def get_all(self): |
| 117 | assert self.capture_all |
| 118 | return np.concatenate(self.all_features, axis=0) |
| 119 | |
| 120 | def get_all_torch(self): |
| 121 | return torch.from_numpy(self.get_all()) |