Does top-p sampling. if ignore_ids is on, then we will zero out those logits. :param logits: [batch_size, vocab_size] tensor :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, like padding maybe :param
(logits, ignore_ids=None, num_samples=1, p=0.9)
| 325 | |
| 326 | |
| 327 | def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9): |
| 328 | """ |
| 329 | Does top-p sampling. if ignore_ids is on, then we will zero out those logits. |
| 330 | :param logits: [batch_size, vocab_size] tensor |
| 331 | :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, |
| 332 | like padding maybe |
| 333 | :param p: topp threshold to use, either a float or a [batch_size] vector |
| 334 | :return: [batch_size, num_samples] samples |
| 335 | # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK |
| 336 | """ |
| 337 | with tf.compat.v1.variable_scope('top_p_sample'): |
| 338 | batch_size, vocab_size = get_shape_list(logits, expected_rank=2) |
| 339 | |
| 340 | probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, |
| 341 | axis=-1) |
| 342 | |
| 343 | if isinstance(p, float) and p > 0.999999: |
| 344 | # Don't do top-p sampling in this case |
| 345 | print("Top-p sampling DISABLED", flush=True) |
| 346 | return { |
| 347 | 'probs': probs, |
| 348 | 'sample': tf.random.categorical( |
| 349 | logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, |
| 350 | num_samples=num_samples, dtype=tf.int32), |
| 351 | } |
| 352 | |
| 353 | # [batch_size, vocab_perm] |
| 354 | indices = tf.argsort(probs, direction='DESCENDING') |
| 355 | cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False) |
| 356 | |
| 357 | # find the top pth index to cut off. careful we don't want to cutoff everything! |
| 358 | # result will be [batch_size, vocab_perm] |
| 359 | p_expanded = p if isinstance(p, float) else p[:, None] |
| 360 | exclude_mask = tf.logical_not( |
| 361 | tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1)) |
| 362 | |
| 363 | # OPTION A - sample in the sorted space, then unsort. |
| 364 | logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10 |
| 365 | sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples) |
| 366 | sample = tf.batch_gather(indices, sample_perm) |
| 367 | |
| 368 | # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample |
| 369 | # unperm_indices = tf.argsort(indices, direction='ASCENDING') |
| 370 | # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices) |
| 371 | # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10 |
| 372 | # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32) |
| 373 | |
| 374 | return { |
| 375 | 'probs': probs, |
| 376 | 'sample': sample, |
| 377 | } |
| 378 | |
| 379 | |
| 380 | def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10): |
no test coverage detected