MCPcopy
hub / github.com/yangchris11/samurai / NestedTensor

Class NestedTensor

lib/utils/misc.py:284–304  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

282
283
284class NestedTensor(object):
285 def __init__(self, tensors, mask: Optional[Tensor]):
286 self.tensors = tensors
287 self.mask = mask
288
289 def to(self, device):
290 # type: (Device) -> NestedTensor # noqa
291 cast_tensor = self.tensors.to(device)
292 mask = self.mask
293 if mask is not None:
294 assert mask is not None
295 cast_mask = mask.to(device)
296 else:
297 cast_mask = None
298 return NestedTensor(cast_tensor, cast_mask)
299
300 def decompose(self):
301 return self.tensors, self.mask
302
303 def __repr__(self):
304 return str(self.tensors)
305
306
307def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):

Callers 4

processMethod · 0.90
toMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected