| 49 | |
| 50 | |
| 51 | class EssayInstructor(DataAugmenter): |
| 52 | def __init__(self, model_name=None): |
| 53 | if model_name is None: |
| 54 | model_name = "snrspeaks/t5-one-line-summary" |
| 55 | self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
| 56 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 57 | |
| 58 | def parse_single(self, essay): |
| 59 | essay_paragraphs = essay.split("\n\n") |
| 60 | preds = [] |
| 61 | |
| 62 | for para in essay_paragraphs: |
| 63 | input_ids = self.tokenizer.encode(para, return_tensors="pt", add_special_tokens=True) |
| 64 | generated_ids = self.model.generate( |
| 65 | input_ids=input_ids, |
| 66 | num_beams=5, |
| 67 | max_length=35, |
| 68 | repetition_penalty=4.5, |
| 69 | length_penalty=1.5, |
| 70 | early_stopping=True, |
| 71 | num_return_sequences=1, |
| 72 | ) |
| 73 | preds.append( |
| 74 | self.tokenizer.decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) |
| 75 | ) |
| 76 | |
| 77 | prompts = ( |
| 78 | ["Write an intro paragraph to an essay called"] |
| 79 | + ["Write a paragraph to an essay about"] * len(preds[1:-1]) |
| 80 | + ["Write a concluding paragraph about"] |
| 81 | ) |
| 82 | |
| 83 | assert len(preds) == len(prompts) |
| 84 | prompts = [prompt + " " + pred for prompt, pred in zip(prompts, preds)] |
| 85 | |
| 86 | return prompts, essay_paragraphs |
| 87 | |
| 88 | |
| 89 | class EssayReviser(DataAugmenter): |