(t *testing.T)
| 258 | } |
| 259 | |
| 260 | func TestCreateOAuthMiddleware(t *testing.T) { |
| 261 | t.Parallel() |
| 262 | |
| 263 | const nextText = "handler-ran" |
| 264 | newNext := func(called *bool) mcp.MethodHandler { |
| 265 | return func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { |
| 266 | *called = true |
| 267 | return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: nextText}}}, nil |
| 268 | } |
| 269 | } |
| 270 | |
| 271 | t.Run("non tool call passes through without authenticating", func(t *testing.T) { |
| 272 | t.Parallel() |
| 273 | fake := &fakeAuthenticator{hasToken: false} |
| 274 | var called bool |
| 275 | mw := createOAuthMiddleware(fake, discardLogger()) |
| 276 | _, err := mw(newNext(&called))(context.Background(), "initialize", &mcp.InitializeRequest{}) |
| 277 | require.NoError(t, err) |
| 278 | assert.True(t, called, "next should run") |
| 279 | assert.Zero(t, fake.authCalls, "authentication must not run for non tool calls") |
| 280 | }) |
| 281 | |
| 282 | t.Run("existing token short circuits authentication", func(t *testing.T) { |
| 283 | t.Parallel() |
| 284 | fake := &fakeAuthenticator{hasToken: true} |
| 285 | var called bool |
| 286 | mw := createOAuthMiddleware(fake, discardLogger()) |
| 287 | _, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) |
| 288 | require.NoError(t, err) |
| 289 | assert.True(t, called, "next should run") |
| 290 | assert.Zero(t, fake.authCalls, "authentication must be skipped when a token already exists") |
| 291 | }) |
| 292 | |
| 293 | t.Run("successful authentication proceeds to handler", func(t *testing.T) { |
| 294 | t.Parallel() |
| 295 | fake := &fakeAuthenticator{hasToken: false, outcome: nil, err: nil} |
| 296 | var called bool |
| 297 | mw := createOAuthMiddleware(fake, discardLogger()) |
| 298 | res, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) |
| 299 | require.NoError(t, err) |
| 300 | assert.Equal(t, 1, fake.authCalls) |
| 301 | assert.True(t, called, "next should run once authorized") |
| 302 | callRes, ok := res.(*mcp.CallToolResult) |
| 303 | require.True(t, ok) |
| 304 | require.Len(t, callRes.Content, 1) |
| 305 | assert.Equal(t, nextText, callRes.Content[0].(*mcp.TextContent).Text) |
| 306 | }) |
| 307 | |
| 308 | t.Run("pending user action is surfaced as a tool result", func(t *testing.T) { |
| 309 | t.Parallel() |
| 310 | const message = "Open https://example.com/auth to authorize, then retry." |
| 311 | fake := &fakeAuthenticator{hasToken: false, outcome: &oauth.Outcome{UserAction: &oauth.UserAction{Message: message}}} |
| 312 | var called bool |
| 313 | mw := createOAuthMiddleware(fake, discardLogger()) |
| 314 | res, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) |
| 315 | require.NoError(t, err) |
| 316 | assert.False(t, called, "next must not run while the user still needs to authorize") |
| 317 | callRes, ok := res.(*mcp.CallToolResult) |
nothing calls this directly
no test coverage detected