(self, x, lora_runtime_params=None)
| 1354 | prefer_managed=self.prefer_managed_weight) |
| 1355 | |
| 1356 | def forward(self, x, lora_runtime_params=None): |
| 1357 | lora_hidden_state = x if lora_runtime_params is not None else None |
| 1358 | if default_net().strongly_typed: |
| 1359 | assert is_same_dtype( |
| 1360 | x.dtype, |
| 1361 | self.dtype), f"Got input type {x.dtype}, expecting {self.dtype}" |
| 1362 | x = mul(x, self.prequant_scaling_factor.value) |
| 1363 | x = cast(x, self.activation_scaling_factor.value.dtype) |
| 1364 | |
| 1365 | quantized_out = quantize(x, self.activation_scaling_factor.value, |
| 1366 | 'int8') |
| 1367 | |
| 1368 | dequantized_out = dequantize(quantized_out, |
| 1369 | self.activation_scaling_factor.value, -1, |
| 1370 | self.activation_scaling_factor.value.dtype) |
| 1371 | |
| 1372 | dequantized_out = cast(dequantized_out, self.dtype) |
| 1373 | |
| 1374 | w_deq_out = dequantize(self.weight.value, |
| 1375 | self.weights_scaling_factor.value, 0, |
| 1376 | self.weights_scaling_factor.value.dtype) |
| 1377 | w_deq_out = cast(w_deq_out, self.dtype) |
| 1378 | |
| 1379 | return self.multiply_collect(dequantized_out, |
| 1380 | w_deq_out, |
| 1381 | gemm_plugin=None, |
| 1382 | lora_runtime_params=lora_runtime_params, |
| 1383 | lora_hidden_state=lora_hidden_state) |
| 1384 | |
| 1385 | |
| 1386 | class FP8Linear(Linear): |
nothing calls this directly
no test coverage detected