MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / _top_p_sample

Function _top_p_sample

LanguageNetwork/GPT2/scripts/modeling.py:327–377  ·  view source on GitHub ↗

Does top-p sampling. if ignore_ids is on, then we will zero out those logits. :param logits: [batch_size, vocab_size] tensor :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, like padding maybe :param

(logits, ignore_ids=None, num_samples=1, p=0.9)

Source from the content-addressed store, hash-verified

325
326
327def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9):
328 """
329 Does top-p sampling. if ignore_ids is on, then we will zero out those logits.
330 :param logits: [batch_size, vocab_size] tensor
331 :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
332 like padding maybe
333 :param p: topp threshold to use, either a float or a [batch_size] vector
334 :return: [batch_size, num_samples] samples
335 # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
336 """
337 with tf.compat.v1.variable_scope('top_p_sample'):
338 batch_size, vocab_size = get_shape_list(logits, expected_rank=2)
339
340 probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
341 axis=-1)
342
343 if isinstance(p, float) and p > 0.999999:
344 # Don't do top-p sampling in this case
345 print("Top-p sampling DISABLED", flush=True)
346 return {
347 'probs': probs,
348 'sample': tf.random.categorical(
349 logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
350 num_samples=num_samples, dtype=tf.int32),
351 }
352
353 # [batch_size, vocab_perm]
354 indices = tf.argsort(probs, direction='DESCENDING')
355 cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)
356
357 # find the top pth index to cut off. careful we don't want to cutoff everything!
358 # result will be [batch_size, vocab_perm]
359 p_expanded = p if isinstance(p, float) else p[:, None]
360 exclude_mask = tf.logical_not(
361 tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1))
362
363 # OPTION A - sample in the sorted space, then unsort.
364 logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
365 sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
366 sample = tf.batch_gather(indices, sample_perm)
367
368 # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample
369 # unperm_indices = tf.argsort(indices, direction='ASCENDING')
370 # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices)
371 # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10
372 # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32)
373
374 return {
375 'probs': probs,
376 'sample': sample,
377 }
378
379
380def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10):

Callers 2

model_fnFunction · 0.70
sample_stepFunction · 0.70

Calls 1

get_shape_listFunction · 0.90

Tested by

no test coverage detected