MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / TensorInfo

Class TensorInfo

tensorrt_llm/runtime/session.py:46–80  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

44
45@dataclass
46class TensorInfo:
47 name: str
48 dtype: trt.DataType
49 shape: tuple
50
51 # add more info like strides, formats if needed
52
53 def numel(self):
54 return prod(self.shape)
55
56 def view(self, *shape):
57 assert set(map(type, shape)) == {int}
58 n_unknown = len(tuple(filter(lambda l: l < 0, shape)))
59 new_shape = list(shape)
60 if n_unknown == 0:
61 assert prod(shape) == self.numel()
62 elif n_unknown == 1:
63 n_known_elements = prod(filter(lambda l: l >= 0, shape))
64 for i, l in enumerate(new_shape):
65 if l == -1:
66 assert self.numel() % n_known_elements == 0
67 new_shape[i] = self.numel() // n_known_elements
68 break
69 else:
70 raise ValueError('More than one dimensions need to be inferred!')
71 return TensorInfo(self.name, self.dtype, tuple(new_shape))
72
73 def __len__(self):
74 return self.shape[0]
75
76 def squeeze(self, dim=0):
77 if self.shape[dim] != 1:
78 raise ValueError(f"dim {dim} is {self.shape[dim]} instead of 1!")
79 return TensorInfo(self.name, self.dtype,
80 self.shape[:dim] + self.shape[dim + 1:])
81
82
83class Session(object):

Callers 15

run_sessionFunction · 0.90
run_vision_encoderMethod · 0.90
run_engineFunction · 0.90
runFunction · 0.90
runFunction · 0.90
prepareMethod · 0.90
_setupMethod · 0.90
vae_decodeFunction · 0.90
run.pyFile · 0.90
vit_processFunction · 0.90
get_audio_featuresMethod · 0.90
audio_towerMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected