MCPcopy
hub / github.com/mudler/LocalAI / UsageMiddleware

Function UsageMiddleware

core/http/middleware/usage.go:54–146  ·  view source on GitHub ↗

UsageMiddleware records token usage for inference requests via the billing.Recorder. Two paths produce a record: 1. Handler-stamped (preferred): the request handler called middleware.StampUsage with the canonical token counts before returning. This is the only reliable path for streaming responses

(recorder *billing.Recorder, fallbackUser *auth.User)

Source from the content-addressed store, hash-verified

52// Every request that fails to produce a record ticks
53// localai_usage_unrecorded_total so silent billing misses are observable.
54func UsageMiddleware(recorder *billing.Recorder, fallbackUser *auth.User) echo.MiddlewareFunc {
55 return func(next echo.HandlerFunc) echo.HandlerFunc {
56 return func(c echo.Context) error {
57 if recorder == nil {
58 return next(c)
59 }
60
61 startTime := time.Now()
62
63 // Wrap response writer to capture body for the fallback parser.
64 // When the handler stamps the context we never read this buffer,
65 // so the cost is the per-chunk Write going through one extra
66 // indirection — accepted overhead in exchange for one billing
67 // path that works for both stamping and body-parse callers.
68 resBody := new(bytes.Buffer)
69 origWriter := c.Response().Writer
70 mw := &bodyWriter{
71 ResponseWriter: origWriter,
72 body: resBody,
73 }
74 c.Response().Writer = mw
75
76 handlerErr := next(c)
77
78 c.Response().Writer = origWriter
79
80 endpoint := c.Request().URL.Path
81
82 if c.Response().Status < 200 || c.Response().Status >= 300 {
83 return handlerErr
84 }
85
86 user := auth.GetUser(c)
87 if user == nil {
88 user = fallbackUser
89 }
90 if user == nil || user.ID == "" {
91 billing.CountUnrecorded(context.Background(), endpoint, "no_user")
92 return handlerErr
93 }
94
95 model, prompt, completion, total, ok := tokensFromContext(c)
96 if !ok {
97 model, prompt, completion, total, ok = tokensFromBody(resBody.Bytes(), c.Response().Header().Get("Content-Type"))
98 }
99 if !ok {
100 billing.CountUnrecorded(context.Background(), endpoint, "no_usage")
101 return handlerErr
102 }
103
104 requested, served := modelsFromContext(c, model)
105 pre, post := promptTokensFromContext(c, prompt)
106
107 source := auth.GetSource(c)
108 if source == "" {
109 // Auth disabled or unrecognised path: classify as web so the row is still
110 // bucketable rather than silently dropped from per-source aggregates.
111 source = auth.UsageSourceWeb

Callers 4

RegisterOllamaRoutesFunction · 0.92
RegisterAnthropicRoutesFunction · 0.92
RegisterOpenAIRoutesFunction · 0.92

Calls 15

GetUserFunction · 0.92
CountUnrecordedFunction · 0.92
GetSourceFunction · 0.92
GetAPIKeyFunction · 0.92
tokensFromContextFunction · 0.85
tokensFromBodyFunction · 0.85
modelsFromContextFunction · 0.85
promptTokensFromContextFunction · 0.85
correlationIDFromContextFunction · 0.85
HeaderMethod · 0.80
RequestMethod · 0.65
GetMethod · 0.65

Tested by

no test coverage detected