MCPcopy Index your code
hub / github.com/lm-sys/FastChat / compress

Function compress

fastchat/model/compression.py:226–276  ·  view source on GitHub ↗

Simulate group-wise quantization.

(tensor, config)

Source from the content-addressed store, hash-verified

224
225
226def compress(tensor, config):
227 """Simulate group-wise quantization."""
228 if not config.enabled:
229 return tensor
230
231 group_size, num_bits, group_dim, symmetric = (
232 config.group_size,
233 config.num_bits,
234 config.group_dim,
235 config.symmetric,
236 )
237 assert num_bits <= 8
238
239 original_shape = tensor.shape
240 num_groups = (original_shape[group_dim] + group_size - 1) // group_size
241 new_shape = (
242 original_shape[:group_dim]
243 + (num_groups, group_size)
244 + original_shape[group_dim + 1 :]
245 )
246
247 # Pad
248 pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
249 if pad_len != 0:
250 pad_shape = (
251 original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :]
252 )
253 tensor = torch.cat(
254 [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
255 dim=group_dim,
256 )
257 data = tensor.view(new_shape)
258
259 # Quantize
260 if symmetric:
261 B = 2 ** (num_bits - 1) - 1
262 scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
263 data = data * scale
264 data = data.clamp_(-B, B).round_().to(torch.int8)
265 return data, scale, original_shape
266 else:
267 B = 2**num_bits - 1
268 mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
269 mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
270
271 scale = B / (mx - mn)
272 data = data - mn
273 data.mul_(scale)
274
275 data = data.clamp_(0, B).round_().to(torch.uint8)
276 return data, mn, scale, original_shape
277
278
279def decompress(packed_data, config):

Callers 2

__init__Method · 0.85
load_compress_modelFunction · 0.85

Calls 1

toMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…