(request: Request)
| 12 | app = FastAPI() |
| 13 | @app.post('/') |
| 14 | async def visual_glm(request: Request): |
| 15 | json_post_raw = await request.json() |
| 16 | print("Start to process request") |
| 17 | |
| 18 | json_post = json.dumps(json_post_raw) |
| 19 | request_data = json.loads(json_post) |
| 20 | input_text, input_image_encoded, history = request_data['text'], request_data['image'], request_data['history'] |
| 21 | input_para = { |
| 22 | "max_length": 2048, |
| 23 | "min_length": 50, |
| 24 | "temperature": 0.8, |
| 25 | "top_p": 0.4, |
| 26 | "top_k": 100, |
| 27 | "repetition_penalty": 1.2 |
| 28 | } |
| 29 | input_para.update(request_data) |
| 30 | |
| 31 | is_zh = is_chinese(input_text) |
| 32 | input_data = generate_input(input_text, input_image_encoded, history, input_para) |
| 33 | input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs'] |
| 34 | with torch.no_grad(): |
| 35 | answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \ |
| 36 | max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \ |
| 37 | top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh) |
| 38 | |
| 39 | now = datetime.datetime.now() |
| 40 | time = now.strftime("%Y-%m-%d %H:%M:%S") |
| 41 | response = { |
| 42 | "result": answer, |
| 43 | "history": history, |
| 44 | "status": 200, |
| 45 | "time": time |
| 46 | } |
| 47 | return response |
| 48 | |
| 49 | |
| 50 | if __name__ == '__main__': |
nothing calls this directly
no test coverage detected