MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / __post_init__

Method __post_init__

triton_kernels/tensor.py:107–132  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

105 shape_max: list[int] | None = None
106
107 def __post_init__(self):
108 # set storage
109 if isinstance(self.storage, torch.Tensor):
110 self.storage = Storage(self.storage)
111 # initialize dtype
112 if self.dtype is None:
113 self.dtype = self.storage.data.dtype
114 if bitwidth(self.dtype) < 8 and self.shape is None:
115 raise ValueError("shape must be provided for sub-byte types")
116 # initialize shape
117 if self.shape is None:
118 self.shape = list(self.storage.data.shape)
119 # validate shape: all elements must be `int` or numel-1 `torch.Tensor`
120 is_int = lambda s: isinstance(s, int)
121 is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
122 assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
123 # initialize shape_max
124 if self.shape_max is None:
125 self.shape_max = [None] * len(self.shape)
126 for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
127 if smax is not None and not is_int(smax):
128 raise ValueError(f"shape_max[{i}] must be `int` or `None`; got {type(smax)}")
129 if smax is None:
130 self.shape_max[i] = s
131 # validate shape_max: all elements must be `int`
132 assert all(map(is_int, self.shape_max))
133
134 # torch compatibility layer
135 @property

Callers

nothing calls this directly

Calls 3

StorageClass · 0.85
bitwidthFunction · 0.85
numelMethod · 0.45

Tested by

no test coverage detected