MCPcopy
hub / github.com/vladmandic/sdnext / reconstruct_multicond_batch

Function reconstruct_multicond_batch

modules/prompt_parser.py:257–281  ·  view source on GitHub ↗
(c: MulticondLearnedConditioning, current_step)

Source from the content-addressed store, hash-verified

255
256
257def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
258 param = c.batch[0][0].schedules[0].cond
259 tensors = []
260 conds_list = []
261 for composable_prompts in c.batch:
262 conds_for_batch = []
263 for composable_prompt in composable_prompts:
264 target_index = 0
265 for current, entry in enumerate(composable_prompt.schedules):
266 if current_step <= entry.end_at_step:
267 target_index = current
268 break
269 conds_for_batch.append((len(tensors), composable_prompt.weight))
270 tensors.append(composable_prompt.schedules[target_index].cond)
271 conds_list.append(conds_for_batch)
272 # if prompts have wildly different lengths above the limit we'll get tensors of different shapes and won't be able to torch.stack them. So this fixes that.
273 if not tensors:
274 return conds_list, torch.zeros([0], device=param.device, dtype=param.dtype)
275 token_count = max([x.shape[0] for x in tensors])
276 for i in range(len(tensors)):
277 if tensors[i].shape[0] != token_count:
278 last_vector = tensors[i][-1:]
279 last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
280 tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
281 return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
282
283
284def parse_prompt_attention(text):

Callers

nothing calls this directly

Calls 1

toMethod · 0.45

Tested by

no test coverage detected