UsageMiddleware records token usage for inference requests via the billing.Recorder. Two paths produce a record: 1. Handler-stamped (preferred): the request handler called middleware.StampUsage with the canonical token counts before returning. This is the only reliable path for streaming responses
(recorder *billing.Recorder, fallbackUser *auth.User)
| 52 | // Every request that fails to produce a record ticks |
| 53 | // localai_usage_unrecorded_total so silent billing misses are observable. |
| 54 | func UsageMiddleware(recorder *billing.Recorder, fallbackUser *auth.User) echo.MiddlewareFunc { |
| 55 | return func(next echo.HandlerFunc) echo.HandlerFunc { |
| 56 | return func(c echo.Context) error { |
| 57 | if recorder == nil { |
| 58 | return next(c) |
| 59 | } |
| 60 | |
| 61 | startTime := time.Now() |
| 62 | |
| 63 | // Wrap response writer to capture body for the fallback parser. |
| 64 | // When the handler stamps the context we never read this buffer, |
| 65 | // so the cost is the per-chunk Write going through one extra |
| 66 | // indirection — accepted overhead in exchange for one billing |
| 67 | // path that works for both stamping and body-parse callers. |
| 68 | resBody := new(bytes.Buffer) |
| 69 | origWriter := c.Response().Writer |
| 70 | mw := &bodyWriter{ |
| 71 | ResponseWriter: origWriter, |
| 72 | body: resBody, |
| 73 | } |
| 74 | c.Response().Writer = mw |
| 75 | |
| 76 | handlerErr := next(c) |
| 77 | |
| 78 | c.Response().Writer = origWriter |
| 79 | |
| 80 | endpoint := c.Request().URL.Path |
| 81 | |
| 82 | if c.Response().Status < 200 || c.Response().Status >= 300 { |
| 83 | return handlerErr |
| 84 | } |
| 85 | |
| 86 | user := auth.GetUser(c) |
| 87 | if user == nil { |
| 88 | user = fallbackUser |
| 89 | } |
| 90 | if user == nil || user.ID == "" { |
| 91 | billing.CountUnrecorded(context.Background(), endpoint, "no_user") |
| 92 | return handlerErr |
| 93 | } |
| 94 | |
| 95 | model, prompt, completion, total, ok := tokensFromContext(c) |
| 96 | if !ok { |
| 97 | model, prompt, completion, total, ok = tokensFromBody(resBody.Bytes(), c.Response().Header().Get("Content-Type")) |
| 98 | } |
| 99 | if !ok { |
| 100 | billing.CountUnrecorded(context.Background(), endpoint, "no_usage") |
| 101 | return handlerErr |
| 102 | } |
| 103 | |
| 104 | requested, served := modelsFromContext(c, model) |
| 105 | pre, post := promptTokensFromContext(c, prompt) |
| 106 | |
| 107 | source := auth.GetSource(c) |
| 108 | if source == "" { |
| 109 | // Auth disabled or unrecognised path: classify as web so the row is still |
| 110 | // bucketable rather than silently dropped from per-source aggregates. |
| 111 | source = auth.UsageSourceWeb |
no test coverage detected