(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
prefix=None,
concat_images=None,
**kwargs,
)
| 211 | |
| 212 | @torch.no_grad() |
| 213 | def sample( |
| 214 | self, |
| 215 | cond: Dict, |
| 216 | uc: Union[Dict, None] = None, |
| 217 | batch_size: int = 16, |
| 218 | shape: Union[None, Tuple, List] = None, |
| 219 | prefix=None, |
| 220 | concat_images=None, |
| 221 | **kwargs, |
| 222 | ): |
| 223 | randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device) |
| 224 | if hasattr(self, "seeded_noise"): |
| 225 | randn = self.seeded_noise(randn) |
| 226 | |
| 227 | if prefix is not None: |
| 228 | randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1) |
| 229 | |
| 230 | # broadcast noise |
| 231 | mp_size = mpu.get_model_parallel_world_size() |
| 232 | if mp_size > 1: |
| 233 | global_rank = torch.distributed.get_rank() // mp_size |
| 234 | src = global_rank * mp_size |
| 235 | torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group()) |
| 236 | |
| 237 | scale = None |
| 238 | scale_emb = None |
| 239 | |
| 240 | denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser( |
| 241 | self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs |
| 242 | ) |
| 243 | |
| 244 | samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb) |
| 245 | samples = samples.to(self.dtype) |
| 246 | return samples |
| 247 | |
| 248 | @torch.no_grad() |
| 249 | def log_conditionings(self, batch: Dict, n: int) -> Dict: |
no outgoing calls
no test coverage detected