MCPcopy Index your code
hub / github.com/XPixelGroup/DiffBIR / sample

Method sample

diffbir/sampler/spaced_sampler.py:187–245  ·  view source on GitHub ↗
(
        self,
        model: ControlLDM,
        device: str,
        steps: int,
        x_size: Tuple[int],
        cond: Dict[str, torch.Tensor],
        uncond: Dict[str, torch.Tensor],
        cfg_scale: float,
        tiled: bool = False,
        tile_size: int = -1,
        tile_stride: int = -1,
        x_T: torch.Tensor | None = None,
        progress: bool = True,
    )

Source from the content-addressed store, hash-verified

185
186 @torch.no_grad()
187 def sample(
188 self,
189 model: ControlLDM,
190 device: str,
191 steps: int,
192 x_size: Tuple[int],
193 cond: Dict[str, torch.Tensor],
194 uncond: Dict[str, torch.Tensor],
195 cfg_scale: float,
196 tiled: bool = False,
197 tile_size: int = -1,
198 tile_stride: int = -1,
199 x_T: torch.Tensor | None = None,
200 progress: bool = True,
201 ) -> torch.Tensor:
202 self.make_schedule(steps)
203 self.to(device)
204 if tiled:
205 forward = model.forward
206 model.forward = make_tiled_fn(
207 lambda x_tile, t, cond, hi, hi_end, wi, wi_end: (
208 forward(
209 x_tile,
210 t,
211 {
212 "c_txt": cond["c_txt"],
213 "c_img": cond["c_img"][..., hi:hi_end, wi:wi_end],
214 },
215 )
216 ),
217 tile_size,
218 tile_stride,
219 )
220 if x_T is None:
221 x_T = torch.randn(x_size, device=device, dtype=torch.float32)
222
223 x = x_T
224 timesteps = np.flip(self.timesteps)
225 total_steps = len(self.timesteps)
226 iterator = tqdm(timesteps, total=total_steps, disable=not progress)
227 bs = x_size[0]
228
229 for i, step in enumerate(iterator):
230 model_t = torch.full((bs,), step, device=device, dtype=torch.long)
231 t = torch.full((bs,), total_steps - i - 1, device=device, dtype=torch.long)
232 cur_cfg_scale = self.get_cfg_scale(cfg_scale, step)
233 x = self.p_sample(
234 model,
235 x,
236 model_t,
237 t,
238 cond,
239 uncond,
240 cur_cfg_scale,
241 )
242
243 if tiled:
244 model.forward = forward

Callers 1

mainFunction · 0.95

Calls 5

make_scheduleMethod · 0.95
p_sampleMethod · 0.95
make_tiled_fnFunction · 0.85
get_cfg_scaleMethod · 0.80
forwardFunction · 0.50

Tested by

no test coverage detected