toolInvokeHandler handles the API request to invoke a specific Tool.
(s *Server, w http.ResponseWriter, r *http.Request)
| 127 | |
| 128 | // toolInvokeHandler handles the API request to invoke a specific Tool. |
| 129 | func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { |
| 130 | ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/invoke") |
| 131 | r = r.WithContext(ctx) |
| 132 | ctx = util.WithLogger(r.Context(), s.logger) |
| 133 | |
| 134 | toolName := chi.URLParam(r, "toolName") |
| 135 | s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) |
| 136 | span.SetAttributes(attribute.String("tool_name", toolName)) |
| 137 | var err error |
| 138 | defer func() { |
| 139 | if err != nil { |
| 140 | span.SetStatus(codes.Error, err.Error()) |
| 141 | } |
| 142 | span.End() |
| 143 | }() |
| 144 | |
| 145 | tool, ok := s.ResourceMgr.GetTool(toolName) |
| 146 | if !ok { |
| 147 | err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) |
| 148 | s.logger.DebugContext(ctx, err.Error()) |
| 149 | _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) |
| 150 | return |
| 151 | } |
| 152 | |
| 153 | // Extract OAuth access token from the "Authorization" header (currently for |
| 154 | // BigQuery end-user credentials usage only) |
| 155 | accessToken := tools.AccessToken(r.Header.Get("Authorization")) |
| 156 | |
| 157 | // Check if this specific tool requires the standard authorization header |
| 158 | clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr) |
| 159 | if err != nil { |
| 160 | errMsg := fmt.Errorf("error during invocation: %w", err) |
| 161 | s.logger.DebugContext(ctx, errMsg.Error()) |
| 162 | _ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound)) |
| 163 | return |
| 164 | } |
| 165 | if clientAuth { |
| 166 | if accessToken == "" { |
| 167 | err = fmt.Errorf("tool requires client authorization but access token is missing from the request header") |
| 168 | s.logger.DebugContext(ctx, err.Error()) |
| 169 | _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) |
| 170 | return |
| 171 | } |
| 172 | } |
| 173 | |
| 174 | // Tool authentication |
| 175 | // claimsFromAuth maps the name of the authservice to the claims retrieved from it. |
| 176 | claimsFromAuth := make(map[string]map[string]any) |
| 177 | for _, aS := range s.ResourceMgr.GetAuthServiceMap() { |
| 178 | var claims map[string]any |
| 179 | var err error |
| 180 | |
| 181 | cfg := aS.ToConfig() |
| 182 | if genCfg, ok := cfg.(generic.Config); ok && genCfg.McpEnabled { |
| 183 | claims = util.AuthTokenClaimsFromContext(ctx) |
| 184 | } else { |
| 185 | claims, err = aS.GetClaimsFromHeader(ctx, r.Header) |
| 186 | if err != nil { |
no test coverage detected