Make a TensorProto with specified arguments. If raw is False, this function will choose the corresponding proto field to store the values based on data_type. If raw is True, use "raw_data" proto field to store the values, and values should be of type bytes in this case. Args:
(
name: str,
data_type: int,
dims: Sequence[int],
vals: Sequence[int | float] | bytes | np.ndarray,
raw: bool = False,
)
| 361 | |
| 362 | |
| 363 | def make_tensor( |
| 364 | name: str, |
| 365 | data_type: int, |
| 366 | dims: Sequence[int], |
| 367 | vals: Sequence[int | float] | bytes | np.ndarray, |
| 368 | raw: bool = False, |
| 369 | ) -> TensorProto: |
| 370 | """Make a TensorProto with specified arguments. If raw is False, this |
| 371 | function will choose the corresponding proto field to store the |
| 372 | values based on data_type. If raw is True, use "raw_data" proto |
| 373 | field to store the values, and values should be of type bytes in |
| 374 | this case. |
| 375 | |
| 376 | Args: |
| 377 | name: tensor name |
| 378 | data_type: a value such as onnx.TensorProto.FLOAT |
| 379 | dims: shape |
| 380 | vals: values |
| 381 | raw: if True, vals contains the serialized content of the tensor, |
| 382 | otherwise, vals should be a list of values of the type defined by ``data_type``. |
| 383 | |
| 384 | Returns: |
| 385 | TensorProto |
| 386 | """ |
| 387 | tensor = TensorProto() |
| 388 | tensor.data_type = data_type |
| 389 | tensor.name = name |
| 390 | tensor.dims.extend(dims) |
| 391 | |
| 392 | if data_type == TensorProto.STRING and raw: |
| 393 | raise TypeError("Can not use raw_data to store string type.") |
| 394 | |
| 395 | np_dtype = tensor_dtype_to_np_dtype(data_type) |
| 396 | |
| 397 | if raw: |
| 398 | # NumPy doesn't have INT2/INT4/FP4. It is packed in couples to UINT8 buffers. |
| 399 | if data_type in {TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1}: |
| 400 | expected_size_bytes = 0.5 |
| 401 | elif data_type in {TensorProto.UINT2, TensorProto.INT2}: |
| 402 | expected_size_bytes = 0.25 |
| 403 | else: |
| 404 | expected_size_bytes = np_dtype.itemsize |
| 405 | expected_size_bytes *= math.prod(dims) |
| 406 | expected_size_bytes = math.ceil(expected_size_bytes) |
| 407 | if isinstance(vals, np.ndarray): |
| 408 | if data_type in { |
| 409 | TensorProto.INT4, |
| 410 | TensorProto.UINT4, |
| 411 | TensorProto.FLOAT4E2M1, |
| 412 | }: |
| 413 | vals = onnx.numpy_helper._pack_4bitx2(vals) |
| 414 | elif data_type in {TensorProto.UINT2, TensorProto.INT2}: |
| 415 | vals = onnx.numpy_helper._pack_2bitx4(vals) |
| 416 | |
| 417 | raw_data = onnx.numpy_helper.tobytes_little_endian(vals) |
| 418 | elif isinstance(vals, bytes): |
| 419 | raw_data = vals |
| 420 | else: |
searching dependent graphs…