(self, x, num_gpus=1, rank=0)
| 102 | self.raw_cov += x64.T @ x64 |
| 103 | |
| 104 | def append_torch(self, x, num_gpus=1, rank=0): |
| 105 | assert isinstance(x, torch.Tensor) and x.ndim == 2 |
| 106 | assert 0 <= rank < num_gpus |
| 107 | if num_gpus > 1: |
| 108 | ys = [] |
| 109 | for src in range(num_gpus): |
| 110 | y = x.clone() |
| 111 | torch.distributed.broadcast(y, src=src) |
| 112 | ys.append(y) |
| 113 | x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples |
| 114 | self.append(x.cpu().numpy()) |
| 115 | |
| 116 | def get_all(self): |
| 117 | assert self.capture_all |
no test coverage detected