(self, x, lora_runtime_params=None, all_reduce_params=None)
| 1290 | prefer_managed=self.prefer_managed_weight) |
| 1291 | |
| 1292 | def forward(self, x, lora_runtime_params=None, all_reduce_params=None): |
| 1293 | lora_hidden_state = x if lora_runtime_params is not None else None |
| 1294 | if default_net().strongly_typed: |
| 1295 | assert is_same_dtype( |
| 1296 | x.dtype, |
| 1297 | self.dtype), f"Got input type {x.dtype}, expecting {self.dtype}" |
| 1298 | x = mul(x, self.prequant_scaling_factor.value) |
| 1299 | |
| 1300 | x = cast(x, self.activation_scaling_factor.value.dtype) |
| 1301 | |
| 1302 | quantized_out = quantize(x, self.activation_scaling_factor.value, |
| 1303 | 'int8') |
| 1304 | |
| 1305 | dequantized_out = dequantize(quantized_out, |
| 1306 | self.activation_scaling_factor.value, -1, |
| 1307 | self.activation_scaling_factor.value.dtype) |
| 1308 | |
| 1309 | dequantized_out = cast(dequantized_out, self.dtype) |
| 1310 | |
| 1311 | w_deq_out = dequantize(self.weight.value, |
| 1312 | self.weights_scaling_factor.value, 0, |
| 1313 | self.weights_scaling_factor.value.dtype) |
| 1314 | |
| 1315 | w_deq_out = cast(w_deq_out, self.dtype) |
| 1316 | return self.multiply_collect(dequantized_out, |
| 1317 | w_deq_out, |
| 1318 | gemm_plugin=None, |
| 1319 | all_reduce_params=all_reduce_params, |
| 1320 | lora_runtime_params=lora_runtime_params, |
| 1321 | lora_hidden_state=lora_hidden_state) |
| 1322 | |
| 1323 | |
| 1324 | class Int8SmoothQuantLinear(Linear): |
nothing calls this directly
no test coverage detected