Function used to parse tfrecord.
(record)
| 428 | """Creates pretrain dataset.""" |
| 429 | |
| 430 | def parser(record): |
| 431 | """Function used to parse tfrecord.""" |
| 432 | |
| 433 | record_spec = { |
| 434 | "input": tf.io.FixedLenFeature([seq_len], tf.int64), |
| 435 | "seg_id": tf.io.FixedLenFeature([seq_len], tf.int64), |
| 436 | "label": tf.io.FixedLenFeature([1], tf.int64), |
| 437 | } |
| 438 | |
| 439 | if online_masking_config.sample_strategy in ["whole_word", "word_span"]: |
| 440 | logging.info("Add `boundary` spec for %s", |
| 441 | online_masking_config.sample_strategy) |
| 442 | record_spec["boundary"] = tf.io.VarLenFeature(tf.int64) |
| 443 | |
| 444 | # retrieve serialized example |
| 445 | example = tf.io.parse_single_example( |
| 446 | serialized=record, features=record_spec) |
| 447 | |
| 448 | inputs = example.pop("input") |
| 449 | if online_masking_config.sample_strategy in ["whole_word", "word_span"]: |
| 450 | boundary = tf.sparse.to_dense(example.pop("boundary")) |
| 451 | else: |
| 452 | boundary = None |
| 453 | is_masked, _ = _online_sample_masks( |
| 454 | inputs, seq_len, num_predict, online_masking_config, boundary=boundary) |
| 455 | |
| 456 | if reuse_len > 0: |
| 457 | ##### Use memory |
| 458 | # permutate the reuse and non-reuse parts separately |
| 459 | non_reuse_len = seq_len - reuse_len |
| 460 | assert reuse_len % perm_size == 0 and non_reuse_len % perm_size == 0 |
| 461 | |
| 462 | # Creates permutation mask and target mask for the first reuse_len tokens. |
| 463 | # The tokens in this part are reused from the last sequence. |
| 464 | perm_mask_0, target_mask_0, input_k_0, input_q_0 = _local_perm( |
| 465 | inputs[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len, |
| 466 | leak_ratio) |
| 467 | |
| 468 | # Creates permutation mask and target mask for the rest of tokens in |
| 469 | # current example, which are concatentation of two new segments. |
| 470 | perm_mask_1, target_mask_1, input_k_1, input_q_1 = _local_perm( |
| 471 | inputs[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len, |
| 472 | leak_ratio) |
| 473 | |
| 474 | perm_mask_0 = tf.concat( |
| 475 | [perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) |
| 476 | perm_mask_1 = tf.concat( |
| 477 | [tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1) |
| 478 | perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) |
| 479 | target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) |
| 480 | input_k = tf.concat([input_k_0, input_k_1], axis=0) |
| 481 | input_q = tf.concat([input_q_0, input_q_1], axis=0) |
| 482 | else: |
| 483 | ##### Do not use memory |
| 484 | assert seq_len % perm_size == 0 |
| 485 | # permutate the entire sequence together |
| 486 | perm_mask, target_mask, input_k, input_q = _local_perm( |
| 487 | inputs, is_masked, perm_size, seq_len, leak_ratio) |