(self, conditional, batch_size)
| 200 | return [promt.format(w) for promt, w in zip(prompts, words)] |
| 201 | |
| 202 | def get_cond_vec(self, conditional, batch_size): |
| 203 | # compute conditional from a single string |
| 204 | if conditional is not None and type(conditional) == str: |
| 205 | cond = self.compute_conditional(conditional) |
| 206 | cond = cond.repeat(batch_size, 1) |
| 207 | |
| 208 | # compute conditional from string list/tuple |
| 209 | elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: |
| 210 | assert len(conditional) == batch_size |
| 211 | cond = self.compute_conditional(conditional) |
| 212 | |
| 213 | # use conditional directly |
| 214 | elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: |
| 215 | cond = conditional |
| 216 | |
| 217 | # compute conditional from image |
| 218 | elif conditional is not None and type(conditional) == torch.Tensor: |
| 219 | with torch.no_grad(): |
| 220 | cond, _, _ = self.visual_forward(conditional) |
| 221 | else: |
| 222 | raise ValueError('invalid conditional') |
| 223 | return cond |
| 224 | |
| 225 | def compute_conditional(self, conditional): |
| 226 | import clip |
no test coverage detected