Parameters: x : Tensor (On GPU) The input tensor. scale : Tensor (On GPU) The block scale tensor. double_scale : Tensor (On GPU) The global per-tensor scaling factor. It should contain only 1 element. dtype : trt.DataType | str
(x: Tensor,
scale: Tensor,
double_scale: Tensor,
dtype: trt.DataType | str = 'float16')
| 1428 | |
| 1429 | |
| 1430 | def block_double_dequantize(x: Tensor, |
| 1431 | scale: Tensor, |
| 1432 | double_scale: Tensor, |
| 1433 | dtype: trt.DataType | str = 'float16') -> Tensor: |
| 1434 | ''' |
| 1435 | Parameters: |
| 1436 | x : Tensor (On GPU) |
| 1437 | The input tensor. |
| 1438 | scale : Tensor (On GPU) |
| 1439 | The block scale tensor. |
| 1440 | double_scale : Tensor (On GPU) |
| 1441 | The global per-tensor scaling factor. It should contain only 1 element. |
| 1442 | dtype : trt.DataType | str |
| 1443 | The data type for dequantized data. Default is float32. |
| 1444 | Returns: |
| 1445 | The dequantized tensor. |
| 1446 | ''' |
| 1447 | if isinstance(dtype, str): |
| 1448 | dtype = str_dtype_to_trt(dtype) |
| 1449 | dequantize_scale_layer = default_trtnet().add_dequantize( |
| 1450 | scale.trt_tensor, double_scale.trt_tensor, dtype) |
| 1451 | scale = _create_tensor(dequantize_scale_layer.get_output(0), |
| 1452 | dequantize_scale_layer) |
| 1453 | |
| 1454 | dequantize_data_layer = default_trtnet().add_dequantize( |
| 1455 | x.trt_tensor, scale.trt_tensor, dtype) |
| 1456 | dequantize_data = _create_tensor(dequantize_data_layer.get_output(0), |
| 1457 | dequantize_data_layer) |
| 1458 | return dequantize_data |
no test coverage detected