MCPcopy
hub / github.com/Turing-Project/WriteGPT / lm_loss

Method lm_loss

LanguageNetwork/GPT2/scripts/modeling.py:526–550  ·  view source on GitHub ↗

:return: stuff

(self)

Source from the content-addressed store, hash-verified

524 return tf.reshape(logprobs_flat, [self.batch_size, self.seq_length, -1])
525
526 def lm_loss(self):
527 """
528 :return: stuff
529 """
530 target_ids_flat = tf.reshape(self.target_ids, [-1])
531
532 # 1 if it's valid and 0 otherwise.
533 label_weights = tf.cast(tf.not_equal(target_ids_flat, self.pad_token_id), dtype=self.logits_flat.dtype)
534
535 # [batch_size * seq_length, vocab_size]
536 one_hot_labels = tf.one_hot(target_ids_flat,
537 depth=self.config.vocab_size,
538 dtype=self.logits_flat.dtype)
539
540 # [batch_size * seq_length, vocab_size]
541 logprobs_flat = tf.nn.log_softmax(self.logits_flat, axis=-1)
542
543 per_example_loss = -tf.reduce_sum(logprobs_flat * one_hot_labels, axis=[-1])
544
545 # per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits_flat, labels=target_ids_flat)
546
547 numerator = tf.reduce_sum(label_weights * per_example_loss)
548 denominator = tf.reduce_sum(label_weights) + 1e-5
549 loss = numerator / denominator
550 return loss
551
552 def pooled_output(self, clf_token):
553 """

Callers 1

model_fnFunction · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected