MCPcopy
hub / github.com/LAION-AI/Open-Assistant / EssayInstructor

Class EssayInstructor

scripts/data_augment/data_augment.py:51–86  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

49
50
51class EssayInstructor(DataAugmenter):
52 def __init__(self, model_name=None):
53 if model_name is None:
54 model_name = "snrspeaks/t5-one-line-summary"
55 self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
56 self.tokenizer = AutoTokenizer.from_pretrained(model_name)
57
58 def parse_single(self, essay):
59 essay_paragraphs = essay.split("\n\n")
60 preds = []
61
62 for para in essay_paragraphs:
63 input_ids = self.tokenizer.encode(para, return_tensors="pt", add_special_tokens=True)
64 generated_ids = self.model.generate(
65 input_ids=input_ids,
66 num_beams=5,
67 max_length=35,
68 repetition_penalty=4.5,
69 length_penalty=1.5,
70 early_stopping=True,
71 num_return_sequences=1,
72 )
73 preds.append(
74 self.tokenizer.decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
75 )
76
77 prompts = (
78 ["Write an intro paragraph to an essay called"]
79 + ["Write a paragraph to an essay about"] * len(preds[1:-1])
80 + ["Write a concluding paragraph about"]
81 )
82
83 assert len(preds) == len(prompts)
84 prompts = [prompt + " " + pred for prompt, pred in zip(prompts, preds)]
85
86 return prompts, essay_paragraphs
87
88
89class EssayReviser(DataAugmenter):

Callers 1

get_augmenterFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected