| 29 | |
| 30 | @app.post("/generation") |
| 31 | async def generate(data: GenerationTaskReq, request: Request): |
| 32 | logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}') |
| 33 | key = (data.prompt, data.max_tokens) |
| 34 | try: |
| 35 | if cache is None: |
| 36 | raise MissCacheError() |
| 37 | outputs = cache.get(key) |
| 38 | output = random.choice(outputs) |
| 39 | logger.info("Cache hit") |
| 40 | except MissCacheError: |
| 41 | inputs = tokenizer(data.prompt, truncation=True, max_length=512) |
| 42 | inputs["max_tokens"] = data.max_tokens |
| 43 | inputs["top_k"] = data.top_k |
| 44 | inputs["top_p"] = data.top_p |
| 45 | inputs["temperature"] = data.temperature |
| 46 | try: |
| 47 | uid = id(data) |
| 48 | engine.submit(uid, inputs) |
| 49 | output = await engine.wait(uid) |
| 50 | output = tokenizer.decode(output, skip_special_tokens=True) |
| 51 | if cache is not None: |
| 52 | cache.add(key, output) |
| 53 | except QueueFullError as e: |
| 54 | raise HTTPException(status_code=406, detail=e.args[0]) |
| 55 | |
| 56 | return {"text": output} |
| 57 | |
| 58 | |
| 59 | @app.on_event("shutdown") |