Does top-k sampling. if ignore_ids is on, then we will zero out those logits. :param logits: [batch_size, vocab_size] tensor :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, like padding maybe :param
(logits, ignore_ids=None, num_samples=1, k=10)
| 378 | |
| 379 | |
| 380 | def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10): |
| 381 | """ |
| 382 | Does top-k sampling. if ignore_ids is on, then we will zero out those logits. |
| 383 | :param logits: [batch_size, vocab_size] tensor |
| 384 | :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, |
| 385 | like padding maybe |
| 386 | :param p: topp threshold to use, either a float or a [batch_size] vector |
| 387 | :return: [batch_size, num_samples] samples |
| 388 | # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK |
| 389 | """ |
| 390 | with tf.compat.v1.variable_scope('top_p_sample'): |
| 391 | batch_size, vocab_size = get_shape_list(logits, expected_rank=2) |
| 392 | |
| 393 | probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, |
| 394 | axis=-1) |
| 395 | # [batch_size, vocab_perm] |
| 396 | indices = tf.argsort(probs, direction='DESCENDING') |
| 397 | |
| 398 | # find the top pth index to cut off. careful we don't want to cutoff everything! |
| 399 | # result will be [batch_size, vocab_perm] |
| 400 | k_expanded = k if isinstance(k, int) else k[:, None] |
| 401 | exclude_mask = tf.range(vocab_size)[None] >= k_expanded |
| 402 | |
| 403 | # OPTION A - sample in the sorted space, then unsort. |
| 404 | logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10 |
| 405 | sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples) |
| 406 | sample = tf.batch_gather(indices, sample_perm) |
| 407 | |
| 408 | return { |
| 409 | 'probs': probs, |
| 410 | 'sample': sample, |
| 411 | } |
| 412 | |
| 413 | |
| 414 | class GroverModel(object): |
no test coverage detected