reur and position embeddings :param input_ids: int Tensor of shape [batch_size, seq_length]. :param vocab_size: number of words in vocab :param embedding_size: dimensionality of the embedding :param position_offset: aka number of cached tokens. :param initializer_range: float. Ra
(input_ids,
vocab_size,
embedding_size,
position_offset=0,
initializer_range=0.02,
max_position_embeddings=512,
use_one_hot_embeddings=True)
| 256 | |
| 257 | |
| 258 | def embed(input_ids, |
| 259 | vocab_size, |
| 260 | embedding_size, |
| 261 | position_offset=0, |
| 262 | initializer_range=0.02, |
| 263 | max_position_embeddings=512, |
| 264 | use_one_hot_embeddings=True): |
| 265 | """reur and position embeddings |
| 266 | :param input_ids: int Tensor of shape [batch_size, seq_length]. |
| 267 | :param vocab_size: number of words in vocab |
| 268 | :param embedding_size: dimensionality of the embedding |
| 269 | :param position_offset: aka number of cached tokens. |
| 270 | :param initializer_range: float. Range of the weight initialization. |
| 271 | :param max_position_embeddings: int. Maximum sequence length. |
| 272 | :param use_one_hot_embeddings: probably want this to be true |
| 273 | :return: [batch_size, seq_length, embedding_size] embedded tensor |
| 274 | """ |
| 275 | (batch_size, seq_length) = get_shape_list(input_ids, expected_rank=2) |
| 276 | |
| 277 | embedding_table = tf.compat.v1.get_variable( |
| 278 | name='word_embed', |
| 279 | shape=[vocab_size, embedding_size], |
| 280 | initializer=create_initializer(initializer_range), |
| 281 | ) |
| 282 | |
| 283 | assert_op = tf.compat.v1.assert_less_equal(tf.reduce_max(input_ids), vocab_size - 1) |
| 284 | with tf.control_dependencies([assert_op]): |
| 285 | if use_one_hot_embeddings: |
| 286 | flat_input_ids = tf.reshape(input_ids, [-1]) |
| 287 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) |
| 288 | output_flat = tf.matmul(one_hot_input_ids, embedding_table) |
| 289 | else: |
| 290 | output_flat = tf.nn.embedding_lookup(embedding_table, input_ids) |
| 291 | |
| 292 | embedded_input = tf.reshape(output_flat, [batch_size, seq_length, embedding_size]) |
| 293 | |
| 294 | assert_op = tf.compat.v1.assert_less_equal(seq_length, max_position_embeddings) |
| 295 | |
| 296 | with tf.control_dependencies([assert_op]): |
| 297 | full_position_embeddings = tf.compat.v1.get_variable( |
| 298 | name='pos_embed', |
| 299 | shape=[max_position_embeddings, embedding_size], |
| 300 | initializer=create_initializer(initializer_range), |
| 301 | ) |
| 302 | # Since the position embedding table is a learned variable, we create it |
| 303 | # using a (long) sequence length `max_position_embeddings`. The actual |
| 304 | # sequence length might be shorter than this, for faster training of |
| 305 | # tasks that do not have long sequences. |
| 306 | # |
| 307 | # So `full_position_embeddings` is effectively an embedding table |
| 308 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current |
| 309 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just |
| 310 | # perform a slice. |
| 311 | if position_offset == 0: |
| 312 | embedded_input += tf.slice(full_position_embeddings, [0, 0], [seq_length, -1])[None] |
| 313 | else: |
| 314 | # Tensorflow is too stupid to allow slicing |
| 315 | flat_pos_ids = (tf.range(seq_length, dtype=tf.int32) + position_offset) |
no test coverage detected