| 265 | return new_tensor |
| 266 | |
| 267 | def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: |
| 268 | if isinstance(dim, int): |
| 269 | dim = (dim,) |
| 270 | |
| 271 | if op =='mean': |
| 272 | red = self.feats.mean(dim=dim, keepdim=keepdim) |
| 273 | elif op =='sum': |
| 274 | red = self.feats.sum(dim=dim, keepdim=keepdim) |
| 275 | elif op == 'prod': |
| 276 | red = self.feats.prod(dim=dim, keepdim=keepdim) |
| 277 | else: |
| 278 | raise ValueError(f"Unsupported reduce operation: {op}") |
| 279 | |
| 280 | if dim is None or 0 in dim: |
| 281 | return red |
| 282 | |
| 283 | red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) |
| 284 | return red |
| 285 | |
| 286 | def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: |
| 287 | return self.reduce(op='mean', dim=dim, keepdim=keepdim) |