带有多臂赌博机的量化代理,从 ppq 0.6.2 版本后,我们引入 多臂赌博机算法训练 scale 与 offset。在未来我们可能还会引入其他 类似的算法,例如UCB,马尔可夫蒙特卡洛估计等。 引入这些算法的原因是我们注意到 scale 与 offset 的导数非常不靠谱 为此我们引入简单的强化学习,直接估计P(r | scale=s, context) 即再给定上下文 context 的情况下,选取当前 scale 为 s,获利的概率 Quantization with multi-arm bandit. Multi-arm bandi
| 10 | |
| 11 | |
| 12 | class BanditDelegator(TorchQuantizeDelegator): |
| 13 | """带有多臂赌博机的量化代理,从 ppq 0.6.2 版本后,我们引入 多臂赌博机算法训练 scale 与 offset。在未来我们可能还会引入其他 |
| 14 | 类似的算法,例如UCB,马尔可夫蒙特卡洛估计等。 |
| 15 | |
| 16 | 引入这些算法的原因是我们注意到 scale 与 offset 的导数非常不靠谱 |
| 17 | 为此我们引入简单的强化学习,直接估计P(r | scale=s, context) |
| 18 | 即再给定上下文 context 的情况下,选取当前 scale 为 s,获利的概率 |
| 19 | |
| 20 | Quantization with multi-arm bandit. |
| 21 | |
| 22 | Multi-arm bandits are introduced since PPQ 0.6.2 for training |
| 23 | quantization scale and offset. |
| 24 | """ |
| 25 | def __init__(self, arms: List[float], config: TensorQuantizationConfig) -> None: |
| 26 | if len(arms) < 2: raise ValueError('Can not initialize bandit with less than 2 arms.') |
| 27 | self.e = 0.1 |
| 28 | self.arms = arms |
| 29 | self.num_of_arms = len(arms) |
| 30 | self.rewards = [EMARecorder() for _ in range(self.num_of_arms)] |
| 31 | self.rewards[0].push(1) |
| 32 | self.last_selected = 0 |
| 33 | self.reference = config.scale.clone() |
| 34 | self.config = config |
| 35 | self.decay = 0.99 |
| 36 | |
| 37 | def roll(self) -> int: |
| 38 | if random.random() > self.e: selected = random.randint(0, len(self.arms) - 1) |
| 39 | else: selected = np.argmax([ema.pop() for ema in self.rewards]) |
| 40 | self.last_selected = selected |
| 41 | return selected |
| 42 | |
| 43 | def mark(self, rewards: float): |
| 44 | self.rewards[self.last_selected].push(rewards) |
| 45 | |
| 46 | def finalize(self) -> bool: |
| 47 | self.config.scale = self.reference * self.arms[np.argmax([ema.pop() for ema in self.rewards])] |
| 48 | |
| 49 | def withdraw(self): |
| 50 | self.config.scale = self.reference |
| 51 | |
| 52 | def __call__(self, tensor: torch.Tensor, |
| 53 | config: TensorQuantizationConfig) -> torch.Tensor: |
| 54 | config.scale = self.reference * self.arms[self.roll()] |
| 55 | return PPQLinearQuantFunction(tensor, config) |