(t *testing.T)
| 774 | } |
| 775 | |
| 776 | func TestClassifyModelError(t *testing.T) { |
| 777 | t.Parallel() |
| 778 | |
| 779 | tests := []struct { |
| 780 | name string |
| 781 | err error |
| 782 | wantRetryable bool |
| 783 | wantRateLimited bool |
| 784 | wantRetryAfter time.Duration |
| 785 | }{ |
| 786 | {name: "nil", err: nil, wantRetryable: false, wantRateLimited: false}, |
| 787 | {name: "context canceled", err: context.Canceled, wantRetryable: false, wantRateLimited: false}, |
| 788 | {name: "context deadline exceeded", err: context.DeadlineExceeded, wantRetryable: false, wantRateLimited: false}, |
| 789 | {name: "context overflow", err: errors.New("prompt is too long: 200000 tokens > 100000 maximum"), wantRetryable: false, wantRateLimited: false}, |
| 790 | // 429 without StatusError (fallback message-pattern path) |
| 791 | {name: "429 message fallback, no RetryAfter", err: errors.New("POST /v1/chat: 429 Too Many Requests"), wantRetryable: false, wantRateLimited: true, wantRetryAfter: 0}, |
| 792 | // 429 via StatusError (primary path) — no Retry-After |
| 793 | {name: "429 StatusError no retry-after", err: &StatusError{StatusCode: 429, RetryAfter: 0, Err: errors.New("rate limited")}, wantRetryable: false, wantRateLimited: true, wantRetryAfter: 0}, |
| 794 | // 429 via StatusError with Retry-After from response header |
| 795 | {name: "429 StatusError with retry-after", err: &StatusError{StatusCode: 429, RetryAfter: 20 * time.Second, Err: errors.New("rate limited")}, wantRetryable: false, wantRateLimited: true, wantRetryAfter: 20 * time.Second}, |
| 796 | // Retryable status codes via StatusError |
| 797 | {name: "500 StatusError", err: &StatusError{StatusCode: 500, Err: errors.New("internal server error")}, wantRetryable: true, wantRateLimited: false}, |
| 798 | {name: "529 StatusError", err: &StatusError{StatusCode: 529, Err: errors.New("overloaded")}, wantRetryable: true, wantRateLimited: false}, |
| 799 | {name: "408 StatusError", err: &StatusError{StatusCode: 408, Err: errors.New("timeout")}, wantRetryable: true, wantRateLimited: false}, |
| 800 | // Retryable fallback path (message-based) |
| 801 | {name: "500 message fallback", err: errors.New("500 internal server error"), wantRetryable: true, wantRateLimited: false}, |
| 802 | {name: "502 message fallback", err: errors.New("502 bad gateway"), wantRetryable: true, wantRateLimited: false}, |
| 803 | // Non-retryable via StatusError |
| 804 | {name: "401 StatusError", err: &StatusError{StatusCode: 401, Err: errors.New("unauthorized")}, wantRetryable: false, wantRateLimited: false}, |
| 805 | {name: "403 StatusError", err: &StatusError{StatusCode: 403, Err: errors.New("forbidden")}, wantRetryable: false, wantRateLimited: false}, |
| 806 | // Non-retryable fallback |
| 807 | {name: "401 message fallback", err: errors.New("401 unauthorized"), wantRetryable: false, wantRateLimited: false}, |
| 808 | // 400 with Vertex AI "function response parts" message is treated as transient (issue #2683) |
| 809 | {name: "vertex transient 400 StatusError", err: &StatusError{StatusCode: 400, Err: errors.New("Error 400, Message: Please ensure that the number of function response parts is equal to the number of function call parts of the function call turn., Status: INVALID_ARGUMENT, Details: []")}, wantRetryable: true, wantRateLimited: false}, |
| 810 | {name: "vertex transient 400 wrapped in stream error", err: fmt.Errorf("error receiving from stream: %w", &StatusError{StatusCode: 400, Err: errors.New("number of function response parts")}), wantRetryable: true, wantRateLimited: false}, |
| 811 | {name: "vertex transient 400 message fallback (no StatusError)", err: errors.New("400 Bad Request: Please ensure that the number of function response parts is equal to the number of function call parts"), wantRetryable: true, wantRateLimited: false}, |
| 812 | // Network errors |
| 813 | {name: "network timeout", err: &mockTimeoutError{}, wantRetryable: true, wantRateLimited: false}, |
| 814 | } |
| 815 | |
| 816 | for _, tt := range tests { |
| 817 | t.Run(tt.name, func(t *testing.T) { |
| 818 | t.Parallel() |
| 819 | retryable, rateLimited, retryAfterOut := ClassifyModelError(tt.err) |
| 820 | assert.Equal(t, tt.wantRetryable, retryable, "retryable mismatch") |
| 821 | assert.Equal(t, tt.wantRateLimited, rateLimited, "rateLimited mismatch") |
| 822 | assert.Equal(t, tt.wantRetryAfter, retryAfterOut, "retryAfter mismatch") |
| 823 | }) |
| 824 | } |
| 825 | |
| 826 | t.Run("wrapped StatusError is found by errors.As", func(t *testing.T) { |
| 827 | t.Parallel() |
| 828 | statusErr := &StatusError{StatusCode: 429, RetryAfter: 15 * time.Second, Err: errors.New("rate limited")} |
| 829 | wrapped := fmt.Errorf("model failed: %w", statusErr) |
| 830 | retryable, rateLimited, retryAfterOut := ClassifyModelError(wrapped) |
| 831 | assert.False(t, retryable) |
| 832 | assert.True(t, rateLimited) |
| 833 | assert.Equal(t, 15*time.Second, retryAfterOut) |
nothing calls this directly
no test coverage detected