MCPcopy
hub / github.com/openai/guided-diffusion / LossSecondMomentResampler

Class LossSecondMomentResampler

guided_diffusion/resample.py:124–154  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

122
123
124class LossSecondMomentResampler(LossAwareSampler):
125 def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 self.diffusion = diffusion
127 self.history_per_term = history_per_term
128 self.uniform_prob = uniform_prob
129 self._loss_history = np.zeros(
130 [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 )
132 self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133
134 def weights(self):
135 if not self._warmed_up():
136 return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 weights /= np.sum(weights)
139 weights *= 1 - self.uniform_prob
140 weights += self.uniform_prob / len(weights)
141 return weights
142
143 def update_with_all_losses(self, ts, losses):
144 for t, loss in zip(ts, losses):
145 if self._loss_counts[t] == self.history_per_term:
146 # Shift out the oldest loss term.
147 self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 self._loss_history[t, -1] = loss
149 else:
150 self._loss_history[t, self._loss_counts[t]] = loss
151 self._loss_counts[t] += 1
152
153 def _warmed_up(self):
154 return (self._loss_counts == self.history_per_term).all()

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected