:return: stuff
(self)
| 524 | return tf.reshape(logprobs_flat, [self.batch_size, self.seq_length, -1]) |
| 525 | |
| 526 | def lm_loss(self): |
| 527 | """ |
| 528 | :return: stuff |
| 529 | """ |
| 530 | target_ids_flat = tf.reshape(self.target_ids, [-1]) |
| 531 | |
| 532 | # 1 if it's valid and 0 otherwise. |
| 533 | label_weights = tf.cast(tf.not_equal(target_ids_flat, self.pad_token_id), dtype=self.logits_flat.dtype) |
| 534 | |
| 535 | # [batch_size * seq_length, vocab_size] |
| 536 | one_hot_labels = tf.one_hot(target_ids_flat, |
| 537 | depth=self.config.vocab_size, |
| 538 | dtype=self.logits_flat.dtype) |
| 539 | |
| 540 | # [batch_size * seq_length, vocab_size] |
| 541 | logprobs_flat = tf.nn.log_softmax(self.logits_flat, axis=-1) |
| 542 | |
| 543 | per_example_loss = -tf.reduce_sum(logprobs_flat * one_hot_labels, axis=[-1]) |
| 544 | |
| 545 | # per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits_flat, labels=target_ids_flat) |
| 546 | |
| 547 | numerator = tf.reduce_sum(label_weights * per_example_loss) |
| 548 | denominator = tf.reduce_sum(label_weights) + 1e-5 |
| 549 | loss = numerator / denominator |
| 550 | return loss |
| 551 | |
| 552 | def pooled_output(self, clf_token): |
| 553 | """ |