| 164 | } |
| 165 | |
| 166 | func TestOnNthCompletion(t *testing.T) { |
| 167 | t.Run("callback is only called on n-th invocation", func(t *testing.T) { |
| 168 | var ( |
| 169 | n = 5 // expect invocation on 5th attempt |
| 170 | errCalled = errors.New("called") //nolint:err113 |
| 171 | callbackInvoked int |
| 172 | callback = func() error { |
| 173 | callbackInvoked++ |
| 174 | return errCalled |
| 175 | } |
| 176 | ) |
| 177 | |
| 178 | onNthCompletion := parallelwork.OnNthCompletion(n, callback) |
| 179 | |
| 180 | // before n-th invocation |
| 181 | for range n - 1 { |
| 182 | err := onNthCompletion() |
| 183 | require.NoError(t, err) |
| 184 | require.Equal(t, 0, callbackInvoked) |
| 185 | } |
| 186 | |
| 187 | // on n-th invocation |
| 188 | err := onNthCompletion() |
| 189 | require.Error(t, err) |
| 190 | require.ErrorIs(t, err, errCalled) |
| 191 | require.Equal(t, 1, callbackInvoked) |
| 192 | |
| 193 | // call once again (after n-th invocation) |
| 194 | err = onNthCompletion() |
| 195 | require.NoError(t, err) |
| 196 | require.Equal(t, 1, callbackInvoked) |
| 197 | }) |
| 198 | |
| 199 | t.Run("concurrency-safe", func(t *testing.T) { |
| 200 | var ( |
| 201 | n = 5 // expect invocation on 5th attempt |
| 202 | results = make(chan error, n+1) // we will have n+1, i.e. 6 attempts in total |
| 203 | errCalled = errors.New("called") //nolint:err113 |
| 204 | callbackInvoked atomic.Int32 |
| 205 | wg sync.WaitGroup |
| 206 | callback = func() error { |
| 207 | callbackInvoked.Add(1) |
| 208 | return errCalled |
| 209 | } |
| 210 | ) |
| 211 | |
| 212 | onNthCompletion := parallelwork.OnNthCompletion(n, callback) |
| 213 | |
| 214 | wg.Add(n + 1) |
| 215 | |
| 216 | for range n + 1 { |
| 217 | go func() { |
| 218 | results <- onNthCompletion() |
| 219 | |
| 220 | wg.Done() |
| 221 | }() |
| 222 | } |
| 223 | |