use std::convert::TryFrom; use std::fs; use std::path::{Path, PathBuf}; use std::vec::Vec; use anyhow::{bail, Context, Result}; use hyper::{Body, Client, Uri}; use sha2::{Digest, Sha256}; use tracing::debug; use tracing_attributes::instrument; use crate::filter::{downloader, path, reader}; use crate::{http, hyper_smol, resolver}; /// An iterator that goes over the parent domains of a provided child domain. /// For example, www.domain.com => [www.domain.com, domain.com, com] struct DomainParentIter<'a> { domain: &'a String, start_idx: usize, } impl<'a> DomainParentIter<'a> { fn new(domain: &'a String) -> DomainParentIter<'a> { DomainParentIter { domain, start_idx: 0, } } } impl<'a> Iterator for DomainParentIter<'a> { type Item = &'a str; fn next(&mut self) -> Option<&'a str> { if self.start_idx >= self.domain.len() { // Seeked past end of domain string, nothing left None } else { // Collect this result: everything from start_idx let remainder = &self.domain[self.start_idx..]; // Update start for next result match remainder.find('.') { Some(idx) => { // idx is within remainder's address space, which starts at start_idx // add 1 to seek past the '.' itself self.start_idx += idx + 1; } None => { self.start_idx = self.domain.len(); } } Some(remainder) } } } pub struct Filter { overrides: Vec, blocks: Vec, filters_dir: PathBuf, fetch_client: Client, } impl Filter { pub fn new(filters_dir: PathBuf, resolver: resolver::Resolver) -> Result { let fetch_client = hyper_smol::client_kapiti(resolver, false, false, 4096); if !filters_dir.exists() { fs::create_dir(&filters_dir).with_context(|| { format!( "Failed to create filter download directory: {:?}", filters_dir ) })?; } else if filters_dir.is_file() { bail!( "Filter download directory configured storage path is a regular file: {:?}", filters_dir ); } Ok(Filter { overrides: vec![], blocks: vec![], filters_dir, fetch_client, }) } pub fn update_override(self: &mut Filter, override_path: &String) -> Result<()> { let file_entries = reader::read( reader::FilterType::OVERRIDE, reader::FileInfo { source_path: override_path.clone(), local_path: override_path.clone() } )?; // Before adding new entry, check for existing entry to be replaced/updated. upsert_entries(&mut self.overrides, file_entries); Ok(()) } #[instrument(skip(self))] // skip non-Debug stuff pub async fn update_block( self: &mut Filter, hosts_entry: &String, timeout_ms: u64, ) -> Result<()> { let download_path_str = update_url( &self.fetch_client, &self.filters_dir, &hosts_entry, timeout_ms, ) .await?; let file_entries = reader::read( reader::FilterType::BLOCK, reader::FileInfo { source_path: hosts_entry.clone(), local_path: download_path_str, } )?; if !file_entries.is_empty() { // Note: In theory we could dedupe entries across different blocks to save memory. // However this causes problems if we want to granularly update individual files. // For example if file A had a hostname that we omit from file B, and then file A is updated // to no longer mention that hostname, we'd want to reintroduce it into file B. // So for now the marginal gain likely isn't worth the complexity, but in the future we could // rebuild + dedupe a single monolithic tree each time ANY file is updated. // BUT this makes it harder to source filter decisions since everything will be merged. upsert_entries(&mut self.blocks, file_entries); } Ok(()) } pub fn set_hardcoded_block(self: &mut Filter, block_names: &[&str]) -> Result<()> { let hardcoded_entries = reader::block_hardcoded(block_names)?; upsert_entries(&mut self.blocks, hardcoded_entries); Ok(()) } pub fn check(self: &Filter, host: &String) -> Option<(&Option, &reader::FilterEntry)> { // Go over domains in ancestor order, checking all blocks for each ancestor. // For example check all files for 'www.example.com', then each again for 'example.com'. // This allows file B with 'www.example.com' to take precedence over file A with 'example.com' // Meanwhile if two files mention the exact same domain then the first file in the list wins. // So if file A says "127.0.0.1" and file B says "172.16.0.1" then "127.0.0.1" wins. for domain_str in DomainParentIter::new(&host) { let domain = domain_str.to_string(); for override_entry in &self.overrides { match override_entry.get(&domain) { // Found in an override file: Tell upstream to let it through or use provided override value Some(entry) => return Some((&override_entry.info, entry)), None => {} } } for block in &self.blocks { match block.get(&domain) { // Found in block: Tell upstream to block it or use filter-provided override Some(entry) => return Some((&block.info, entry)), None => {} } } } return None; } } async fn update_url( fetch_client: &Client, filters_dir: &PathBuf, hosts_path: &String, timeout_ms: u64, ) -> Result { if let Ok(host_uri) = Uri::try_from(hosts_path) { // Parsed as a URL, try to download if host_uri.scheme() == None { // Filesystem paths can get parsed as URLs with no scheme return Ok(hosts_path.to_string()); } let fetcher = http::Fetcher::new(10 * 1024 * 1024, None); // We download files to the exact SHA of the URL string we were provided. // This is an easy way to avoid filename collisions in URLs: example1.com/hosts vs example2.com/hosts // If the user changes the URL string then that changes the SHA, perfect for "cache invalidation" purposes. let hosts_path_sha = Sha256::digest(hosts_path.as_bytes()); let download_path = Path::new(filters_dir).join(format!( "{:x}.sha256.{}", hosts_path_sha, path::ZSTD_EXTENSION )); downloader::update_file( fetch_client, &fetcher, &hosts_path.to_string(), download_path.as_path(), timeout_ms, ) .await?; Ok(download_path .to_str() .with_context(|| format!("busted download path: {:?}", download_path))? .to_string()) } else { // Couldn't parse as URL, assume it's a local file debug!("file: {}", hosts_path); Ok(hosts_path.to_string()) } } fn upsert_entries(entries: &mut Vec, new_entry: reader::FilterEntries) { if let Some(new_file_info) = &new_entry.info { // Before adding a new file entry, check for an existing file entry to be replaced/updated. for i in 0..entries.len() { let entry = entries.get(i).expect("incoherent vector size"); if let Some(existing_file_info) = &entry.info { if existing_file_info.local_path == new_file_info.local_path { // Delete or replace existing version if new_entry.is_empty() { entries.remove(i); } else { entries.insert(i, new_entry); } return; } } } } // Add new entry if !new_entry.is_empty() { entries.push(new_entry); } } #[cfg(test)] mod tests { use super::*; #[test] fn iter_empty() { let domain = "".to_string(); let mut iter = DomainParentIter::new(&domain); assert_eq!(None, iter.next()); } #[test] fn iter_com() { let domain = "com".to_string(); let mut iter = DomainParentIter::new(&domain); assert_eq!(Some("com"), iter.next()); assert_eq!(None, iter.next()); } #[test] fn iter_domaincom() { let domain = "domain.com".to_string(); let mut iter = DomainParentIter::new(&domain); assert_eq!(Some("domain.com"), iter.next()); assert_eq!(Some("com"), iter.next()); assert_eq!(None, iter.next()); } #[test] fn iter_wwwdomaincom() { let domain = "www.domain.com".to_string(); let mut iter = DomainParentIter::new(&domain); assert_eq!(Some("www.domain.com"), iter.next()); assert_eq!(Some("domain.com"), iter.next()); assert_eq!(Some("com"), iter.next()); assert_eq!(None, iter.next()); } #[test] fn iter_wwwngeeknz() { let domain = "www.n.geek.nz".to_string(); let mut iter = DomainParentIter::new(&domain); assert_eq!(Some("www.n.geek.nz"), iter.next()); assert_eq!(Some("n.geek.nz"), iter.next()); assert_eq!(Some("geek.nz"), iter.next()); assert_eq!(Some("nz"), iter.next()); assert_eq!(None, iter.next()); } #[test] fn iter_averylongteststringwithmanysegments() { let domain = "a.very-long.test.string.with-many.segments".to_string(); let mut iter = DomainParentIter::new(&domain); assert_eq!( Some("a.very-long.test.string.with-many.segments"), iter.next() ); assert_eq!( Some("very-long.test.string.with-many.segments"), iter.next() ); assert_eq!(Some("test.string.with-many.segments"), iter.next()); assert_eq!(Some("string.with-many.segments"), iter.next()); assert_eq!(Some("with-many.segments"), iter.next()); assert_eq!(Some("segments"), iter.next()); assert_eq!(None, iter.next()); } }