MCPcopy
hub / github.com/tensorflow/models / parser

Function parser

official/legacy/xlnet/data_utils.py:430–542  ·  view source on GitHub ↗

Function used to parse tfrecord.

(record)

Source from the content-addressed store, hash-verified

428 """Creates pretrain dataset."""
429
430 def parser(record):
431 """Function used to parse tfrecord."""
432
433 record_spec = {
434 "input": tf.io.FixedLenFeature([seq_len], tf.int64),
435 "seg_id": tf.io.FixedLenFeature([seq_len], tf.int64),
436 "label": tf.io.FixedLenFeature([1], tf.int64),
437 }
438
439 if online_masking_config.sample_strategy in ["whole_word", "word_span"]:
440 logging.info("Add `boundary` spec for %s",
441 online_masking_config.sample_strategy)
442 record_spec["boundary"] = tf.io.VarLenFeature(tf.int64)
443
444 # retrieve serialized example
445 example = tf.io.parse_single_example(
446 serialized=record, features=record_spec)
447
448 inputs = example.pop("input")
449 if online_masking_config.sample_strategy in ["whole_word", "word_span"]:
450 boundary = tf.sparse.to_dense(example.pop("boundary"))
451 else:
452 boundary = None
453 is_masked, _ = _online_sample_masks(
454 inputs, seq_len, num_predict, online_masking_config, boundary=boundary)
455
456 if reuse_len > 0:
457 ##### Use memory
458 # permutate the reuse and non-reuse parts separately
459 non_reuse_len = seq_len - reuse_len
460 assert reuse_len % perm_size == 0 and non_reuse_len % perm_size == 0
461
462 # Creates permutation mask and target mask for the first reuse_len tokens.
463 # The tokens in this part are reused from the last sequence.
464 perm_mask_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
465 inputs[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len,
466 leak_ratio)
467
468 # Creates permutation mask and target mask for the rest of tokens in
469 # current example, which are concatentation of two new segments.
470 perm_mask_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
471 inputs[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len,
472 leak_ratio)
473
474 perm_mask_0 = tf.concat(
475 [perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1)
476 perm_mask_1 = tf.concat(
477 [tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1)
478 perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
479 target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
480 input_k = tf.concat([input_k_0, input_k_1], axis=0)
481 input_q = tf.concat([input_q_0, input_q_1], axis=0)
482 else:
483 ##### Do not use memory
484 assert seq_len % perm_size == 0
485 # permutate the entire sequence together
486 perm_mask, target_mask, input_k, input_q = _local_perm(
487 inputs, is_masked, perm_size, seq_len, leak_ratio)

Calls 5

_online_sample_masksFunction · 0.85
infoMethod · 0.80
popMethod · 0.80
concatMethod · 0.80
_local_permFunction · 0.70