MCPcopy
hub / github.com/KlingAIResearch/LivePortrait / NestedTensor

Class NestedTensor

src/utils/dependencies/XPose/util/misc.py:382–453  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

380
381
382class NestedTensor(object):
383 def __init__(self, tensors, mask: Optional[Tensor]):
384 self.tensors = tensors
385 self.mask = mask
386 if mask == 'auto':
387 self.mask = torch.zeros_like(tensors).to(tensors.device)
388 if self.mask.dim() == 3:
389 self.mask = self.mask.sum(0).to(bool)
390 elif self.mask.dim() == 4:
391 self.mask = self.mask.sum(1).to(bool)
392 else:
393 raise ValueError("tensors dim must be 3 or 4 but {}({})".format(self.tensors.dim(), self.tensors.shape))
394
395 def imgsize(self):
396 res = []
397 for i in range(self.tensors.shape[0]):
398 mask = self.mask[i]
399 maxH = (~mask).sum(0).max()
400 maxW = (~mask).sum(1).max()
401 res.append(torch.Tensor([maxH, maxW]))
402 return res
403
404 def to(self, device):
405 # type: (Device) -> NestedTensor # noqa
406 cast_tensor = self.tensors.to(device)
407 mask = self.mask
408 if mask is not None:
409 assert mask is not None
410 cast_mask = mask.to(device)
411 else:
412 cast_mask = None
413 return NestedTensor(cast_tensor, cast_mask)
414
415 def to_img_list_single(self, tensor, mask):
416 assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim())
417 maxH = (~mask).sum(0).max()
418 maxW = (~mask).sum(1).max()
419 img = tensor[:, :maxH, :maxW]
420 return img
421
422 def to_img_list(self):
423 """remove the padding and convert to img list
424
425 Returns:
426 [type]: [description]
427 """
428 if self.tensors.dim() == 3:
429 return self.to_img_list_single(self.tensors, self.mask)
430 else:
431 res = []
432 for i in range(self.tensors.shape[0]):
433 tensor_i = self.tensors[i]
434 mask_i = self.mask[i]
435 res.append(self.to_img_list_single(tensor_i, mask_i))
436 return res
437
438 @property
439 def device(self):

Callers 6

forwardMethod · 0.90
forwardMethod · 0.90
forwardMethod · 0.90
toMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected