| 380 | |
| 381 | |
| 382 | class NestedTensor(object): |
| 383 | def __init__(self, tensors, mask: Optional[Tensor]): |
| 384 | self.tensors = tensors |
| 385 | self.mask = mask |
| 386 | if mask == 'auto': |
| 387 | self.mask = torch.zeros_like(tensors).to(tensors.device) |
| 388 | if self.mask.dim() == 3: |
| 389 | self.mask = self.mask.sum(0).to(bool) |
| 390 | elif self.mask.dim() == 4: |
| 391 | self.mask = self.mask.sum(1).to(bool) |
| 392 | else: |
| 393 | raise ValueError("tensors dim must be 3 or 4 but {}({})".format(self.tensors.dim(), self.tensors.shape)) |
| 394 | |
| 395 | def imgsize(self): |
| 396 | res = [] |
| 397 | for i in range(self.tensors.shape[0]): |
| 398 | mask = self.mask[i] |
| 399 | maxH = (~mask).sum(0).max() |
| 400 | maxW = (~mask).sum(1).max() |
| 401 | res.append(torch.Tensor([maxH, maxW])) |
| 402 | return res |
| 403 | |
| 404 | def to(self, device): |
| 405 | # type: (Device) -> NestedTensor # noqa |
| 406 | cast_tensor = self.tensors.to(device) |
| 407 | mask = self.mask |
| 408 | if mask is not None: |
| 409 | assert mask is not None |
| 410 | cast_mask = mask.to(device) |
| 411 | else: |
| 412 | cast_mask = None |
| 413 | return NestedTensor(cast_tensor, cast_mask) |
| 414 | |
| 415 | def to_img_list_single(self, tensor, mask): |
| 416 | assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim()) |
| 417 | maxH = (~mask).sum(0).max() |
| 418 | maxW = (~mask).sum(1).max() |
| 419 | img = tensor[:, :maxH, :maxW] |
| 420 | return img |
| 421 | |
| 422 | def to_img_list(self): |
| 423 | """remove the padding and convert to img list |
| 424 | |
| 425 | Returns: |
| 426 | [type]: [description] |
| 427 | """ |
| 428 | if self.tensors.dim() == 3: |
| 429 | return self.to_img_list_single(self.tensors, self.mask) |
| 430 | else: |
| 431 | res = [] |
| 432 | for i in range(self.tensors.shape[0]): |
| 433 | tensor_i = self.tensors[i] |
| 434 | mask_i = self.mask[i] |
| 435 | res.append(self.to_img_list_single(tensor_i, mask_i)) |
| 436 | return res |
| 437 | |
| 438 | @property |
| 439 | def device(self): |
no outgoing calls
no test coverage detected