vector/sources/util/grpc/
decompression.rs

1use std::{
2    cmp,
3    future::Future,
4    io::{self, Write},
5    mem,
6    pin::Pin,
7    sync::LazyLock,
8    task::{Context, Poll},
9};
10
11use bytes::{Buf, BufMut, BytesMut};
12use flate2::write::GzDecoder;
13use futures_util::FutureExt;
14use http::{HeaderValue, Request, Response};
15use hyper::{
16    Body,
17    body::{HttpBody, Sender},
18};
19use tokio::{pin, select};
20use tonic::{Status, body::BoxBody, metadata::AsciiMetadataValue};
21use tower::{Layer, Service};
22use vector_lib::internal_event::{
23    ByteSize, BytesReceived, InternalEventHandle as _, Protocol, Registered,
24};
25
26use crate::internal_events::{GrpcError, GrpcInvalidCompressionSchemeError};
27
28// Every gRPC message has a five byte header:
29// - a compressed flag (u8, 0/1 for compressed/decompressed)
30// - a length prefix, indicating the number of remaining bytes to read (u32)
31const GRPC_MESSAGE_HEADER_LEN: usize = mem::size_of::<u8>() + mem::size_of::<u32>();
32const GRPC_ENCODING_HEADER: &str = "grpc-encoding";
33const GRPC_ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
34
35// The encodings this layer advertises to clients via `grpc-accept-encoding`.
36// Each variant maps to a `CompressionScheme` (or `None` for `identity`) through
37// `to_scheme`, so adding a variant here forces the decompression match to be
38// updated and the advertised list cannot drift from the schemes actually handled.
39#[derive(Clone, Copy)]
40enum AdvertisedEncoding {
41    Gzip,
42    Zstd,
43    Identity,
44}
45
46impl AdvertisedEncoding {
47    const ALL: &'static [Self] = &[Self::Gzip, Self::Zstd, Self::Identity];
48
49    const fn as_str(self) -> &'static str {
50        match self {
51            Self::Gzip => "gzip",
52            Self::Zstd => "zstd",
53            Self::Identity => "identity",
54        }
55    }
56
57    fn parse(s: &str) -> Option<Self> {
58        Self::ALL.iter().copied().find(|e| e.as_str() == s)
59    }
60
61    // `identity` is the gRPC no-op encoding: the request body is already
62    // uncompressed, so there's nothing to decompress.
63    const fn to_scheme(self) -> Option<CompressionScheme> {
64        match self {
65            Self::Gzip => Some(CompressionScheme::Gzip),
66            Self::Zstd => Some(CompressionScheme::Zstd),
67            Self::Identity => None,
68        }
69    }
70}
71
72// Advertised to clients via `grpc-accept-encoding`. Derived from
73// `AdvertisedEncoding::ALL` so this layer is the single owner of gRPC compression
74// negotiation for all Vector gRPC sources and the header value cannot drift from
75// the set of schemes actually handled.
76static GRPC_ACCEPT_ENCODING_VALUE: LazyLock<String> = LazyLock::new(|| {
77    AdvertisedEncoding::ALL
78        .iter()
79        .map(|e| e.as_str())
80        .collect::<Vec<_>>()
81        .join(",")
82});
83
84enum CompressionScheme {
85    Gzip,
86    Zstd,
87}
88
89impl CompressionScheme {
90    fn from_encoding_header(req: &Request<Body>) -> Result<Option<Self>, Status> {
91        req.headers()
92            .get(GRPC_ENCODING_HEADER)
93            .map(|s| {
94                s.to_str().map(|s| s.to_string()).map_err(|_| {
95                    Status::unimplemented(format!(
96                        "`{GRPC_ENCODING_HEADER}` contains non-visible characters and is not a valid encoding"
97                    ))
98                })
99            })
100            .transpose()
101            .and_then(|value| match value {
102                None => Ok(None),
103                Some(scheme) => match AdvertisedEncoding::parse(&scheme) {
104                    Some(encoding) => Ok(encoding.to_scheme()),
105                    None => Err(Status::unimplemented(format!(
106                        "compression scheme `{scheme}` is not supported"
107                    ))),
108                },
109            })
110            .map_err(|mut status| {
111                status.metadata_mut().insert(
112                    GRPC_ACCEPT_ENCODING_HEADER,
113                    AsciiMetadataValue::try_from(GRPC_ACCEPT_ENCODING_VALUE.as_str())
114                        .expect("advertised encoding value must be valid ASCII"),
115                );
116                status
117            })
118    }
119}
120
121#[derive(Default)]
122enum State {
123    #[default]
124    WaitingForHeader,
125    Forward {
126        overall_len: usize,
127    },
128    Decompress {
129        remaining: usize,
130    },
131}
132
133enum Decompressor {
134    Gzip(Box<GzDecoder<Vec<u8>>>),
135    Zstd {
136        compressed: Vec<u8>,
137        output_buf: Vec<u8>,
138    },
139}
140
141impl Decompressor {
142    fn new(scheme: &CompressionScheme) -> Result<Self, io::Error> {
143        // Create the backing buffer for the decompressor and set the compression flag to false (0)
144        // and pre-allocate the space for the length prefix, which we'll fill out once we've
145        // finalized the decompressor.
146        let buf = vec![0; GRPC_MESSAGE_HEADER_LEN];
147        match scheme {
148            CompressionScheme::Gzip => Ok(Decompressor::Gzip(Box::new(GzDecoder::new(buf)))),
149            CompressionScheme::Zstd => Ok(Decompressor::Zstd {
150                compressed: Vec::new(),
151                output_buf: buf,
152            }),
153        }
154    }
155
156    fn write_all(&mut self, data: &[u8]) -> io::Result<()> {
157        match self {
158            Decompressor::Gzip(d) => d.write_all(data),
159            Decompressor::Zstd { compressed, .. } => {
160                compressed.extend_from_slice(data);
161                Ok(())
162            }
163        }
164    }
165
166    fn finish(self) -> io::Result<Vec<u8>> {
167        match self {
168            Decompressor::Gzip(d) => (*d).finish(),
169            // Decode directly into output_buf to avoid a temporary intermediate Vec that
170            // decode_all would produce; peak memory is compressed + decompressed rather than
171            // compressed + 2 × decompressed.
172            Decompressor::Zstd {
173                compressed,
174                mut output_buf,
175            } => {
176                zstd::stream::copy_decode(io::Cursor::new(&compressed), &mut output_buf)?;
177                Ok(output_buf)
178            }
179        }
180    }
181}
182
183async fn drive_body_decompression(
184    mut source: Body,
185    mut destination: Sender,
186    scheme: Option<CompressionScheme>,
187) -> Result<usize, Status> {
188    let mut state = State::default();
189    let mut buf = BytesMut::new();
190    let mut decompressor: Option<Decompressor> = None;
191    let mut bytes_received = 0;
192
193    // Drain all message chunks from the body first.
194    while let Some(result) = source.data().await {
195        let chunk = result.map_err(|_| Status::internal("failed to read from underlying body"))?;
196        buf.put(chunk);
197
198        let maybe_message = loop {
199            match state {
200                State::WaitingForHeader => {
201                    // If we don't have enough data yet to even read the gRPC message header, we can't do anything yet.
202                    if buf.len() < GRPC_MESSAGE_HEADER_LEN {
203                        break None;
204                    }
205
206                    // Extract the compressed flag and length prefix.
207                    let (is_compressed, message_len) = {
208                        let header = &buf[..GRPC_MESSAGE_HEADER_LEN];
209
210                        let message_len_raw: u32 = header[1..]
211                            .try_into()
212                            .map(u32::from_be_bytes)
213                            .expect("there must be four bytes remaining in the header slice");
214                        let message_len = message_len_raw
215                            .try_into()
216                            .expect("Vector does not support 16-bit platforms");
217
218                        (header[0] == 1, message_len)
219                    };
220
221                    // Now, if the message is not compressed, then put ourselves into forward mode, where we'll wait for
222                    // the rest of the message to come in -- decoding isn't streaming so there's no benefit there --
223                    // before we emit it.
224                    //
225                    // If the message _is_ compressed, we do roughly the same thing but we shove it into the
226                    // decompressor incrementally because there's no good reason to make both the internal buffer and
227                    // the decompressor buffer expand if we don't have to.
228                    if is_compressed {
229                        // Per the gRPC compression spec, the compressed flag requires a
230                        // negotiated encoding. Reject frames that set it under identity
231                        // (or with no `grpc-encoding` header) rather than silently
232                        // falling back to gzip and masking client/server mismatches.
233                        if scheme.is_none() {
234                            return Err(Status::internal(
235                                "received compressed frame but no compression scheme was negotiated",
236                            ));
237                        }
238
239                        // We skip the header in the buffer because it doesn't matter to the decompressor and we
240                        // recreate it anyways.
241                        buf.advance(GRPC_MESSAGE_HEADER_LEN);
242
243                        state = State::Decompress {
244                            remaining: message_len,
245                        };
246                    } else {
247                        let overall_len = GRPC_MESSAGE_HEADER_LEN + message_len;
248                        state = State::Forward { overall_len };
249                    }
250                }
251                State::Forward { overall_len } => {
252                    // All we're doing at this point is waiting until we have all the bytes for the current gRPC message
253                    // before we emit them to the caller.
254                    if buf.len() < overall_len {
255                        break None;
256                    }
257
258                    // Now that we have all the bytes we need, slice them out of our internal buffer, reset our state,
259                    // and hand the message back to the caller.
260                    let message = buf.split_to(overall_len).freeze();
261                    state = State::WaitingForHeader;
262
263                    bytes_received += overall_len;
264
265                    break Some(message);
266                }
267                State::Decompress { ref mut remaining } => {
268                    if *remaining > 0 {
269                        // We're waiting for `remaining` more bytes to feed to the decompressor before we finalize it and
270                        // generate our new chunk of data. We might have data in our internal buffer, so try and drain that
271                        // first before polling the underlying body for more.
272                        let available = buf.len();
273                        if available > 0 {
274                            // Write the lesser of what the buffer has, or what is remaining for the current message, into
275                            // the decompressor. This is _technically_ synchronous but there's really no way to do it
276                            // asynchronously since we already have the data, and that's the only asynchronous part.
277                            let to_take = cmp::min(available, *remaining);
278                            let d = match &mut decompressor {
279                                Some(d) => d,
280                                slot @ None => {
281                                    let scheme = scheme.as_ref().expect(
282                                        "compressed frames without a negotiated scheme are rejected earlier",
283                                    );
284                                    slot.insert(Decompressor::new(scheme).map_err(|_| {
285                                        Status::internal("failed to initialize decompressor")
286                                    })?)
287                                }
288                            };
289                            if d.write_all(&buf[..to_take]).is_err() {
290                                return Err(Status::internal("failed to write to decompressor"));
291                            }
292
293                            *remaining -= to_take;
294                            buf.advance(to_take);
295                        } else {
296                            break None;
297                        }
298                    } else {
299                        // We don't need any more data, so consume the decompressor, finalize it by updating the length
300                        // prefix, and then pass it back to the caller.
301                        let result = decompressor
302                            .take()
303                            .expect("consumed decompressor when no decompressor was present")
304                            .finish();
305
306                        // The only I/O errors that occur during `finish` should be I/O errors from writing to the internal
307                        // buffer, but `Vec<T>` is infallible in this regard, so this should be impossible without having
308                        // first panicked due to memory exhaustion.
309                        let mut buf = result.map_err(|_| {
310                            Status::internal(
311                                "reached impossible error during decompressor finalization",
312                            )
313                        })?;
314                        bytes_received += buf.len();
315
316                        // Write the length of our decompressed message in the pre-allocated slot for the message's length prefix.
317                        let message_len_actual = buf.len() - GRPC_MESSAGE_HEADER_LEN;
318                        let message_len = u32::try_from(message_len_actual).map_err(|_| {
319                            Status::out_of_range("messages greater than 4GB are not supported")
320                        })?;
321
322                        let message_len_bytes = message_len.to_be_bytes();
323                        let message_len_slot = &mut buf[1..GRPC_MESSAGE_HEADER_LEN];
324                        message_len_slot.copy_from_slice(&message_len_bytes[..]);
325
326                        // Reset our state before returning the decompressed message.
327                        state = State::WaitingForHeader;
328
329                        break Some(buf.into());
330                    }
331                }
332            }
333        };
334
335        if let Some(message) = maybe_message {
336            // We got a decompressed (or passthrough) message chunk, so just forward it to the destination.
337            if destination.send_data(message).await.is_err() {
338                return Err(Status::internal("destination body abnormally closed"));
339            }
340        }
341    }
342
343    // When we've exhausted all the message chunks, we try sending any trailers that came in on the underlying body.
344    let result = source.trailers().await;
345    let maybe_trailers =
346        result.map_err(|_| Status::internal("error reading trailers from underlying body"))?;
347    if let Some(trailers) = maybe_trailers
348        && destination.send_trailers(trailers).await.is_err()
349    {
350        return Err(Status::internal("destination body abnormally closed"));
351    }
352
353    Ok(bytes_received)
354}
355
356async fn drive_request<F, E>(
357    source: Body,
358    destination: Sender,
359    inner: F,
360    bytes_received: Registered<BytesReceived>,
361    scheme: Option<CompressionScheme>,
362) -> Result<Response<BoxBody>, E>
363where
364    F: Future<Output = Result<Response<BoxBody>, E>>,
365    E: std::fmt::Display,
366{
367    let body_decompression = drive_body_decompression(source, destination, scheme);
368
369    pin!(inner);
370    pin!(body_decompression);
371
372    let mut body_eof = false;
373    let mut body_bytes_received = 0;
374
375    let mut result = loop {
376        select! {
377            biased;
378
379            // Drive the inner future, as this will be consuming the message chunks we give it.
380            result = &mut inner => break result,
381
382            // Drive the core decompression loop, reading chunks from the underlying body, decompressing them if needed,
383            // and eventually handling trailers at the end, if they're present.
384            result = &mut body_decompression, if !body_eof => match result {
385                Err(e) => break Ok(e.to_http()),
386                Ok(bytes_received) => {
387                    body_bytes_received = bytes_received;
388                    body_eof = true;
389                },
390            }
391        }
392    };
393
394    // If the response indicates success, then emit the necessary metrics
395    // otherwise emit the error.
396    match &result {
397        Ok(res) if res.status().is_success() => {
398            bytes_received.emit(ByteSize(body_bytes_received));
399        }
400        Ok(res) => {
401            emit!(GrpcError {
402                error: format!("Received {}", res.status())
403            });
404        }
405        Err(error) => {
406            emit!(GrpcError { error: &error });
407        }
408    };
409
410    // Advertise the set of compression schemes this layer can accept to the client.
411    // Since this layer is the single owner of compression negotiation, individual
412    // services no longer call `.accept_compressed(..)` and therefore tonic would not
413    // set this header itself.
414    if let Ok(res) = result.as_mut() {
415        res.headers_mut().insert(
416            GRPC_ACCEPT_ENCODING_HEADER,
417            HeaderValue::from_str(&GRPC_ACCEPT_ENCODING_VALUE)
418                .expect("advertised encoding value must be valid ASCII"),
419        );
420    }
421
422    result
423}
424
425#[derive(Clone)]
426pub struct DecompressionAndMetrics<S> {
427    inner: S,
428    bytes_received: Registered<BytesReceived>,
429}
430
431impl<S> Service<Request<Body>> for DecompressionAndMetrics<S>
432where
433    S: Service<Request<Body>, Response = Response<BoxBody>> + Clone + Send + 'static,
434    S::Future: Send + 'static,
435    S::Error: std::fmt::Display,
436{
437    type Response = Response<BoxBody>;
438    type Error = S::Error;
439    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
440
441    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
442        self.inner.poll_ready(cx)
443    }
444
445    fn call(&mut self, req: Request<Body>) -> Self::Future {
446        match CompressionScheme::from_encoding_header(&req) {
447            // There was a header for the encoding, but it was either invalid data or a scheme we don't support.
448            Err(status) => {
449                emit!(GrpcInvalidCompressionSchemeError { status: &status });
450                Box::pin(async move { Ok(status.to_http()) })
451            }
452
453            // The request either isn't using compression, or it has indicated compression may be used and we know we
454            // can support decompression based on the indicated compression scheme... so wrap the body to decompress, if
455            // need be, and then track the bytes that flowed through.
456            Ok(scheme) => {
457                let (destination, decompressed_body) = Body::channel();
458                let (mut req_parts, req_body) = req.into_parts();
459                // Since this layer owns compression negotiation and is about to hand the
460                // inner service a fully decompressed body (with the per-message compressed
461                // flag cleared), strip the `grpc-encoding` header so tonic's codegen treats
462                // the request as uncompressed and does not try to validate the encoding
463                // against any per-service `accept_compressed(..)` configuration.
464                if scheme.is_some() {
465                    req_parts.headers.remove(GRPC_ENCODING_HEADER);
466                }
467                let mapped_req = Request::from_parts(req_parts, decompressed_body);
468
469                let inner = self.inner.call(mapped_req);
470
471                drive_request(
472                    req_body,
473                    destination,
474                    inner,
475                    self.bytes_received.clone(),
476                    scheme,
477                )
478                .boxed()
479            }
480        }
481    }
482}
483
484/// A layer for decompressing Tonic request payloads and emitting telemetry for the payload sizes.
485///
486/// In some cases, we configure `tonic` to use compression on requests to save CPU and throughput when sending those
487/// large requests. In the case of Vector-to-Vector communication, this means the Vector v2 source may deal with
488/// compressed requests. The code already transparently handles decompression, but as part of our component
489/// specification, we have specific goals around what event representations we pay attention to.
490///
491/// In the case of tracking bytes sent/received, we always want to track the number of bytes received _after_
492/// decompression to faithfully represent the amount of data being processed by Vector. This poses a problem with the
493/// out-of-the-box `tonic` codegen as there is no hook whatsoever to inspect the raw request payload (after
494/// decompression, if it was compressed at all) prior to the payload being decoded as a Protocol Buffers payload.
495///
496/// This layer wraps the incoming body in our own body type, which allows us to do two things: decompress the payload
497/// before it enters the decoding phase, and emit metrics based on the decompressed payload.
498///
499/// Since we can see the decompressed bytes, and also know if the underlying service responded successfully -- i.e. the
500/// request was valid, and was processed -- we can now report the number of bytes (after decompression) that were
501/// received _and_ processed correctly.
502///
503/// The supported compression schemes are gzip and zstd.
504#[derive(Clone, Default)]
505pub struct DecompressionAndMetricsLayer;
506
507impl<S> Layer<S> for DecompressionAndMetricsLayer {
508    type Service = DecompressionAndMetrics<S>;
509
510    fn layer(&self, inner: S) -> Self::Service {
511        DecompressionAndMetrics {
512            inner,
513            bytes_received: register!(BytesReceived::from(Protocol::from("grpc"))),
514        }
515    }
516}