(num_elems, num_groups, is_symmetric_quant, q_bits, directed_case)
| 127 | @pytest.mark.parametrize("q_bits", [4, 8]) |
| 128 | @pytest.mark.parametrize("directed_case", ["all_zeros", None]) |
| 129 | def test_float_quantize(num_elems, num_groups, is_symmetric_quant, q_bits, directed_case): |
| 130 | # fix seed |
| 131 | torch.manual_seed(num_elems) |
| 132 | |
| 133 | if directed_case == "all_zeros": |
| 134 | activations_ds = torch.zeros((num_groups, num_elems), |
| 135 | dtype=torch.float16, |
| 136 | device=get_accelerator().device_name()) |
| 137 | else: |
| 138 | activations_ds = torch.randn((num_groups, num_elems), |
| 139 | dtype=torch.float16, |
| 140 | device=get_accelerator().device_name()) |
| 141 | activations_ref = activations_ds.clone().detach() |
| 142 | |
| 143 | ref_out_tensor, ref_params = run_float_quantize(q_bits, is_symmetric_quant, activations_ref, num_groups) |
| 144 | ref_dequantized_tensor = run_float_dequantize(q_bits, is_symmetric_quant, ref_out_tensor, ref_params, num_groups) |
| 145 | # we need to convert the tensor to float64 to avoid overflow |
| 146 | ref_quantization_error = torch.sum(torch.abs((activations_ref - ref_dequantized_tensor).to(torch.float64))) |
| 147 | |
| 148 | ds_out_tensor, ds_out_params = run_quantize_ds(activations_ds, num_groups, q_bits, is_symmetric_quant) |
| 149 | ds_dequantized_tensor = run_dequantize_ds(ds_out_tensor, ds_out_params, num_groups, q_bits, is_symmetric_quant) |
| 150 | assert torch.all(torch.isfinite(ds_dequantized_tensor)) |
| 151 | |
| 152 | ds_quantization_error = torch.sum(torch.abs((activations_ds - ds_dequantized_tensor).to(torch.float64))) |
| 153 | |
| 154 | assert (ds_quantization_error <= ref_quantization_error * 1.05) |
nothing calls this directly
no test coverage detected
searching dependent graphs…