vector/sinks/util/
retries.rs

1use std::{
2    borrow::Cow,
3    cmp,
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use futures::FutureExt;
11use tokio::time::{Sleep, sleep};
12use tower::{retry::Policy, timeout::error::Elapsed};
13use vector_lib::configurable::configurable_component;
14
15use crate::Error;
16
17pub enum RetryAction<Request = ()> {
18    /// Indicate that this request should be retried with a reason
19    Retry(Cow<'static, str>),
20    /// Indicate that a portion of this request should be retried with a generic function
21    RetryPartial(Box<dyn Fn(Request) -> Request + Send + Sync>),
22    /// Indicate that this request should not be retried with a reason
23    DontRetry(Cow<'static, str>),
24    /// Indicate that this request should not be retried but the request was successful
25    Successful,
26}
27
28pub trait RetryLogic: Clone + Send + Sync + 'static {
29    type Error: std::error::Error + Send + Sync + 'static;
30    type Request;
31    type Response;
32
33    /// When the Service call returns an `Err` response, this function allows
34    /// implementors to specify what kinds of errors can be retried.
35    fn is_retriable_error(&self, error: &Self::Error) -> bool;
36
37    /// When the Service call times out, this function allows implementors to
38    /// specify if the timeout should be retried.
39    fn is_retriable_timeout(&self) -> bool {
40        true
41    }
42
43    /// When the Service call returns an `Ok` response, this function allows
44    /// implementors to specify additional logic to determine if the success response
45    /// is actually an error. This is particularly useful when the downstream service
46    /// of a sink returns a transport protocol layer success but error data in the
47    /// response body. For example, an HTTP 200 status, but the body of the response
48    /// contains a list of errors encountered while processing.
49    fn should_retry_response(&self, _response: &Self::Response) -> RetryAction<Self::Request> {
50        // Treat the default as the request is successful
51        RetryAction::Successful
52    }
53
54    /// Optional hook run when an error is determined to be retriable.
55    fn on_retriable_error(&self, _error: &Self::Error) {}
56}
57
58/// The jitter mode to use for retry backoff behavior.
59#[configurable_component]
60#[derive(Clone, Copy, Debug, Default)]
61pub enum JitterMode {
62    /// No jitter.
63    None,
64
65    /// Full jitter.
66    ///
67    /// The random delay is anywhere from 0 up to the maximum current delay calculated by the backoff
68    /// strategy.
69    ///
70    /// Incorporating full jitter into your backoff strategy can greatly reduce the likelihood
71    /// of creating accidental denial of service (DoS) conditions against your own systems when
72    /// many clients are recovering from a failure state.
73    #[default]
74    Full,
75}
76
77#[derive(Debug, Clone)]
78pub struct FibonacciRetryPolicy<L> {
79    remaining_attempts: usize,
80    previous_duration: Duration,
81    current_duration: Duration,
82    jitter_mode: JitterMode,
83    current_jitter_duration: Duration,
84    max_duration: Duration,
85    logic: L,
86}
87
88pub struct RetryPolicyFuture {
89    delay: Pin<Box<Sleep>>,
90}
91
92impl<L: RetryLogic> FibonacciRetryPolicy<L> {
93    pub fn new(
94        remaining_attempts: usize,
95        initial_backoff: Duration,
96        max_duration: Duration,
97        logic: L,
98        jitter_mode: JitterMode,
99    ) -> Self {
100        FibonacciRetryPolicy {
101            remaining_attempts,
102            previous_duration: Duration::from_secs(0),
103            current_duration: initial_backoff,
104            jitter_mode,
105            current_jitter_duration: Self::add_full_jitter(initial_backoff),
106            max_duration,
107            logic,
108        }
109    }
110
111    fn add_full_jitter(d: Duration) -> Duration {
112        let jitter = (rand::random::<u64>() % (d.as_millis() as u64)) + 1;
113        Duration::from_millis(jitter)
114    }
115
116    const fn backoff(&self) -> Duration {
117        match self.jitter_mode {
118            JitterMode::None => self.current_duration,
119            JitterMode::Full => self.current_jitter_duration,
120        }
121    }
122
123    fn advance(&mut self) {
124        let sum = self
125            .previous_duration
126            .checked_add(self.current_duration)
127            .unwrap_or(Duration::MAX);
128        let next_duration = cmp::min(sum, self.max_duration);
129        self.remaining_attempts = self.remaining_attempts.saturating_sub(1);
130        self.previous_duration = self.current_duration;
131        self.current_duration = next_duration;
132        self.current_jitter_duration = Self::add_full_jitter(next_duration);
133    }
134
135    fn build_retry(&mut self) -> RetryPolicyFuture {
136        self.advance();
137        let delay = Box::pin(sleep(self.backoff()));
138
139        debug!(message = "Retrying request.", delay_ms = %self.backoff().as_millis());
140        RetryPolicyFuture { delay }
141    }
142}
143
144impl<Req, Res, L> Policy<Req, Res, Error> for FibonacciRetryPolicy<L>
145where
146    Req: Clone + Send + 'static,
147    L: RetryLogic<Request = Req, Response = Res>,
148{
149    type Future = RetryPolicyFuture;
150
151    // NOTE: in the error cases- `Error` and `EventsDropped` internal events are emitted by the
152    // driver, so only need to log here.
153    fn retry(&mut self, req: &mut Req, result: &mut Result<Res, Error>) -> Option<Self::Future> {
154        match result {
155            Ok(response) => match self.logic.should_retry_response(response) {
156                RetryAction::Retry(reason) => {
157                    if self.remaining_attempts == 0 {
158                        error!(
159                            message = "OK/retry response but retries exhausted; dropping the request.",
160                            reason = ?reason,
161                        );
162                        return None;
163                    }
164
165                    warn!(message = "Retrying after response.", reason = %reason);
166                    Some(self.build_retry())
167                }
168                RetryAction::RetryPartial(modify_request) => {
169                    if self.remaining_attempts == 0 {
170                        error!(
171                            message =
172                                "OK/retry response but retries exhausted; dropping the request.",
173                        );
174                        return None;
175                    }
176                    *req = modify_request(req.clone());
177                    warn!("OK/retrying partial after response.");
178                    Some(self.build_retry())
179                }
180                RetryAction::DontRetry(reason) => {
181                    error!(message = "Not retriable; dropping the request.", ?reason);
182                    None
183                }
184
185                RetryAction::Successful => None,
186            },
187            Err(error) => {
188                if self.remaining_attempts == 0 {
189                    error!(message = "Retries exhausted; dropping the request.", %error);
190                    return None;
191                }
192
193                if let Some(expected) = error.downcast_ref::<L::Error>() {
194                    if self.logic.is_retriable_error(expected) {
195                        self.logic.on_retriable_error(expected);
196                        warn!(message = "Retrying after error.", error = %expected);
197                        Some(self.build_retry())
198                    } else {
199                        error!(
200                            message = "Non-retriable error; dropping the request.",
201                            %error,
202                        );
203                        None
204                    }
205                } else if error.downcast_ref::<Elapsed>().is_some() {
206                    if self.logic.is_retriable_timeout() {
207                        warn!(
208                            "Request timed out. If this happens often while the events are actually reaching their destination, try decreasing `batch.max_bytes` and/or using `compression` if applicable. Alternatively `request.timeout_secs` can be increased."
209                        );
210                        Some(self.build_retry())
211                    } else {
212                        error!(
213                            message =
214                                "Request timed out and is not retriable; dropping the request."
215                        );
216                        None
217                    }
218                } else {
219                    error!(
220                        message = "Unexpected error type; dropping the request.",
221                        %error
222                    );
223                    None
224                }
225            }
226        }
227    }
228
229    fn clone_request(&mut self, request: &Req) -> Option<Req> {
230        Some(request.clone())
231    }
232}
233
234// Safety: `L` is never pinned and we use no unsafe pin projections
235// therefore this safe.
236impl Unpin for RetryPolicyFuture {}
237
238impl Future for RetryPolicyFuture {
239    type Output = ();
240
241    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
242        std::task::ready!(self.delay.poll_unpin(cx));
243        Poll::Ready(())
244    }
245}
246
247impl<Request> RetryAction<Request> {
248    pub const fn is_retryable(&self) -> bool {
249        matches!(self, RetryAction::Retry(_) | RetryAction::RetryPartial(_))
250    }
251
252    pub const fn is_not_retryable(&self) -> bool {
253        matches!(self, RetryAction::DontRetry(_))
254    }
255
256    pub const fn is_successful(&self) -> bool {
257        matches!(self, RetryAction::Successful)
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use std::{fmt, time::Duration};
264
265    use tokio::time;
266    use tokio_test::{assert_pending, assert_ready_err, assert_ready_ok, task};
267    use tower::retry::RetryLayer;
268    use tower_test::{assert_request_eq, mock};
269
270    use super::*;
271    use crate::test_util::trace_init;
272
273    #[tokio::test]
274    async fn service_error_retry() {
275        trace_init();
276
277        time::pause();
278
279        let policy = FibonacciRetryPolicy::new(
280            5,
281            Duration::from_secs(1),
282            Duration::from_secs(10),
283            SvcRetryLogic,
284            JitterMode::None,
285        );
286
287        let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
288
289        assert_ready_ok!(svc.poll_ready());
290
291        let fut = svc.call("hello");
292        let mut fut = task::spawn(fut);
293
294        assert_request_eq!(handle, "hello").send_error(Error(true));
295
296        assert_pending!(fut.poll());
297
298        time::advance(Duration::from_secs(2)).await;
299        assert_pending!(fut.poll());
300
301        assert_request_eq!(handle, "hello").send_response("world");
302        assert_eq!(fut.await.unwrap(), "world");
303    }
304
305    #[tokio::test]
306    async fn service_error_no_retry() {
307        trace_init();
308
309        let policy = FibonacciRetryPolicy::new(
310            5,
311            Duration::from_secs(1),
312            Duration::from_secs(10),
313            SvcRetryLogic,
314            JitterMode::None,
315        );
316
317        let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
318
319        assert_ready_ok!(svc.poll_ready());
320
321        let mut fut = task::spawn(svc.call("hello"));
322        assert_request_eq!(handle, "hello").send_error(Error(false));
323        assert_ready_err!(fut.poll());
324    }
325
326    #[tokio::test]
327    async fn timeout_error() {
328        trace_init();
329
330        time::pause();
331
332        let policy = FibonacciRetryPolicy::new(
333            5,
334            Duration::from_secs(1),
335            Duration::from_secs(10),
336            SvcRetryLogic,
337            JitterMode::None,
338        );
339
340        let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
341
342        assert_ready_ok!(svc.poll_ready());
343
344        let mut fut = task::spawn(svc.call("hello"));
345        assert_request_eq!(handle, "hello").send_error(Elapsed::new());
346        assert_pending!(fut.poll());
347
348        time::advance(Duration::from_secs(2)).await;
349        assert_pending!(fut.poll());
350
351        assert_request_eq!(handle, "hello").send_response("world");
352        assert_eq!(fut.await.unwrap(), "world");
353    }
354
355    #[tokio::test]
356    async fn timeout_error_no_retry() {
357        trace_init();
358
359        let policy = FibonacciRetryPolicy::new(
360            5,
361            Duration::from_secs(1),
362            Duration::from_secs(10),
363            NoTimeoutRetryLogic,
364            JitterMode::None,
365        );
366
367        let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
368
369        assert_ready_ok!(svc.poll_ready());
370
371        let mut fut = task::spawn(svc.call("hello"));
372        assert_request_eq!(handle, "hello").send_error(Elapsed::new());
373        assert_ready_err!(fut.poll());
374    }
375
376    #[test]
377    fn backoff_grows_to_max() {
378        let mut policy = FibonacciRetryPolicy::new(
379            10,
380            Duration::from_secs(1),
381            Duration::from_secs(10),
382            SvcRetryLogic,
383            JitterMode::None,
384        );
385        assert_eq!(Duration::from_secs(1), policy.backoff());
386
387        policy.advance();
388        assert_eq!(Duration::from_secs(1), policy.backoff());
389
390        policy.advance();
391        assert_eq!(Duration::from_secs(2), policy.backoff());
392
393        policy.advance();
394        assert_eq!(Duration::from_secs(3), policy.backoff());
395
396        policy.advance();
397        assert_eq!(Duration::from_secs(5), policy.backoff());
398
399        policy.advance();
400        assert_eq!(Duration::from_secs(8), policy.backoff());
401
402        policy.advance();
403        assert_eq!(Duration::from_secs(10), policy.backoff());
404
405        policy.advance();
406        assert_eq!(Duration::from_secs(10), policy.backoff());
407    }
408
409    #[test]
410    fn backoff_grows_to_max_with_jitter() {
411        let max_duration = Duration::from_secs(10);
412        let mut policy = FibonacciRetryPolicy::new(
413            10,
414            Duration::from_secs(1),
415            max_duration,
416            SvcRetryLogic,
417            JitterMode::Full,
418        );
419
420        let expected_fib = [1, 1, 2, 3, 5, 8];
421
422        for (i, &exp_fib_secs) in expected_fib.iter().enumerate() {
423            let backoff = policy.backoff();
424            let upper_bound = Duration::from_secs(exp_fib_secs);
425
426            // Check if the backoff is within the expected range, considering the jitter
427            assert!(
428                !backoff.is_zero() && backoff <= upper_bound,
429                "Attempt {}: Expected backoff to be within 0 and {:?}, got {:?}",
430                i + 1,
431                upper_bound,
432                backoff
433            );
434
435            policy.advance();
436        }
437
438        // Once the max backoff is reached, it should not exceed the max backoff.
439        for _ in 0..4 {
440            let backoff = policy.backoff();
441            assert!(
442                !backoff.is_zero() && backoff <= max_duration,
443                "Expected backoff to not exceed {max_duration:?}, got {backoff:?}"
444            );
445
446            policy.advance();
447        }
448    }
449
450    #[derive(Debug, Clone)]
451    struct SvcRetryLogic;
452
453    impl RetryLogic for SvcRetryLogic {
454        type Error = Error;
455        type Request = &'static str;
456        type Response = &'static str;
457
458        fn is_retriable_error(&self, error: &Self::Error) -> bool {
459            error.0
460        }
461    }
462
463    #[derive(Debug, Clone)]
464    struct NoTimeoutRetryLogic;
465
466    impl RetryLogic for NoTimeoutRetryLogic {
467        type Error = Error;
468        type Request = &'static str;
469        type Response = &'static str;
470
471        fn is_retriable_error(&self, error: &Self::Error) -> bool {
472            error.0
473        }
474
475        fn is_retriable_timeout(&self) -> bool {
476            false
477        }
478    }
479
480    #[derive(Debug)]
481    struct Error(bool);
482
483    impl fmt::Display for Error {
484        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
485            write!(f, "error")
486        }
487    }
488
489    impl std::error::Error for Error {}
490}