Defines heuristics to log different conditionings. These can be lists of strings (text-to-image), tensors, ints, ...
(self, batch: Dict, n: int)
| 247 | |
| 248 | @torch.no_grad() |
| 249 | def log_conditionings(self, batch: Dict, n: int) -> Dict: |
| 250 | """ |
| 251 | Defines heuristics to log different conditionings. |
| 252 | These can be lists of strings (text-to-image), tensors, ints, ... |
| 253 | """ |
| 254 | image_h, image_w = batch[self.input_key].shape[3:] |
| 255 | log = dict() |
| 256 | |
| 257 | for embedder in self.conditioner.embedders: |
| 258 | if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log: |
| 259 | x = batch[embedder.input_key][:n] |
| 260 | if isinstance(x, torch.Tensor): |
| 261 | if x.dim() == 1: |
| 262 | # class-conditional, convert integer to string |
| 263 | x = [str(x[i].item()) for i in range(x.shape[0])] |
| 264 | xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) |
| 265 | elif x.dim() == 2: |
| 266 | # size and crop cond and the like |
| 267 | x = ["x".join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0])] |
| 268 | xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) |
| 269 | else: |
| 270 | raise NotImplementedError() |
| 271 | elif isinstance(x, (List, ListConfig)): |
| 272 | if isinstance(x[0], str): |
| 273 | xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) |
| 274 | else: |
| 275 | raise NotImplementedError() |
| 276 | else: |
| 277 | raise NotImplementedError() |
| 278 | log[embedder.input_key] = xc |
| 279 | return log |
| 280 | |
| 281 | @torch.no_grad() |
| 282 | def log_video( |
no test coverage detected