MCPcopy
hub / github.com/deepspeedai/DeepSpeed / test_float_quantize

Function test_float_quantize

tests/unit/ops/quantizer/test_quantize.py:129–154  ·  view source on GitHub ↗
(num_elems, num_groups, is_symmetric_quant, q_bits, directed_case)

Source from the content-addressed store, hash-verified

127@pytest.mark.parametrize("q_bits", [4, 8])
128@pytest.mark.parametrize("directed_case", ["all_zeros", None])
129def test_float_quantize(num_elems, num_groups, is_symmetric_quant, q_bits, directed_case):
130 # fix seed
131 torch.manual_seed(num_elems)
132
133 if directed_case == "all_zeros":
134 activations_ds = torch.zeros((num_groups, num_elems),
135 dtype=torch.float16,
136 device=get_accelerator().device_name())
137 else:
138 activations_ds = torch.randn((num_groups, num_elems),
139 dtype=torch.float16,
140 device=get_accelerator().device_name())
141 activations_ref = activations_ds.clone().detach()
142
143 ref_out_tensor, ref_params = run_float_quantize(q_bits, is_symmetric_quant, activations_ref, num_groups)
144 ref_dequantized_tensor = run_float_dequantize(q_bits, is_symmetric_quant, ref_out_tensor, ref_params, num_groups)
145 # we need to convert the tensor to float64 to avoid overflow
146 ref_quantization_error = torch.sum(torch.abs((activations_ref - ref_dequantized_tensor).to(torch.float64)))
147
148 ds_out_tensor, ds_out_params = run_quantize_ds(activations_ds, num_groups, q_bits, is_symmetric_quant)
149 ds_dequantized_tensor = run_dequantize_ds(ds_out_tensor, ds_out_params, num_groups, q_bits, is_symmetric_quant)
150 assert torch.all(torch.isfinite(ds_dequantized_tensor))
151
152 ds_quantization_error = torch.sum(torch.abs((activations_ds - ds_dequantized_tensor).to(torch.float64)))
153
154 assert (ds_quantization_error <= ref_quantization_error * 1.05)

Callers

nothing calls this directly

Calls 8

get_acceleratorFunction · 0.90
run_float_quantizeFunction · 0.85
run_float_dequantizeFunction · 0.85
run_quantize_dsFunction · 0.85
run_dequantize_dsFunction · 0.85
manual_seedMethod · 0.45
device_nameMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…