| 282 | |
| 283 | |
| 284 | class NestedTensor(object): |
| 285 | def __init__(self, tensors, mask: Optional[Tensor]): |
| 286 | self.tensors = tensors |
| 287 | self.mask = mask |
| 288 | |
| 289 | def to(self, device): |
| 290 | # type: (Device) -> NestedTensor # noqa |
| 291 | cast_tensor = self.tensors.to(device) |
| 292 | mask = self.mask |
| 293 | if mask is not None: |
| 294 | assert mask is not None |
| 295 | cast_mask = mask.to(device) |
| 296 | else: |
| 297 | cast_mask = None |
| 298 | return NestedTensor(cast_tensor, cast_mask) |
| 299 | |
| 300 | def decompose(self): |
| 301 | return self.tensors, self.mask |
| 302 | |
| 303 | def __repr__(self): |
| 304 | return str(self.tensors) |
| 305 | |
| 306 | |
| 307 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): |
no outgoing calls
no test coverage detected