(self, keep: torch.Tensor)
| 44 | return self._stats.items() |
| 45 | |
| 46 | def filter(self, keep: torch.Tensor) -> None: |
| 47 | for k, v in self._stats.items(): |
| 48 | if v is None: |
| 49 | self._stats[k] = None |
| 50 | elif isinstance(v, torch.Tensor): |
| 51 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] |
| 52 | elif isinstance(v, np.ndarray): |
| 53 | self._stats[k] = v[keep.detach().cpu().numpy()] |
| 54 | elif isinstance(v, list) and keep.dtype == torch.bool: |
| 55 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] |
| 56 | elif isinstance(v, list): |
| 57 | self._stats[k] = [v[i] for i in keep] |
| 58 | else: |
| 59 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") |
| 60 | |
| 61 | def cat(self, new_stats: "MaskData") -> None: |
| 62 | for k, v in new_stats.items(): |
no test coverage detected