| 182 | |
| 183 | |
| 184 | class QuantableVariable(Variable): |
| 185 | def __init__(self, convert_from: Variable) -> None: |
| 186 | super().__init__( |
| 187 | name = convert_from.name, |
| 188 | dest_ops = convert_from.dest_ops.copy(), |
| 189 | source_op = convert_from.source_op, |
| 190 | value = convert_from.value, |
| 191 | is_parameter = convert_from.is_parameter, |
| 192 | shape = convert_from.shape, |
| 193 | dtype = convert_from.dtype) |
| 194 | self._fp32_value = None |
| 195 | if convert_from.value is not None: |
| 196 | self._fp32_value = convert_any_to_torch_tensor(convert_from.value, device='cpu') |
| 197 | |
| 198 | @ property |
| 199 | def stored_value(self) -> Any: |
| 200 | return self._fp32_value |
| 201 | |
| 202 | @ stored_value.setter |
| 203 | def stored_value(self, value: Any): |
| 204 | self._fp32_value = value |
| 205 | |
| 206 | @ property |
| 207 | def dest_op_configs(self) -> List[TensorQuantizationConfig]: |
| 208 | _dest_op_configs, _dest_idx = [], self.dest_idx |
| 209 | for idx, op in enumerate(self.dest_ops): |
| 210 | if isinstance(op, QuantableOperation): |
| 211 | _dest_op_configs.append(op.config.input_quantization_config[_dest_idx[idx]]) |
| 212 | else: _dest_op_configs.append(None) |
| 213 | return _dest_op_configs |
| 214 | |
| 215 | @ property |
| 216 | def dest_op_platforms(self) -> List[TargetPlatform]: |
| 217 | _dest_op_platforms = [] |
| 218 | for op in self.dest_ops: |
| 219 | if op is not None: |
| 220 | _dest_op_platforms.append(op.platform) |
| 221 | else: _dest_op_platforms.append(TargetPlatform.FP32) |
| 222 | return _dest_op_platforms |
| 223 | |
| 224 | @ property |
| 225 | def source_op_config(self) -> TensorQuantizationConfig: |
| 226 | if self.source_op is not None: |
| 227 | if isinstance(self.source_op, QuantableOperation): |
| 228 | return self.source_op.config.output_quantization_config[self.src_idx] |
| 229 | else: return None |
| 230 | return None |
| 231 | |
| 232 | @ property |
| 233 | def source_op_platform(self) -> TargetPlatform: |
| 234 | if self.source_op is None: |
| 235 | return TargetPlatform.FP32 |
| 236 | else: return self.source_op.platform |
| 237 | |
| 238 | def copy(self, copy_value: bool = False): |
| 239 | clone = QuantableVariable(super().copy(copy_value)) |
| 240 | if copy_value and self._fp32_value is not None: |
| 241 | clone._fp32_value = self._fp32_value.clone() |
no outgoing calls
no test coverage detected