MCPcopy
hub / github.com/tensorflow/models / parser

Function parser

official/legacy/xlnet/preprocess_pretrain_data.py:745–834  ·  view source on GitHub ↗

function used to parse tfrecord.

(record)

Source from the content-addressed store, hash-verified

743
744 #### Function used to parse tfrecord
745 def parser(record):
746 """function used to parse tfrecord."""
747
748 record_spec = {
749 "input": tf.FixedLenFeature([seq_len], tf.int64),
750 "target": tf.FixedLenFeature([seq_len], tf.int64),
751 "seg_id": tf.FixedLenFeature([seq_len], tf.int64),
752 "label": tf.FixedLenFeature([1], tf.int64),
753 "is_masked": tf.FixedLenFeature([seq_len], tf.int64),
754 }
755
756 # retrieve serialized example
757 example = tf.parse_single_example(
758 serialized=record,
759 features=record_spec)
760
761 inputs = example.pop("input")
762 target = example.pop("target")
763 is_masked = tf.cast(example.pop("is_masked"), tf.bool)
764
765 non_reuse_len = seq_len - reuse_len
766 assert perm_size <= reuse_len and perm_size <= non_reuse_len
767
768 perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
769 inputs[:reuse_len],
770 target[:reuse_len],
771 is_masked[:reuse_len],
772 perm_size,
773 reuse_len)
774
775 perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
776 inputs[reuse_len:],
777 target[reuse_len:],
778 is_masked[reuse_len:],
779 perm_size,
780 non_reuse_len)
781
782 perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
783 axis=1)
784 perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
785 axis=1)
786 perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
787 target = tf.concat([target_0, target_1], axis=0)
788 target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
789 input_k = tf.concat([input_k_0, input_k_1], axis=0)
790 input_q = tf.concat([input_q_0, input_q_1], axis=0)
791
792 if num_predict is not None:
793 indices = tf.range(seq_len, dtype=tf.int64)
794 bool_target_mask = tf.cast(target_mask, tf.bool)
795 indices = tf.boolean_mask(indices, bool_target_mask)
796
797 ##### extra padding due to CLS/SEP introduced after prepro
798 actual_num_predict = tf.shape(indices)[0]
799 pad_len = num_predict - actual_num_predict
800
801 ##### target_mapping
802 target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)

Callers

nothing calls this directly

Calls 5

_convert_exampleFunction · 0.85
popMethod · 0.80
concatMethod · 0.80
infoMethod · 0.80
_local_permFunction · 0.70

Tested by

no test coverage detected