Mask attention so that we're only predicting going forward :param attention_scores: [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. :param attention_mask [query_length, key_length] :return: masked attention
(attention_scores, attention_mask)
| 103 | |
| 104 | |
| 105 | def mask_attention_for_ltr(attention_scores, attention_mask): |
| 106 | """ |
| 107 | Mask attention so that we're only predicting going forward |
| 108 | :param attention_scores: [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. |
| 109 | :param attention_mask [query_length, key_length] |
| 110 | :return: masked attention |
| 111 | """ |
| 112 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for |
| 113 | # masked positions, this operation will create a tensor which is 0.0 for |
| 114 | # positions we want to attend and -10000.0 for masked positions. |
| 115 | mask = attention_mask[None, None] |
| 116 | return attention_scores * mask - tf.cast(1e10, attention_scores.dtype) * (1 - mask) |
| 117 | |
| 118 | |
| 119 | def create_initializer(initializer_range=0.02): |