| 53 | |
| 54 | # generation |
| 55 | def generate(prompt): |
| 56 | inputs = tokenizer(prompt, return_tensors="pt") |
| 57 | inputs.pop("token_type_ids", None) |
| 58 | inputs = inputs.to(device) |
| 59 | |
| 60 | with torch.no_grad(): |
| 61 | output = model.generate( |
| 62 | **inputs, |
| 63 | |
| 64 | max_new_tokens=MAX_NEW_TOKENS, |
| 65 | do_sample=True, |
| 66 | |
| 67 | temperature=TEMPERATURE, |
| 68 | top_p=TOP_P, |
| 69 | repetition_penalty=REPETITION_PENALTY, |
| 70 | |
| 71 | eos_token_id=tokenizer.eos_token_id, |
| 72 | pad_token_id=tokenizer.eos_token_id, |
| 73 | |
| 74 | ) |
| 75 | |
| 76 | return tokenizer.decode(output[0], skip_special_tokens=True) |
| 77 | |
| 78 | print("\n GENERATIONS\n" + "=" * 50) |
| 79 | for p in PROMPTS: |