~nickbp/kapiti

ref: 9d460d0b7ab39ac1e39576b34fa6217049c65edb kapiti/src/cache/retainer.rs -rw-r--r-- 5.0 KiB
9d460d0bNick Parker Update dependencies to latest along with corresponding code changes 2 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use std::sync::Arc;
use std::time::Duration;

use anyhow::{Context, Result};
use async_trait::async_trait;
use tracing::{debug, info, warn};

use crate::cache::DnsCache;
use crate::codec::message::{self, RequestInfo};
use crate::specs::message::Message;

/// Client for fetching and storing DNS lookup results in a local Retainer cache
pub struct Cache {
    cache: Arc<retainer::cache::Cache<String, Message>>,
    _monitor: smol::Task<()>,
    max_records: usize,
}

impl Cache {
    pub fn new(max_records: usize) -> Cache {
        info!("Using local retainer cache with max_records={}", max_records);
        let cache = Arc::new(retainer::cache::Cache::new());
        let cache_clone = cache.clone();
        // Set up background job to clean expired entries
        let monitor = smol::spawn(async move {
            // Every 3s, check 4 entries for expiration.
            // If >25% were expired, then repeat the check for another 4 entries
            cache_clone.monitor(4, 0.25, Duration::from_secs(3)).await
        });
        Cache {
            cache,
            _monitor: monitor,
            max_records,
        }
    }

    fn key(self: &mut Cache, request_info: &RequestInfo) -> String {
        format!("{:?}__{}", request_info.resource_type, request_info.name)
    }
}

#[async_trait]
impl DnsCache for Cache {
    /// Queries Redis for a cached response.
    async fn fetch(&mut self, request_info: RequestInfo) -> Result<Option<Message>> {
        let cache_key = self.key(&request_info);
        // TODO for now, only remove() exposes the full CacheEntry, while get() only exposes the value
        //      so lets 'remove' the entry to get the expiration, then put it back.
        //      the cache task is operating on a single thread/loop, so this should be safe
        match self.cache.remove(&cache_key).await {
            Some(entry) => {
                let expiration = entry.expiration().with_context(|| {
                    format!("Cache entry lacks expiration for cache_key='{}'", cache_key)
                })?;
                let cache_ttl = expiration.remaining();
                debug!("Cached response for cache_key='{}': (ttl={:?}) {}", cache_key, cache_ttl, entry.value());
                let mut updated_response = (*entry.value()).clone();
                if let Err(e) = message::update_cached_response(&mut updated_response, &request_info, cache_ttl.as_secs() as u32) {
                    warn!("Ignoring cache response at cache_key='{}': {}", cache_key, e);
                    // Skip putting it back since it's apparently bad anyway.
                    Ok(None)
                } else {
                    // Put the original value back since we just removed it to get the full CacheEntry
                    // TODO drop clone: only here due to "shared reference" via entry
                    //      but this whole thing is a hack anyway so let's just clone for now
                    self.cache.insert(cache_key, (*entry.value()).clone(), *expiration.instant()).await;
                    Ok(Some(updated_response))
                }
            },
            None => {
                debug!(
                    "Cache didn't have cached {:?} result for {}",
                    request_info.resource_type,
                    request_info.name
                );
                Ok(None)
            }
        }
    }

    /// Stores an upstream server response Message to the cache.
    async fn store(&mut self, request_info: RequestInfo, response: Message) -> Result<()> {
        // Before storing, check if the cache has reached the max record count, and clear it if so.
        // This avoids unbounded storage and favors new retrievals over old cached values with long TTLs.
        if self.max_records > 0 {
            let cache_len = self.cache.len().await;
            if cache_len >= self.max_records {
                warn!("Clearing local cache: size={} max_records={}", cache_len, self.max_records);
                self.cache.clear().await;
            }
        }

        match message::get_min_ttl_secs(&response) {
            // This can happen with e.g. a SERVFAIL response
            None => debug!(
                "Skipping storage of {:?} response for {} with missing resources",
                request_info.resource_type, request_info.name
            ),
            // Not much point in storing something that expires right away
            // (Might not happen in practice?)
            Some(0) => debug!(
                "Skipping storage of {:?} result for {} with TTL=0s",
                request_info.resource_type, request_info.name
            ),
            Some(response_min_ttl) => {
                let cache_key = self.key(&request_info);
                debug!(
                    "Stored response for {:?} {} request to cache_key='{}' with TTL={}s",
                    request_info.resource_type, request_info.name, cache_key, response_min_ttl
                );
                self.cache.insert(cache_key, response, Duration::from_secs(response_min_ttl as u64)).await;
            }
        }
        Ok(())
    }
}