| 255 | |
| 256 | |
| 257 | def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): |
| 258 | param = c.batch[0][0].schedules[0].cond |
| 259 | tensors = [] |
| 260 | conds_list = [] |
| 261 | for composable_prompts in c.batch: |
| 262 | conds_for_batch = [] |
| 263 | for composable_prompt in composable_prompts: |
| 264 | target_index = 0 |
| 265 | for current, entry in enumerate(composable_prompt.schedules): |
| 266 | if current_step <= entry.end_at_step: |
| 267 | target_index = current |
| 268 | break |
| 269 | conds_for_batch.append((len(tensors), composable_prompt.weight)) |
| 270 | tensors.append(composable_prompt.schedules[target_index].cond) |
| 271 | conds_list.append(conds_for_batch) |
| 272 | # if prompts have wildly different lengths above the limit we'll get tensors of different shapes and won't be able to torch.stack them. So this fixes that. |
| 273 | if not tensors: |
| 274 | return conds_list, torch.zeros([0], device=param.device, dtype=param.dtype) |
| 275 | token_count = max([x.shape[0] for x in tensors]) |
| 276 | for i in range(len(tensors)): |
| 277 | if tensors[i].shape[0] != token_count: |
| 278 | last_vector = tensors[i][-1:] |
| 279 | last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) |
| 280 | tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) |
| 281 | return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) |
| 282 | |
| 283 | |
| 284 | def parse_prompt_attention(text): |