(self)
| 153 | |
| 154 | @torch.no_grad() |
| 155 | def run(self) -> None: |
| 156 | self.setup() |
| 157 | auto_cast_type = { |
| 158 | "fp32": torch.float32, |
| 159 | "fp16": torch.float16, |
| 160 | "bf16": torch.bfloat16, |
| 161 | }[self.args.precision] |
| 162 | |
| 163 | for lq in self.load_lq(): |
| 164 | # prepare prompt |
| 165 | with VRAMPeakMonitor("applying captioner"): |
| 166 | caption = self.captioner(lq) |
| 167 | pos_prompt = ", ".join( |
| 168 | [text for text in [caption, self.args.pos_prompt] if text] |
| 169 | ) |
| 170 | neg_prompt = self.args.neg_prompt |
| 171 | lq = self.after_load_lq(lq) |
| 172 | |
| 173 | # batch process |
| 174 | n_samples = self.args.n_samples |
| 175 | batch_size = self.args.batch_size |
| 176 | num_batches = (n_samples + batch_size - 1) // batch_size |
| 177 | samples = [] |
| 178 | for i in range(num_batches): |
| 179 | n_inputs = min((i + 1) * batch_size, n_samples) - i * batch_size |
| 180 | with torch.autocast(self.args.device, auto_cast_type): |
| 181 | batch_samples = self.pipeline.run( |
| 182 | np.tile(lq[None], (n_inputs, 1, 1, 1)), |
| 183 | self.args.steps, |
| 184 | self.args.strength, |
| 185 | self.args.cleaner_tiled, |
| 186 | self.args.cleaner_tile_size, |
| 187 | self.args.cleaner_tile_stride, |
| 188 | self.args.vae_encoder_tiled, |
| 189 | self.args.vae_encoder_tile_size, |
| 190 | self.args.vae_decoder_tiled, |
| 191 | self.args.vae_decoder_tile_size, |
| 192 | self.args.cldm_tiled, |
| 193 | self.args.cldm_tile_size, |
| 194 | self.args.cldm_tile_stride, |
| 195 | pos_prompt, |
| 196 | neg_prompt, |
| 197 | self.args.cfg_scale, |
| 198 | self.args.start_point_type, |
| 199 | self.args.sampler, |
| 200 | self.args.noise_aug, |
| 201 | self.args.rescale_cfg, |
| 202 | self.args.s_churn, |
| 203 | self.args.s_tmin, |
| 204 | self.args.s_tmax, |
| 205 | self.args.s_noise, |
| 206 | self.args.eta, |
| 207 | self.args.order, |
| 208 | ) |
| 209 | samples.extend(list(batch_samples)) |
| 210 | self.save(samples, pos_prompt, neg_prompt) |
| 211 | |
| 212 | def save(self, samples: List[np.ndarray], pos_prompt: str, neg_prompt: str) -> None: |
nothing calls this directly
no test coverage detected