Turn the given PDF into a quantized CDF that splits [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional to the PDF. Args: pdf (torch.Tensor): probability distribution, shape should be `[N]`. total_range_bits (int): see `ArithmeticCoder`, the typi
(pdf: torch.Tensor, total_range_bits: int,
roundoff: float = 1e-8, min_range: int = 2,
check: bool = True)
| 16 | |
| 17 | |
| 18 | def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int, |
| 19 | roundoff: float = 1e-8, min_range: int = 2, |
| 20 | check: bool = True) -> torch.Tensor: |
| 21 | """Turn the given PDF into a quantized CDF that splits |
| 22 | [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional |
| 23 | to the PDF. |
| 24 | |
| 25 | Args: |
| 26 | pdf (torch.Tensor): probability distribution, shape should be `[N]`. |
| 27 | total_range_bits (int): see `ArithmeticCoder`, the typical range we expect |
| 28 | during the coding process is `[0, 2 ** total_range_bits - 1]`. |
| 29 | roundoff (float): will round the pdf up to that level to remove difference coming |
| 30 | from e.g. evaluating the Language Model on different architectures. |
| 31 | min_range (int): minimum range width. Should always be at least 2 for numerical |
| 32 | stability. Use this to avoid pathological behavior is a value |
| 33 | that is expected to be rare actually happens in real life. |
| 34 | check (bool): if True, checks that nothing bad happened, can be deactivated for speed. |
| 35 | """ |
| 36 | pdf = pdf.detach() |
| 37 | if roundoff: |
| 38 | pdf = (pdf / roundoff).floor() * roundoff |
| 39 | # interpolate with uniform distribution to achieve desired minimum probability. |
| 40 | total_range = 2 ** total_range_bits |
| 41 | cardinality = len(pdf) |
| 42 | alpha = min_range * cardinality / total_range |
| 43 | assert alpha <= 1, "you must reduce min_range" |
| 44 | ranges = (((1 - alpha) * total_range) * pdf).floor().long() |
| 45 | ranges += min_range |
| 46 | quantized_cdf = torch.cumsum(ranges, dim=-1) |
| 47 | if min_range < 2: |
| 48 | raise ValueError("min_range must be at least 2.") |
| 49 | if check: |
| 50 | assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] |
| 51 | if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: |
| 52 | raise ValueError("You must increase your total_range_bits.") |
| 53 | return quantized_cdf |
| 54 | |
| 55 | |
| 56 | class ArithmeticCoder: |
no outgoing calls
no test coverage detected
searching dependent graphs…