MCPcopy
hub / github.com/Turing-Project/WriteGPT / model_fn

Function model_fn

LanguageNetwork/GPT2/scripts/modeling.py:566–685  ·  view source on GitHub ↗

The `model_fn` for TPUEstimator.

(features, labels, mode, params)

Source from the content-addressed store, hash-verified

564 """Returns `model_fn` closure for TPUEstimator."""
565
566 def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
567 """The `model_fn` for TPUEstimator."""
568
569 tf.logging.info("*** Features ***")
570 for name in sorted(features.keys()):
571 tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
572
573 input_ids = features["input_ids"]
574
575 is_training = (mode == tf.estimator.ModeKeys.TRAIN)
576
577 model = GroverModel(
578 config=config,
579 is_training=is_training,
580 input_ids=input_ids,
581 pad_token_id=config.pad_token_id,
582 chop_off_last_token=True,
583 )
584
585 total_loss = model.lm_loss()
586 print(model.logits_flat)
587 print(total_loss)
588
589 if is_training:
590 train_op, train_metrics = create_optimizer(
591 total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
592 tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
593 else:
594 train_op = None
595 train_metrics = {}
596 tvars = tf.trainable_variables()
597 params_sum = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
598 tf.logging.info("**** Trainable params_sum ****")
599 tf.logging.info(params_sum)
600 initialized_variable_names = {}
601 scaffold_fn = None
602 if init_checkpoint:
603 (assignment_map, initialized_variable_names
604 ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint)
605 if use_tpu:
606 def tpu_scaffold():
607 tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
608 return tf.train.Scaffold()
609
610 scaffold_fn = tpu_scaffold
611 else:
612 tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
613
614 tf.logging.info("**** Trainable Variables ****")
615 for var in tvars:
616 init_string = ""
617 if var.name in initialized_variable_names:
618 init_string = ", *INIT_FROM_CKPT*"
619 tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
620 init_string)
621
622 output_spec = None
623 if mode == tf.estimator.ModeKeys.TRAIN:

Callers

nothing calls this directly

Calls 7

lm_lossMethod · 0.95
get_shape_listFunction · 0.90
GroverModelClass · 0.70
create_optimizerFunction · 0.70
_top_p_sampleFunction · 0.70

Tested by

no test coverage detected