MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / _top_k_sample

Function _top_k_sample

LanguageNetwork/GPT2/scripts/modeling.py:380–411  ·  view source on GitHub ↗

Does top-k sampling. if ignore_ids is on, then we will zero out those logits. :param logits: [batch_size, vocab_size] tensor :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, like padding maybe :param

(logits, ignore_ids=None, num_samples=1, k=10)

Source from the content-addressed store, hash-verified

378
379
380def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10):
381 """
382 Does top-k sampling. if ignore_ids is on, then we will zero out those logits.
383 :param logits: [batch_size, vocab_size] tensor
384 :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
385 like padding maybe
386 :param p: topp threshold to use, either a float or a [batch_size] vector
387 :return: [batch_size, num_samples] samples
388 # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
389 """
390 with tf.compat.v1.variable_scope('top_p_sample'):
391 batch_size, vocab_size = get_shape_list(logits, expected_rank=2)
392
393 probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
394 axis=-1)
395 # [batch_size, vocab_perm]
396 indices = tf.argsort(probs, direction='DESCENDING')
397
398 # find the top pth index to cut off. careful we don't want to cutoff everything!
399 # result will be [batch_size, vocab_perm]
400 k_expanded = k if isinstance(k, int) else k[:, None]
401 exclude_mask = tf.range(vocab_size)[None] >= k_expanded
402
403 # OPTION A - sample in the sorted space, then unsort.
404 logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
405 sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
406 sample = tf.batch_gather(indices, sample_perm)
407
408 return {
409 'probs': probs,
410 'sample': sample,
411 }
412
413
414class GroverModel(object):

Callers 1

sample_stepFunction · 0.70

Calls 1

get_shape_listFunction · 0.90

Tested by

no test coverage detected