#![deny(warnings, rust_2018_idioms)]
use std::convert::TryFrom;
use std::io::{self, Write};
use anyhow::{anyhow, bail, Context, Result};
use flate2::write::GzDecoder;
use http::request::Builder;
use hyper::body::HttpBody;
use hyper::header;
use hyper::{Body, HeaderMap, Method, Request, Response, Uri, Version};
use tracing::trace;
/// Shared code for downloading content from HTTP(s) servers.
/// Adds support for gzipped responses, restricting the download size, and checking matching content type.
pub struct Fetcher {
max_length_bytes: usize,
expect_content_type: Option<String>,
http_version: Version,
}
impl Fetcher {
pub fn new(max_length_bytes: usize, expect_content_type: Option<String>) -> Fetcher {
Fetcher {
max_length_bytes,
expect_content_type,
http_version: Version::HTTP_11,
}
}
pub fn use_http_2(mut self) -> Self {
self.http_version = Version::HTTP_2;
self
}
/// Builds a request that advertises support for gzipped payloads and acceptance of a specified content type.
pub fn request_builder(&self, method: &Method, url: &Uri) -> Builder {
let mut builder = Request::builder()
.version(self.http_version)
.method(method)
.uri(url)
.header(header::ACCEPT_ENCODING, "gzip")
// Server may complain if we don't specify something
.header(header::USER_AGENT, "rust/hyper");
if let Some(expect_content_type) = &self.expect_content_type {
builder = builder.header(header::ACCEPT, expect_content_type)
}
builder
}
/// Builds a request that advertises support for gzipped payloads and acceptance of a specified content type.
pub fn build_request(&self, method: &Method, url: &String) -> Result<Request<Body>> {
let uri = Uri::try_from(url)?;
self.request_builder(method, &uri)
.body(Body::empty())
.with_context(|| format!("Failed to build {} {} request", method, url))
}
/// Writes the payload from the response `resp` to the provided output `out`.
/// Supports decompressing gzipped payload data, and will enforce any maximum size and/or content type.
/// `source` can be used for logging the source URL of the response.
pub async fn write_response<W: Write>(
&self,
source: &String,
out: &mut W,
resp: &mut Response<Body>,
) -> Result<()> {
if !resp.status().is_success() {
bail!("{:?} HTTP error: {}", source, resp.status());
}
let headers = resp.headers();
trace!("{} headers: {:?}", source, headers);
// Servers may not return a Content-Length when gzip is enabled.
// We enforce size when streaming the body content anyway.
let content_length_opt = match header_to_str_opt(headers, &header::CONTENT_LENGTH, &source)?
{
Some(len_str) => {
let len = len_str.parse::<usize>().with_context(|| {
format!(
"{:?} response Content-Length cannot be converted to usize: {}",
source, len_str
)
})?;
// Check advertised size, but verify when reading body.
if len > self.max_length_bytes {
bail!(
"{:?} response Content-Length header exceeds maximum {}: {}",
source,
self.max_length_bytes,
len
);
}
Some(len)
}
None => None,
};
if let Some(expect_content_type) = &self.expect_content_type {
let content_type = header_to_str(headers, &header::CONTENT_TYPE, &source)?;
// Content-Type: "text/csv; charset=UTF-8; header=present"
if !content_type.starts_with(expect_content_type) {
bail!(
"{:?} has Content-Type {:?}, expected starts_with({})",
source,
content_type,
expect_content_type
);
}
}
let gzip = match header_to_str_opt(headers, &header::CONTENT_ENCODING, &source)? {
Some(encoding) => "gzip" == encoding,
None => {
// Fall back to Transfer-Encoding
match header_to_str_opt(headers, &header::TRANSFER_ENCODING, &source)? {
Some(encoding) => "gzip" == encoding,
None => false,
}
}
};
if gzip {
self.write_body_gzip(resp, out, content_length_opt, source)
.await
} else {
self.write_body_plain(resp, out, content_length_opt, source)
.await
}
}
async fn write_body_gzip<W: Write>(
&self,
resp: &mut Response<Body>,
out: &mut W,
content_length_opt: Option<usize>,
source: &String,
) -> Result<()> {
let mut downloaded: usize = 0;
// Gzip compression for response input
// Keep track of uncompressed bytes to avoid e.g. a malicious payload filling the disk
let mut decoder = GzDecoder::new(CountingWriter::new(out));
while let Some(next) = resp.data().await {
let chunk = next.with_context(|| format!("{:?} failed to download body", source))?;
if let Some(content_length) = content_length_opt {
if downloaded + chunk.len() > content_length {
bail!(
"{:?} compressed response body length exceeds expected Content-Length {}",
source,
content_length
);
}
}
downloaded += chunk.len();
trace!(
"got chunk {} => {}/{:?}",
chunk.len(),
downloaded,
content_length_opt
);
// Need to use write_all() to actually flush all the input data
// In the gzip case, regular write() will just consume some of the input
// and then it's up to us to loop with the remaining input.
decoder.write_all(&chunk[..])?;
// Get the uncompressed size (so far) from our underlying CountingWriter
if decoder.get_ref().count() > self.max_length_bytes {
bail!(
"{:?} uncompressed response body length exceeds max {}",
source,
self.max_length_bytes
);
}
}
decoder.finish()?;
Ok(())
}
async fn write_body_plain<W: Write>(
&self,
resp: &mut Response<Body>,
out: &mut W,
content_length_opt: Option<usize>,
source: &String,
) -> Result<()> {
let mut downloaded: usize = 0;
// No compression for response input
while let Some(next) = resp.data().await {
let chunk = next.with_context(|| format!("{:?} failed to download body", source))?;
if let Some(content_length) = content_length_opt {
if downloaded + chunk.len() > content_length {
bail!(
"{:?} response body length exceeds expected Content-Length {}",
source,
content_length
);
}
}
downloaded += chunk.len();
trace!(
"got chunk {} => {}/{:?}",
chunk.len(),
downloaded,
content_length_opt
);
out.write(&chunk[..])?;
}
Ok(())
}
}
fn header_to_str(
headers: &HeaderMap,
header: &header::HeaderName,
origin: &String,
) -> Result<String> {
header_to_str_opt(headers, header, origin)?.with_context(|| {
format!(
"{} response has missing {:?}: {:?}",
origin, header, headers
)
})
}
fn header_to_str_opt(
headers: &HeaderMap,
header: &header::HeaderName,
origin: &String,
) -> Result<Option<String>> {
headers
.get(header)
.map_or(Ok(None), |header_val| match header_val.to_str() {
Ok(header_str) => Ok(Some(header_str.to_string())),
Err(e) => Err(anyhow!(
"Failed to convert {} {:?} to string: {:?}",
origin,
header,
e
)),
})
}
/// Pass-through writer that counts the number of bytes that have been written.
/// Used to consistently measure the decompressed size of a download.
struct CountingWriter<W: Write> {
inner: W,
count: usize,
}
impl<W: Write> CountingWriter<W> {
fn new(inner: W) -> CountingWriter<W> {
CountingWriter { inner, count: 0 }
}
fn count(&self) -> usize {
self.count
}
}
impl<W: Write> Write for CountingWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let result = self.inner.write(buf);
if let Ok(_) = result {
self.count += buf.len();
}
return result;
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}