Pad sequence length to multiple of alignment for Flash Attention compatibility. Flash Attention within SDPA requires sequence lengths aligned to 8 bytes. This pads input_ids, attention_mask, and token_type_ids (if present) to prevent 'p.attn_bias_ptr is not correctly aligned' errors.
(model_inputs, pad_token_id, alignment: int = 8)
| 517 | |
| 518 | |
| 519 | def _pad_inputs_for_attention_alignment(model_inputs, pad_token_id, alignment: int = 8): |
| 520 | """Pad sequence length to multiple of alignment for Flash Attention compatibility. |
| 521 | |
| 522 | Flash Attention within SDPA requires sequence lengths aligned to 8 bytes. |
| 523 | This pads input_ids, attention_mask, and token_type_ids (if present) to prevent |
| 524 | 'p.attn_bias_ptr is not correctly aligned' errors. |
| 525 | """ |
| 526 | seq_len = model_inputs.input_ids.shape[1] |
| 527 | padded_len = ((seq_len + alignment - 1) // alignment) * alignment |
| 528 | padding_length = padded_len - seq_len |
| 529 | |
| 530 | if padding_length > 0: |
| 531 | model_inputs["input_ids"] = _cat_with_padding( |
| 532 | model_inputs.input_ids, padding_length, pad_token_id |
| 533 | ) |
| 534 | |
| 535 | model_inputs["attention_mask"] = _cat_with_padding( |
| 536 | model_inputs.attention_mask, padding_length, 0 |
| 537 | ) |
| 538 | |
| 539 | if ( |
| 540 | "token_type_ids" in model_inputs |
| 541 | and model_inputs["token_type_ids"] is not None |
| 542 | ): |
| 543 | model_inputs["token_type_ids"] = _cat_with_padding( |
| 544 | model_inputs["token_type_ids"], padding_length, 0 |
| 545 | ) |
| 546 | |
| 547 | return model_inputs |
| 548 | |
| 549 | |
| 550 | def _locate_model_within_model(super_model, model_name): |
no test coverage detected