:param x_flat: Tensor input, should be [batch_size*seq_length, dim] :param attention_mask: Attention mask to use of size [seq_length, seq_length+cached_length] :param size_per_head: dim = size_per_head * num_attention_heads :param num_attention_heads: dim = size_per_head * num_atte
(x_flat, attention_mask, batch_size, seq_length, size_per_head=512, num_attention_heads=1, *,
cache=None,
initializer_range=0.02, hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1, do_cache=False)
| 147 | |
| 148 | |
| 149 | def attention_layer(x_flat, attention_mask, batch_size, seq_length, size_per_head=512, num_attention_heads=1, *, |
| 150 | cache=None, |
| 151 | initializer_range=0.02, hidden_dropout_prob=0.1, |
| 152 | attention_probs_dropout_prob=0.1, do_cache=False): |
| 153 | """ |
| 154 | :param x_flat: Tensor input, should be [batch_size*seq_length, dim] |
| 155 | :param attention_mask: Attention mask to use of size [seq_length, seq_length+cached_length] |
| 156 | :param size_per_head: dim = size_per_head * num_attention_heads |
| 157 | :param num_attention_heads: dim = size_per_head * num_attention_heads |
| 158 | :param cache: Optionally some past (cached) things of size |
| 159 | [batch, 2, heads, sequence, features], where 2 is [k, v] |
| 160 | :param do_cache: True if we should return cache |
| 161 | :return: A new tensor of shape [batch_size, seq_length, dim] |
| 162 | as well as a new cache "cached_keys_and_values" that will be of size |
| 163 | [batch_size, 2, num_attention_heads, seq_length, dim] |
| 164 | """ |
| 165 | batch_size_seq_length, dim = get_shape_list(x_flat, expected_rank=2) |
| 166 | |
| 167 | if dim != size_per_head * num_attention_heads: |
| 168 | raise ValueError("passed in a tensor of shape {} when size_per_head={} and num_attention_heads={}".format( |
| 169 | (batch_size_seq_length, dim), size_per_head, num_attention_heads |
| 170 | )) |
| 171 | |
| 172 | query = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length, |
| 173 | num_attention_heads=num_attention_heads, size_per_head=size_per_head, |
| 174 | name='query_layer', |
| 175 | initializer_range=initializer_range) |
| 176 | key = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length, |
| 177 | num_attention_heads=num_attention_heads, size_per_head=size_per_head, |
| 178 | name='key_layer', |
| 179 | initializer_range=initializer_range) |
| 180 | |
| 181 | value = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length, |
| 182 | num_attention_heads=num_attention_heads, size_per_head=size_per_head, |
| 183 | name='value_layer', |
| 184 | initializer_range=initializer_range) |
| 185 | |
| 186 | # Add to cache |
| 187 | cached_keys_and_values = tf.stack([key, value], axis=1) if do_cache else None |
| 188 | |
| 189 | # Things that were relevant from the cache |
| 190 | if cache is not None: |
| 191 | pk, pv = tf.unstack(cache, axis=1) |
| 192 | key = tf.concat([pk, key], axis=-2) |
| 193 | value = tf.concat([pv, value], axis=-2) |
| 194 | |
| 195 | # Multiply [batch_size, num_attention_heads, seq_length, size_per_head] with |
| 196 | # [batch_size, num_attention_heads, size_per_head, seq_length+cached_length] -> |
| 197 | # [batch_size, num_attention_heads, seq_length, seq_length+cached_length] |
| 198 | attention_scores = tf.matmul(query, key, transpose_b=True) |
| 199 | attention_scores = tf.multiply(attention_scores, |
| 200 | 1.0 / math.sqrt(float(size_per_head))) |
| 201 | attention_scores = mask_attention_for_ltr(attention_scores, attention_mask) |
| 202 | attention_probs = tf.nn.softmax(attention_scores) |
| 203 | |
| 204 | # This is actually dropping out entire tokens to attend to, which might |
| 205 | # seem a bit unusual, but is taken from the original Transformer paper. |
| 206 | # NOPENOPENOPENOPE |
no test coverage detected