#![deny(warnings)]
use std::convert::TryFrom;
use std::path::{Path, PathBuf};
use std::time::Duration;
use std::vec::Vec;
use anyhow::{Context, Result};
use hyper::client::HttpConnector;
use hyper::{Body, Client, Uri};
use hyper_rustls::HttpsConnector;
use rustls::ClientConfig;
use sha2::{Digest, Sha256};
use tracing::{debug, warn};
use tracing_attributes::instrument;
use crate::client::hyper::Resolver;
use crate::filter::{downloader, path, reader};
use crate::http::Fetcher;
/// An iterator that goes over the parent domains of a provided child domain.
/// For example, www.domain.com => [www.domain.com, domain.com, com]
struct DomainParentIter<'a> {
domain: &'a String,
start_idx: usize,
}
impl<'a> DomainParentIter<'a> {
fn new(domain: &'a String) -> DomainParentIter<'a> {
DomainParentIter {
domain,
start_idx: 0,
}
}
}
impl<'a> Iterator for DomainParentIter<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<&'a str> {
if self.start_idx >= self.domain.len() {
// Seeked past end of domain string, nothing left
None
} else {
// Collect this result: everything from start_idx
let remainder = &self.domain[self.start_idx..];
// Update start for next result
match remainder.find('.') {
Some(idx) => {
// idx is within remainder's address space, which starts at start_idx
// add 1 to seek past the '.' itself
self.start_idx += idx + 1;
}
None => {
self.start_idx = self.domain.len();
}
}
Some(remainder)
}
}
}
pub struct Filter {
overrides: Vec<reader::FileEntries>,
blocks: Vec<reader::FileEntries>,
download_dir: PathBuf,
fetch_client: Client<HttpsConnector<HttpConnector<Resolver>>, Body>,
}
impl Filter {
pub fn new(download_dir: PathBuf, resolver: Resolver) -> Result<Filter> {
let fetch_client = build_fetch_client(resolver, 10000)?;
Ok(Filter {
overrides: vec![],
blocks: vec![],
download_dir,
fetch_client,
})
}
pub fn update_override(self: &mut Filter, override_path: &String) -> Result<()> {
let file_entries = reader::read(reader::FileInfo {
source_path: override_path.clone(),
local_path: override_path.clone(),
filter_type: reader::FilterType::OVERRIDE,
})?;
// Before adding new entry, check for existing entry to be replaced/updated.
upsert_entries(&mut self.overrides, file_entries);
Ok(())
}
#[instrument(skip(self))] // skip non-Debug stuff
pub async fn update_block(self: &mut Filter, hosts_entry: &String) -> Result<()> {
let download_path_str =
update_url(&self.fetch_client, &self.download_dir, &hosts_entry).await?;
let info = reader::FileInfo {
source_path: hosts_entry.clone(),
local_path: download_path_str,
filter_type: reader::FilterType::BLOCK,
};
let file_entries = reader::read(info)?;
if !file_entries.is_empty() {
// Note: In theory we could dedupe entries across different blocks to save memory.
// However this causes problems if we want to granularly update individual files.
// For example if file A had a hostname that we omit from file B, and then file A is updated
// to no longer mention that hostname, we'd want to reintroduce it into file B.
// So for now the marginal gain likely isn't worth the complexity, but in the future we could
// rebuild + dedupe a single monolithic tree each time ANY file is updated.
// BUT this makes it harder to source filter decisions since everything will be merged.
upsert_entries(&mut self.blocks, file_entries);
}
Ok(())
}
pub fn check(self: &Filter, host: &String) -> Option<(&reader::FileInfo, &reader::FileEntry)> {
// Go over domains in ancestor order, checking all blocks for each ancestor.
// For example check all files for 'www.example.com', then each again for 'example.com'.
// This allows file B with 'www.example.com' to take precedence over file A with 'example.com'
// Meanwhile if two files mention the exact same domain then the first file in the list wins.
// So if file A says "127.0.0.1" and file B says "172.16.0.1" then "127.0.0.1" wins.
for domain_str in DomainParentIter::new(&host) {
let domain = domain_str.to_string();
for override_entry in &self.overrides {
match override_entry.get(&domain) {
// Found in an override file: Tell upstream to let it through or use provided override value
Some(entry) => return Some((override_entry.info(), entry)),
None => {}
}
}
for block in &self.blocks {
match block.get(&domain) {
// Found in block: Tell upstream to block it or use filter-provided override
Some(entry) => return Some((block.info(), entry)),
None => {}
}
}
}
return None;
}
}
/// Build HTTP connector which queries our configured source DNS server.
/// In particular, AVOID querying the system DNS, which may just loop back to us.
fn build_fetch_client(
resolver: Resolver,
timeout_ms: u64,
) -> Result<Client<HttpsConnector<HttpConnector<Resolver>>, Body>> {
let mut http_connector = HttpConnector::<_>::new_with_resolver(resolver);
http_connector.set_connect_timeout(Some(Duration::from_millis(timeout_ms)));
http_connector.set_happy_eyeballs_timeout(Some(Duration::from_millis(timeout_ms)));
http_connector.set_keepalive(Some(Duration::from_secs(90)));
// Required or else we get errors when trying to pass through https urls, see also HttpsConnector::new_():
http_connector.enforce_http(false);
// Build HTTPS connector that wraps HTTP connector. Allows HTTPS but doesn't require it.
let mut https_config = ClientConfig::new();
https_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
match rustls_native_certs::load_native_certs() {
Ok(certs) => {
https_config.root_store = certs;
}
Err((Some(certs), e)) => {
warn!(
"Some TLS certificates failed to load, trying to continue without them: {:?}",
e
);
https_config.root_store = certs;
}
Err((None, e)) => {
return Err(e).with_context(|| "Failed to load native TLS cert store");
}
}
https_config.ct_logs = Some(&ct_logs::LOGS);
Ok(Client::builder()
.build::<HttpsConnector<_>, Body>(HttpsConnector::from((http_connector, https_config))))
}
async fn update_url(
fetch_client: &Client<HttpsConnector<HttpConnector<Resolver>>, Body>,
download_dir: &PathBuf,
hosts_path: &String,
) -> Result<String> {
match Uri::try_from(hosts_path) {
// Parsed as a URL, try to download
Ok(host_uri) => {
if host_uri.scheme() == None {
// Filesystem paths can get parsed as URLs with no scheme
return Ok(hosts_path.to_string());
}
let fetcher = Fetcher::new(10 * 1024 * 1024, None);
// We download files to the exact SHA of the URL string we were provided.
// This is an easy way to avoid filename collisions in URLs: example1.com/hosts vs example2.com/hosts
// If the user changes the URL string then that changes the SHA, perfect for "cache invalidation" purposes.
let hosts_path_sha = Sha256::digest(hosts_path.as_bytes());
let download_path = Path::new(download_dir).join(format!(
"{:x}.sha256.{}",
hosts_path_sha,
path::ZSTD_EXTENSION
));
downloader::update_file(
fetch_client,
&fetcher,
&hosts_path.to_string(),
download_path.as_path(),
)
.await?;
Ok(download_path
.to_str()
.with_context(|| format!("busted download path: {:?}", download_path))?
.to_string())
}
// Couldn't parse as URL, assume it's a local file
Err(_) => {
debug!("file: {}", hosts_path);
Ok(hosts_path.to_string())
}
}
}
fn upsert_entries(entries: &mut Vec<reader::FileEntries>, new_entry: reader::FileEntries) {
// Before adding new entry, check for existing entry to be replaced/updated.
for i in 0..entries.len() {
let entry = entries.get(i).expect("incoherent vector size");
if entry.info.local_path == new_entry.info.local_path {
// Delete or replace existing version
if new_entry.is_empty() {
entries.remove(i);
} else {
entries.insert(i, new_entry);
}
return;
}
}
// Add new entry
if !new_entry.is_empty() {
entries.push(new_entry);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn iter_empty() {
let domain = "".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(None, iter.next());
}
#[test]
fn iter_com() {
let domain = "com".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(Some("com"), iter.next());
assert_eq!(None, iter.next());
}
#[test]
fn iter_domaincom() {
let domain = "domain.com".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(Some("domain.com"), iter.next());
assert_eq!(Some("com"), iter.next());
assert_eq!(None, iter.next());
}
#[test]
fn iter_wwwdomaincom() {
let domain = "www.domain.com".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(Some("www.domain.com"), iter.next());
assert_eq!(Some("domain.com"), iter.next());
assert_eq!(Some("com"), iter.next());
assert_eq!(None, iter.next());
}
#[test]
fn iter_wwwngeeknz() {
let domain = "www.n.geek.nz".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(Some("www.n.geek.nz"), iter.next());
assert_eq!(Some("n.geek.nz"), iter.next());
assert_eq!(Some("geek.nz"), iter.next());
assert_eq!(Some("nz"), iter.next());
assert_eq!(None, iter.next());
}
#[test]
fn iter_averylongteststringwithmanysegments() {
let domain = "a.very-long.test.string.with-many.segments".to_string();
let mut iter = DomainParentIter::new(&domain);
assert_eq!(
Some("a.very-long.test.string.with-many.segments"),
iter.next()
);
assert_eq!(
Some("very-long.test.string.with-many.segments"),
iter.next()
);
assert_eq!(Some("test.string.with-many.segments"), iter.next());
assert_eq!(Some("string.with-many.segments"), iter.next());
assert_eq!(Some("with-many.segments"), iter.next());
assert_eq!(Some("segments"), iter.next());
assert_eq!(None, iter.next());
}
}