Apply a delay pattern mask to the decoder input ids, only preserving predictions where the mask is set to -1, and otherwise setting to the value detailed in the mask.
(input_ids, decoder_pad_token_mask)
| 203 | per_codebook_losses: Optional[List[torch.FloatTensor]] = None |
| 204 | |
| 205 | def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): |
| 206 | """Apply a delay pattern mask to the decoder input ids, only preserving predictions where |
| 207 | the mask is set to -1, and otherwise setting to the value detailed in the mask.""" |
| 208 | seq_len = input_ids.shape[-1] |
| 209 | decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len] |
| 210 | input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) |
| 211 | return input_ids |
| 212 | |
| 213 | |
| 214 | def build_delay_pattern_mask( |
no outgoing calls
no test coverage detected