The `model_fn` for TPUEstimator.
(features, labels, mode, params)
| 564 | """Returns `model_fn` closure for TPUEstimator.""" |
| 565 | |
| 566 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument |
| 567 | """The `model_fn` for TPUEstimator.""" |
| 568 | |
| 569 | tf.logging.info("*** Features ***") |
| 570 | for name in sorted(features.keys()): |
| 571 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) |
| 572 | |
| 573 | input_ids = features["input_ids"] |
| 574 | |
| 575 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) |
| 576 | |
| 577 | model = GroverModel( |
| 578 | config=config, |
| 579 | is_training=is_training, |
| 580 | input_ids=input_ids, |
| 581 | pad_token_id=config.pad_token_id, |
| 582 | chop_off_last_token=True, |
| 583 | ) |
| 584 | |
| 585 | total_loss = model.lm_loss() |
| 586 | print(model.logits_flat) |
| 587 | print(total_loss) |
| 588 | |
| 589 | if is_training: |
| 590 | train_op, train_metrics = create_optimizer( |
| 591 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) |
| 592 | tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) |
| 593 | else: |
| 594 | train_op = None |
| 595 | train_metrics = {} |
| 596 | tvars = tf.trainable_variables() |
| 597 | params_sum = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) |
| 598 | tf.logging.info("**** Trainable params_sum ****") |
| 599 | tf.logging.info(params_sum) |
| 600 | initialized_variable_names = {} |
| 601 | scaffold_fn = None |
| 602 | if init_checkpoint: |
| 603 | (assignment_map, initialized_variable_names |
| 604 | ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint) |
| 605 | if use_tpu: |
| 606 | def tpu_scaffold(): |
| 607 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) |
| 608 | return tf.train.Scaffold() |
| 609 | |
| 610 | scaffold_fn = tpu_scaffold |
| 611 | else: |
| 612 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) |
| 613 | |
| 614 | tf.logging.info("**** Trainable Variables ****") |
| 615 | for var in tvars: |
| 616 | init_string = "" |
| 617 | if var.name in initialized_variable_names: |
| 618 | init_string = ", *INIT_FROM_CKPT*" |
| 619 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, |
| 620 | init_string) |
| 621 | |
| 622 | output_spec = None |
| 623 | if mode == tf.estimator.ModeKeys.TRAIN: |
nothing calls this directly
no test coverage detected