Get the broadcast map for the varlen tensor.
(self)
| 560 | |
| 561 | @property |
| 562 | def batch_boardcast_map(self) -> torch.LongTensor: |
| 563 | """ |
| 564 | Get the broadcast map for the varlen tensor. |
| 565 | """ |
| 566 | batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map') |
| 567 | if batch_boardcast_map is None: |
| 568 | batch_boardcast_map = torch.repeat_interleave( |
| 569 | torch.arange(len(self.layout), device=self.device), |
| 570 | self.seqlen, |
| 571 | ) |
| 572 | self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map) |
| 573 | return batch_boardcast_map |
| 574 | |
| 575 | @overload |
| 576 | def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... |
nothing calls this directly
no test coverage detected