(input, chatbot, max_length, top_p, temperature, history)
| 113 | |
| 114 | |
| 115 | def predict(input, chatbot, max_length, top_p, temperature, history): |
| 116 | query = parse_text(input) |
| 117 | chatbot.append((query, "")) |
| 118 | prompt = meta_instruction |
| 119 | for i, (old_query, response) in enumerate(history): |
| 120 | prompt += '<|Human|>: ' + old_query + '<eoh>'+response |
| 121 | prompt += '<|Human|>: ' + query + '<eoh>' |
| 122 | inputs = tokenizer(prompt, return_tensors="pt") |
| 123 | with torch.no_grad(): |
| 124 | outputs = model.generate( |
| 125 | inputs.input_ids.cuda(), |
| 126 | attention_mask=inputs.attention_mask.cuda(), |
| 127 | max_length=max_length, |
| 128 | do_sample=True, |
| 129 | top_k=40, |
| 130 | top_p=top_p, |
| 131 | temperature=temperature, |
| 132 | num_return_sequences=1, |
| 133 | eos_token_id=106068, |
| 134 | pad_token_id=tokenizer.pad_token_id) |
| 135 | response = tokenizer.decode( |
| 136 | outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
| 137 | |
| 138 | chatbot[-1] = (query, parse_text(response.replace("<|MOSS|>: ", ""))) |
| 139 | history = history + [(query, response)] |
| 140 | print(f"chatbot is {chatbot}") |
| 141 | print(f"history is {history}") |
| 142 | |
| 143 | return chatbot, history |
| 144 | |
| 145 | |
| 146 | def reset_user_input(): |
nothing calls this directly
no test coverage detected