MCPcopy
hub / github.com/zai-org/GLM-130B / QuantizedColumnParallelLinear

Class QuantizedColumnParallelLinear

quantization/layers.py:14–48  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

12
13
14class QuantizedColumnParallelLinear(ColumnParallelLinear):
15 def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
16 super(QuantizedColumnParallelLinear, self).__init__(*args, **kwargs)
17 self.weight_bit_width = weight_bit_width
18
19 shape = self.weight.shape
20 del self.weight
21
22 if weight is None:
23 self.weight = torch.empty(
24 shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
25 )
26 self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
27 else:
28 self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
29 self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
30 if weight_bit_width == 4:
31 self.weight = compress_int4_weight(self.weight)
32
33 self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
34 self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
35
36 def forward(self, input_):
37 # Set up backprop all-reduce.
38 input_parallel = copy_to_model_parallel_region(input_)
39 # Matrix multiply.
40 output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
41 if self.bias is not None:
42 output_parallel = output_parallel + self.bias
43 if self.gather_output:
44 # All-gather across the partitions.
45 output = gather_from_model_parallel_region(output_parallel)
46 else:
47 output = output_parallel
48 return output
49
50
51class QuantizedRowParallelLinear(RowParallelLinear):

Callers 1

quantizeFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected