| 122 | |
| 123 | |
| 124 | class LossSecondMomentResampler(LossAwareSampler): |
| 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): |
| 126 | self.diffusion = diffusion |
| 127 | self.history_per_term = history_per_term |
| 128 | self.uniform_prob = uniform_prob |
| 129 | self._loss_history = np.zeros( |
| 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 |
| 131 | ) |
| 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) |
| 133 | |
| 134 | def weights(self): |
| 135 | if not self._warmed_up(): |
| 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) |
| 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) |
| 138 | weights /= np.sum(weights) |
| 139 | weights *= 1 - self.uniform_prob |
| 140 | weights += self.uniform_prob / len(weights) |
| 141 | return weights |
| 142 | |
| 143 | def update_with_all_losses(self, ts, losses): |
| 144 | for t, loss in zip(ts, losses): |
| 145 | if self._loss_counts[t] == self.history_per_term: |
| 146 | # Shift out the oldest loss term. |
| 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] |
| 148 | self._loss_history[t, -1] = loss |
| 149 | else: |
| 150 | self._loss_history[t, self._loss_counts[t]] = loss |
| 151 | self._loss_counts[t] += 1 |
| 152 | |
| 153 | def _warmed_up(self): |
| 154 | return (self._loss_counts == self.history_per_term).all() |
no outgoing calls
no test coverage detected