MCPcopy
hub / github.com/OpenPPL/ppq / BanditDelegator

Class BanditDelegator

ppq/quantization/algorithm/exprimental.py:12–55  ·  view source on GitHub ↗

带有多臂赌博机的量化代理,从 ppq 0.6.2 版本后,我们引入 多臂赌博机算法训练 scale 与 offset。在未来我们可能还会引入其他 类似的算法,例如UCB,马尔可夫蒙特卡洛估计等。 引入这些算法的原因是我们注意到 scale 与 offset 的导数非常不靠谱 为此我们引入简单的强化学习,直接估计P(r | scale=s, context) 即再给定上下文 context 的情况下,选取当前 scale 为 s,获利的概率 Quantization with multi-arm bandit. Multi-arm bandi

Source from the content-addressed store, hash-verified

10
11
12class BanditDelegator(TorchQuantizeDelegator):
13 """带有多臂赌博机的量化代理,从 ppq 0.6.2 版本后,我们引入 多臂赌博机算法训练 scale 与 offset。在未来我们可能还会引入其他
14 类似的算法,例如UCB,马尔可夫蒙特卡洛估计等。
15
16 引入这些算法的原因是我们注意到 scale 与 offset 的导数非常不靠谱
17 为此我们引入简单的强化学习,直接估计P(r | scale=s, context)
18 即再给定上下文 context 的情况下,选取当前 scale 为 s,获利的概率
19
20 Quantization with multi-arm bandit.
21
22 Multi-arm bandits are introduced since PPQ 0.6.2 for training
23 quantization scale and offset.
24 """
25 def __init__(self, arms: List[float], config: TensorQuantizationConfig) -> None:
26 if len(arms) < 2: raise ValueError('Can not initialize bandit with less than 2 arms.')
27 self.e = 0.1
28 self.arms = arms
29 self.num_of_arms = len(arms)
30 self.rewards = [EMARecorder() for _ in range(self.num_of_arms)]
31 self.rewards[0].push(1)
32 self.last_selected = 0
33 self.reference = config.scale.clone()
34 self.config = config
35 self.decay = 0.99
36
37 def roll(self) -> int:
38 if random.random() > self.e: selected = random.randint(0, len(self.arms) - 1)
39 else: selected = np.argmax([ema.pop() for ema in self.rewards])
40 self.last_selected = selected
41 return selected
42
43 def mark(self, rewards: float):
44 self.rewards[self.last_selected].push(rewards)
45
46 def finalize(self) -> bool:
47 self.config.scale = self.reference * self.arms[np.argmax([ema.pop() for ema in self.rewards])]
48
49 def withdraw(self):
50 self.config.scale = self.reference
51
52 def __call__(self, tensor: torch.Tensor,
53 config: TensorQuantizationConfig) -> torch.Tensor:
54 config.scale = self.reference * self.arms[self.roll()]
55 return PPQLinearQuantFunction(tensor, config)

Callers 1

calib_blockMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected