MCPcopy
hub / github.com/Janspiry/Palette-Image-to-Image-Diffusion-Models / restoration

Method restoration

models/network.py:88–103  ·  view source on GitHub ↗
(self, y_cond, y_t=None, y_0=None, mask=None, sample_num=8)

Source from the content-addressed store, hash-verified

86
87 @torch.no_grad()
88 def restoration(self, y_cond, y_t=None, y_0=None, mask=None, sample_num=8):
89 b, *_ = y_cond.shape
90
91 assert self.num_timesteps > sample_num, 'num_timesteps must greater than sample_num'
92 sample_inter = (self.num_timesteps//sample_num)
93
94 y_t = default(y_t, lambda: torch.randn_like(y_cond))
95 ret_arr = y_t
96 for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
97 t = torch.full((b,), i, device=y_cond.device, dtype=torch.long)
98 y_t = self.p_sample(y_t, t, y_cond=y_cond)
99 if mask is not None:
100 y_t = y_0*(1.-mask) + mask*y_t
101 if i % sample_inter == 0:
102 ret_arr = torch.cat([ret_arr, y_t], dim=0)
103 return y_t, ret_arr
104
105 def forward(self, y_0, y_cond=None, mask=None, noise=None):
106 # sampling from p(gammas)

Callers 2

val_stepMethod · 0.80
testMethod · 0.80

Calls 2

p_sampleMethod · 0.95
defaultFunction · 0.70

Tested by

no test coverage detected