MCPcopy
hub / github.com/Lightricks/ComfyUI-LTXVideo / _pad_inputs_for_attention_alignment

Function _pad_inputs_for_attention_alignment

gemma_encoder.py:519–547  ·  view source on GitHub ↗

Pad sequence length to multiple of alignment for Flash Attention compatibility. Flash Attention within SDPA requires sequence lengths aligned to 8 bytes. This pads input_ids, attention_mask, and token_type_ids (if present) to prevent 'p.attn_bias_ptr is not correctly aligned' errors.

(model_inputs, pad_token_id, alignment: int = 8)

Source from the content-addressed store, hash-verified

517
518
519def _pad_inputs_for_attention_alignment(model_inputs, pad_token_id, alignment: int = 8):
520 """Pad sequence length to multiple of alignment for Flash Attention compatibility.
521
522 Flash Attention within SDPA requires sequence lengths aligned to 8 bytes.
523 This pads input_ids, attention_mask, and token_type_ids (if present) to prevent
524 'p.attn_bias_ptr is not correctly aligned' errors.
525 """
526 seq_len = model_inputs.input_ids.shape[1]
527 padded_len = ((seq_len + alignment - 1) // alignment) * alignment
528 padding_length = padded_len - seq_len
529
530 if padding_length > 0:
531 model_inputs["input_ids"] = _cat_with_padding(
532 model_inputs.input_ids, padding_length, pad_token_id
533 )
534
535 model_inputs["attention_mask"] = _cat_with_padding(
536 model_inputs.attention_mask, padding_length, 0
537 )
538
539 if (
540 "token_type_ids" in model_inputs
541 and model_inputs["token_type_ids"] is not None
542 ):
543 model_inputs["token_type_ids"] = _cat_with_padding(
544 model_inputs["token_type_ids"], padding_length, 0
545 )
546
547 return model_inputs
548
549
550def _locate_model_within_model(super_model, model_name):

Callers 1

_enhanceFunction · 0.85

Calls 1

_cat_with_paddingFunction · 0.85

Tested by

no test coverage detected