| 70 | |
| 71 | @app.post("/") |
| 72 | async def create_item(request: Request): |
| 73 | prompt = meta_instruction |
| 74 | json_post_raw = await request.json() |
| 75 | json_post = json.dumps(json_post_raw) |
| 76 | json_post_list = json.loads(json_post) |
| 77 | query = json_post_list.get('prompt') # '<|Human|>: ' + query + '<eoh>' |
| 78 | uid = json_post_list.get('uid', None) |
| 79 | if uid == None or not(uid in history_mp): |
| 80 | uid = str(uuid.uuid4()) |
| 81 | history_mp[uid] = [] |
| 82 | for i, (old_query, response) in enumerate(history_mp[uid]): |
| 83 | prompt += '<|Human|>: ' + old_query + '<eoh>'+response |
| 84 | prompt += '<|Human|>: ' + query + '<eoh>' |
| 85 | max_length = json_post_list.get('max_length', 2048) |
| 86 | top_p = json_post_list.get('top_p', 0.8) |
| 87 | temperature = json_post_list.get('temperature', 0.7) |
| 88 | inputs = tokenizer(prompt, return_tensors="pt") |
| 89 | now = datetime.datetime.now() |
| 90 | time = now.strftime("%Y-%m-%d %H:%M:%S") |
| 91 | inputs = tokenizer(prompt, return_tensors="pt") |
| 92 | with torch.no_grad(): |
| 93 | outputs = model.generate( |
| 94 | inputs.input_ids.cuda(), |
| 95 | attention_mask=inputs.attention_mask.cuda(), |
| 96 | max_length=max_length, |
| 97 | do_sample=True, |
| 98 | top_k=40, |
| 99 | top_p=top_p, |
| 100 | temperature=temperature, |
| 101 | repetition_penalty=1.02, |
| 102 | num_return_sequences=1, |
| 103 | eos_token_id=106068, |
| 104 | pad_token_id=tokenizer.pad_token_id) |
| 105 | response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
| 106 | history_mp[uid] = history_mp[uid] + [(query, response)] |
| 107 | answer = { |
| 108 | "response": response, |
| 109 | "history": history_mp[uid], |
| 110 | "status": 200, |
| 111 | "time": time, |
| 112 | "uid": uid |
| 113 | } |
| 114 | log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' |
| 115 | print(log) |
| 116 | return answer |
| 117 | |
| 118 | if __name__ == "__main__": |
| 119 | uvicorn.run(app, host='0.0.0.0', port=19324, workers=1) |