(context_tokens, device, args)
| 84 | |
| 85 | |
| 86 | def get_batch(context_tokens, device, args): |
| 87 | tokens = context_tokens |
| 88 | tokens = tokens.view(args.batch_size, -1).contiguous() |
| 89 | tokens = tokens.to(device) |
| 90 | |
| 91 | # Get the masks and postition ids. |
| 92 | attention_mask, loss_mask, position_ids = get_masks_and_position_ids( |
| 93 | tokens, |
| 94 | args.eod_token, |
| 95 | args.reset_position_ids, |
| 96 | args.reset_attention_mask) |
| 97 | |
| 98 | return tokens, attention_mask, position_ids |
| 99 | |
| 100 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): |
| 101 | # This function has been mostly taken from huggingface conversational ai code at |
no test coverage detected