(
self,
model: ControlLDM,
device: str,
steps: int,
x_size: Tuple[int],
cond: Dict[str, torch.Tensor],
uncond: Dict[str, torch.Tensor],
cfg_scale: float,
tiled: bool = False,
tile_size: int = -1,
tile_stride: int = -1,
x_T: torch.Tensor | None = None,
progress: bool = True,
)
| 185 | |
| 186 | @torch.no_grad() |
| 187 | def sample( |
| 188 | self, |
| 189 | model: ControlLDM, |
| 190 | device: str, |
| 191 | steps: int, |
| 192 | x_size: Tuple[int], |
| 193 | cond: Dict[str, torch.Tensor], |
| 194 | uncond: Dict[str, torch.Tensor], |
| 195 | cfg_scale: float, |
| 196 | tiled: bool = False, |
| 197 | tile_size: int = -1, |
| 198 | tile_stride: int = -1, |
| 199 | x_T: torch.Tensor | None = None, |
| 200 | progress: bool = True, |
| 201 | ) -> torch.Tensor: |
| 202 | self.make_schedule(steps) |
| 203 | self.to(device) |
| 204 | if tiled: |
| 205 | forward = model.forward |
| 206 | model.forward = make_tiled_fn( |
| 207 | lambda x_tile, t, cond, hi, hi_end, wi, wi_end: ( |
| 208 | forward( |
| 209 | x_tile, |
| 210 | t, |
| 211 | { |
| 212 | "c_txt": cond["c_txt"], |
| 213 | "c_img": cond["c_img"][..., hi:hi_end, wi:wi_end], |
| 214 | }, |
| 215 | ) |
| 216 | ), |
| 217 | tile_size, |
| 218 | tile_stride, |
| 219 | ) |
| 220 | if x_T is None: |
| 221 | x_T = torch.randn(x_size, device=device, dtype=torch.float32) |
| 222 | |
| 223 | x = x_T |
| 224 | timesteps = np.flip(self.timesteps) |
| 225 | total_steps = len(self.timesteps) |
| 226 | iterator = tqdm(timesteps, total=total_steps, disable=not progress) |
| 227 | bs = x_size[0] |
| 228 | |
| 229 | for i, step in enumerate(iterator): |
| 230 | model_t = torch.full((bs,), step, device=device, dtype=torch.long) |
| 231 | t = torch.full((bs,), total_steps - i - 1, device=device, dtype=torch.long) |
| 232 | cur_cfg_scale = self.get_cfg_scale(cfg_scale, step) |
| 233 | x = self.p_sample( |
| 234 | model, |
| 235 | x, |
| 236 | model_t, |
| 237 | t, |
| 238 | cond, |
| 239 | uncond, |
| 240 | cur_cfg_scale, |
| 241 | ) |
| 242 | |
| 243 | if tiled: |
| 244 | model.forward = forward |
no test coverage detected