this is a TPU compatible version of tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd) where the lower right triangle contains 1s
(nd, ns, *, dtype)
| 143 | |
| 144 | |
| 145 | def get_attention_mask(nd, ns, *, dtype): |
| 146 | """ |
| 147 | this is a TPU compatible version of tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd) |
| 148 | where the lower right triangle contains 1s |
| 149 | """ |
| 150 | i = tf.range(nd)[:, None] |
| 151 | j = tf.range(ns) |
| 152 | m = i >= j - ns + nd |
| 153 | return tf.cast(m, dtype) |
| 154 | |
| 155 | |
| 156 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): |