MCPcopy
hub / github.com/OpenPPL/ppq / TensorMeta

Class TensorMeta

ppq/core/data.py:115–172  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

113
114
115class TensorMeta:
116 def __init__(
117 self, dtype: DataType, shape: List[int],
118 tensor_name: str = None) -> None:
119 """TensorMeta structure described metadata of a tensor.
120
121 Which includes tensor's data type and shape.
122 TensorMeta is necessary to initialize quantization configuration and hooks,
123 and is needed to compute the number of input channels.
124 Args:
125 dtype (DataType):
126 A DataType enumeration described tensor type.
127 shape (List[int]):
128 A int list contains size of each dimensions.
129 tensor_name (str, optional): Not yet used.
130 """
131 if not isinstance(dtype, DataType):
132 raise TypeError(f'Can not create Tensor Meta with dtype {type(dtype)}, '
133 'only ppq.core.DataType instance is acceptable here.')
134 self.dtype = dtype
135 self.shape = shape
136 self.name = tensor_name
137
138 @ classmethod
139 def parsing_from_numpy_ndarray(cls, numpy_array: ndarray, name: str = None):
140 shape = list(numpy_array.shape)
141 dtype = DataType.convert_from_numpy(numpy_array.dtype)
142 return TensorMeta(dtype=dtype, shape=shape,tensor_name=name)
143
144 @ classmethod
145 def parsing_from_torch_tensor(cls, torch_tensor: Tensor, name: str = None):
146 if not isinstance(torch_tensor, Tensor):
147 raise TypeError(f'Can not parse meta data for {type(torch_tensor)} instance, '
148 'it should be torch.Tensor object.')
149 shape = list(torch_tensor.shape)
150
151 # for tensor scalar, which do not have an valid shape
152 # just mannully give a empty list to them.
153 if not shape: shape = []
154
155 dtype = DataType.convert_from_torch(torch_tensor.dtype)
156 return TensorMeta(dtype=dtype, shape=shape, tensor_name=name)
157
158 def create_tensor(self, device: str, fill_value: Any = 0):
159 return torch.Tensor(size=self.shape, device='cpu').fill_(
160 fill_value).type(dtype=DataType.to_torch(self.dtype)).to(device)
161
162 def create_ndarray(self, fill_value: Any = 0):
163 return ndarray(shape=self.shape,
164 dtype=DataType.to_numpy(self.dtype)).fill(fill_value)
165
166 def __str__(self) -> str:
167 return f'Tensor({self.name}) meta: dtype({self.dtype}), shape({self.shape})'
168
169 def copy(self):
170 if self.shape is not None:
171 return TensorMeta(dtype=self.dtype, shape=self.shape.copy(), tensor_name=self.name)
172 else: return TensorMeta(dtype=self.dtype, shape=None, tensor_name=self.name)

Callers 3

copyMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected