MCPcopy Index your code
hub / github.com/openai/guided-diffusion / SpacedDiffusion

Class SpacedDiffusion

guided_diffusion/respace.py:63–113  ·  view source on GitHub ↗

A diffusion process which can skip steps in a base diffusion process. :param use_timesteps: a collection (sequence or set) of timesteps from the original diffusion process to retain. :param kwargs: the kwargs to create the base diffusion process.

Source from the content-addressed store, hash-verified

61
62
63class SpacedDiffusion(GaussianDiffusion):
64 """
65 A diffusion process which can skip steps in a base diffusion process.
66
67 :param use_timesteps: a collection (sequence or set) of timesteps from the
68 original diffusion process to retain.
69 :param kwargs: the kwargs to create the base diffusion process.
70 """
71
72 def __init__(self, use_timesteps, **kwargs):
73 self.use_timesteps = set(use_timesteps)
74 self.timestep_map = []
75 self.original_num_steps = len(kwargs["betas"])
76
77 base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 last_alpha_cumprod = 1.0
79 new_betas = []
80 for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 if i in self.use_timesteps:
82 new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 last_alpha_cumprod = alpha_cumprod
84 self.timestep_map.append(i)
85 kwargs["betas"] = np.array(new_betas)
86 super().__init__(**kwargs)
87
88 def p_mean_variance(
89 self, model, *args, **kwargs
90 ): # pylint: disable=signature-differs
91 return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92
93 def training_losses(
94 self, model, *args, **kwargs
95 ): # pylint: disable=signature-differs
96 return super().training_losses(self._wrap_model(model), *args, **kwargs)
97
98 def condition_mean(self, cond_fn, *args, **kwargs):
99 return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100
101 def condition_score(self, cond_fn, *args, **kwargs):
102 return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103
104 def _wrap_model(self, model):
105 if isinstance(model, _WrappedModel):
106 return model
107 return _WrappedModel(
108 model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 )
110
111 def _scale_timesteps(self, t):
112 # Scaling is done by the wrapped model.
113 return t
114
115
116class _WrappedModel:

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected