| 13 | |
| 14 | |
| 15 | class NestedTensor(object): |
| 16 | |
| 17 | def __init__(self, tensors, mask: Optional[Tensor]): |
| 18 | self.tensors = tensors |
| 19 | self.mask = mask |
| 20 | |
| 21 | def to(self, device): |
| 22 | # type: (Device) -> NestedTensor # noqa |
| 23 | cast_tensor = self.tensors.to(device) |
| 24 | mask = self.mask |
| 25 | if mask is not None: |
| 26 | assert mask is not None |
| 27 | cast_mask = mask.to(device) |
| 28 | else: |
| 29 | cast_mask = None |
| 30 | return NestedTensor(cast_tensor, cast_mask) |
| 31 | |
| 32 | def decompose(self): |
| 33 | return self.tensors, self.mask |
| 34 | |
| 35 | def __repr__(self): |
| 36 | return str(self.tensors) |
| 37 | |
| 38 | |
| 39 | class PositionEmbeddingSine(nn.Module): |