MCPcopy
hub / github.com/zai-org/GLM-130B / QuantizedRowParallelLinear

Class QuantizedRowParallelLinear

quantization/layers.py:51–87  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

49
50
51class QuantizedRowParallelLinear(RowParallelLinear):
52 def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs):
53 super(QuantizedRowParallelLinear, self).__init__(*args, **kwargs)
54 self.weight_bit_width = weight_bit_width
55
56 shape = self.weight.shape
57 del self.weight
58
59 if weight is None:
60 self.weight = torch.empty(
61 shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
62 )
63 self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
64 else:
65 self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
66 self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
67 if weight_bit_width == 4:
68 self.weight = compress_int4_weight(self.weight)
69
70 self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
71 self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
72
73 def forward(self, input_):
74 # Set up backprop all-reduce.
75 if self.input_is_parallel:
76 input_parallel = input_
77 else:
78 input_parallel = scatter_to_model_parallel_region(input_)
79 # Matrix multiply.
80 output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width)
81 # All-reduce across all the partitions.
82 output_ = reduce_from_model_parallel_region(output_parallel)
83 if self.bias is not None:
84 output = output_ + self.bias
85 else:
86 output = output_
87 return output

Callers 1

quantizeFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected