function used to parse tfrecord.
(record)
| 743 | |
| 744 | #### Function used to parse tfrecord |
| 745 | def parser(record): |
| 746 | """function used to parse tfrecord.""" |
| 747 | |
| 748 | record_spec = { |
| 749 | "input": tf.FixedLenFeature([seq_len], tf.int64), |
| 750 | "target": tf.FixedLenFeature([seq_len], tf.int64), |
| 751 | "seg_id": tf.FixedLenFeature([seq_len], tf.int64), |
| 752 | "label": tf.FixedLenFeature([1], tf.int64), |
| 753 | "is_masked": tf.FixedLenFeature([seq_len], tf.int64), |
| 754 | } |
| 755 | |
| 756 | # retrieve serialized example |
| 757 | example = tf.parse_single_example( |
| 758 | serialized=record, |
| 759 | features=record_spec) |
| 760 | |
| 761 | inputs = example.pop("input") |
| 762 | target = example.pop("target") |
| 763 | is_masked = tf.cast(example.pop("is_masked"), tf.bool) |
| 764 | |
| 765 | non_reuse_len = seq_len - reuse_len |
| 766 | assert perm_size <= reuse_len and perm_size <= non_reuse_len |
| 767 | |
| 768 | perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm( |
| 769 | inputs[:reuse_len], |
| 770 | target[:reuse_len], |
| 771 | is_masked[:reuse_len], |
| 772 | perm_size, |
| 773 | reuse_len) |
| 774 | |
| 775 | perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm( |
| 776 | inputs[reuse_len:], |
| 777 | target[reuse_len:], |
| 778 | is_masked[reuse_len:], |
| 779 | perm_size, |
| 780 | non_reuse_len) |
| 781 | |
| 782 | perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])], |
| 783 | axis=1) |
| 784 | perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], |
| 785 | axis=1) |
| 786 | perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) |
| 787 | target = tf.concat([target_0, target_1], axis=0) |
| 788 | target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) |
| 789 | input_k = tf.concat([input_k_0, input_k_1], axis=0) |
| 790 | input_q = tf.concat([input_q_0, input_q_1], axis=0) |
| 791 | |
| 792 | if num_predict is not None: |
| 793 | indices = tf.range(seq_len, dtype=tf.int64) |
| 794 | bool_target_mask = tf.cast(target_mask, tf.bool) |
| 795 | indices = tf.boolean_mask(indices, bool_target_mask) |
| 796 | |
| 797 | ##### extra padding due to CLS/SEP introduced after prepro |
| 798 | actual_num_predict = tf.shape(indices)[0] |
| 799 | pad_len = num_predict - actual_num_predict |
| 800 | |
| 801 | ##### target_mapping |
| 802 | target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) |
nothing calls this directly
no test coverage detected