(blocks, *args)
| 53 | return (x[:, 0] | x[:, 1] << 8).unsqueeze(1) |
| 54 | |
| 55 | def split_block_dims(blocks, *args): |
| 56 | n_max = blocks.shape[1] |
| 57 | dims = list(args) + [n_max - sum(args)] |
| 58 | return torch.split(blocks, dims, dim=1) |
| 59 | |
| 60 | # Full weights # |
| 61 | def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None): |
no outgoing calls
no test coverage detected