| 239 | } |
| 240 | |
| 241 | func (s *Server) csrfProtect(h http.Handler) http.Handler { |
| 242 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 243 | // CSRF is not required for GET, HEAD, or OPTIONS requests. |
| 244 | if slices.Contains([]string{"GET", "HEAD", "OPTIONS"}, r.Method) { |
| 245 | h.ServeHTTP(w, r) |
| 246 | return |
| 247 | } |
| 248 | |
| 249 | // first attempt to use Sec-Fetch-Site header (sent by all modern |
| 250 | // browsers to "potentially trustworthy" origins i.e. localhost or those |
| 251 | // served over HTTPS) |
| 252 | secFetchSite := r.Header.Get("Sec-Fetch-Site") |
| 253 | if secFetchSite == "same-origin" { |
| 254 | h.ServeHTTP(w, r) |
| 255 | return |
| 256 | } else if secFetchSite != "" { |
| 257 | http.Error(w, fmt.Sprintf("CSRF request denied with Sec-Fetch-Site %q", secFetchSite), http.StatusForbidden) |
| 258 | return |
| 259 | } |
| 260 | |
| 261 | // if Sec-Fetch-Site is not available we presume we are operating over HTTP. |
| 262 | // We fall back to comparing the Origin & Host headers. |
| 263 | |
| 264 | // use the Host header to determine the expected origin |
| 265 | // (use the override if set to allow for reverse proxying) |
| 266 | host := r.Host |
| 267 | if host == "" { |
| 268 | http.Error(w, "CSRF request denied with no Host header", http.StatusForbidden) |
| 269 | return |
| 270 | } |
| 271 | if s.originOverride != "" { |
| 272 | host = s.originOverride |
| 273 | } |
| 274 | |
| 275 | originHeader := r.Header.Get("Origin") |
| 276 | if originHeader == "" { |
| 277 | http.Error(w, "CSRF request denied with no Origin header", http.StatusForbidden) |
| 278 | return |
| 279 | } |
| 280 | parsedOrigin, err := url.Parse(originHeader) |
| 281 | if err != nil { |
| 282 | http.Error(w, fmt.Sprintf("CSRF request denied with invalid Origin %q", r.Header.Get("Origin")), http.StatusForbidden) |
| 283 | return |
| 284 | } |
| 285 | origin := parsedOrigin.Host |
| 286 | if origin == "" { |
| 287 | http.Error(w, "CSRF request denied with no host in the Origin header", http.StatusForbidden) |
| 288 | return |
| 289 | } |
| 290 | |
| 291 | if origin != host { |
| 292 | http.Error(w, fmt.Sprintf("CSRF request denied with mismatched Origin %q and Host %q", origin, host), http.StatusForbidden) |
| 293 | return |
| 294 | } |
| 295 | |
| 296 | h.ServeHTTP(w, r) |
| 297 | |
| 298 | }) |