~nickbp/originz

ref: a8d53481eaf10a685fbb4a4e98c49b13dfbc95ff originz/src/cache/redis.rs -rw-r--r-- 9.6 KiB
a8d53481Nick Parker Initialize scratch buffer size, not just capacity 1 year, 2 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
use std::convert::TryFrom;
use std::time::Duration;

use anyhow::{Context, Result};
use async_trait::async_trait;
use bytes::BytesMut;
use lazy_static::lazy_static;
use redis::{self, IntoConnectionInfo};
use rkyv::{Deserialize, Fallible, Infallible};
use rkyv::ser::{ScratchSpace, Serializer};
use tracing::{debug, info, trace, warn};

use crate::cache::DnsCache;
use crate::codec::message::{self, RequestInfo};
use crate::specs::message::Message;
use crate::specs::version_generated::VERSION_HASH;

lazy_static! {
    /// Script for retrieving both the data and TTL for a TTLed value in a single query.
    static ref GET_WITH_TTL: redis::Script = redis::Script::new("return {redis.call('get',KEYS[1]), redis.call('ttl',KEYS[1])}");
}

/// Client for fetching and storing DNS lookup results in a remote Redis cache
pub struct Cache {
    redis_conn: redis::Connection,
    store_buf: BytesMut,
    scratch_buf: BytesMut,
}

impl Cache {
    pub fn new(url: &str, timeout: Duration) -> Result<Cache> {
        // We avoid logging valid url string since it may contain a password
        let conn_info = url
            .into_connection_info()
            .with_context(|| format!("Failed to parse '{}' as Redis URL", url))?;
        info!("Connecting to Redis with addr={} timeout={}ms", conn_info.addr, timeout.as_millis());
        let redis_client = redis::Client::open(conn_info.clone())
            .with_context(|| format!("Failed to create Redis client for {}", conn_info.addr))?;
        let redis_conn = redis_client
            .get_connection()
            .with_context(|| format!("Redis connection failed for {}", conn_info.addr))?;
        if timeout.as_millis() != 0 {
            redis_conn.set_write_timeout(Some(timeout))?;
            redis_conn.set_read_timeout(Some(timeout))?;
        }
        // Scratch wants buffer to already have some size (not just capacity)
        let mut scratch_buf = BytesMut::with_capacity(1024);
        scratch_buf.resize(scratch_buf.capacity(), 0);
        Ok(Cache {
            redis_conn,
            // Set up a reasonable allocation - will expand automatically if needed.
            store_buf: BytesMut::with_capacity(1024),
            scratch_buf,
        })
    }

    fn key(self: &mut Cache, request_info: &RequestInfo) -> String {
        format!(
            "kapiti_rykv__{}__{:?}__{}",
            VERSION_HASH, request_info.resource_type, request_info.name
        )
    }
}

#[async_trait]
impl DnsCache for Cache {
    /// Queries Redis for a cached response.
    async fn fetch(&mut self, request_info: RequestInfo) -> Result<Option<Message>> {
        let cache_key = self.key(&request_info);
        let response_option: Option<Vec<redis::Value>> = GET_WITH_TTL
            .key(cache_key.clone())
            .invoke(&mut self.redis_conn)
            .with_context(|| format!("Reading cached response failed for {}", request_info.name))?;

        if let Some(response_vec) = response_option {
            return match (response_vec.get(0), response_vec.get(1)) {
                (Some(redis::Value::Data(bytes)), Some(redis::Value::Int(raw_redis_ttl))) => {
                    // Just in case, maybe data wasn't garbage collected yet?
                    if raw_redis_ttl <= &0 {
                        return Ok(None);
                    }
                    let redis_ttl: u32 = u32::try_from(*raw_redis_ttl).with_context(|| {
                        format!("Invalid TTL={} for cache_key='{}'", raw_redis_ttl, cache_key)
                    })?;

                    // Check the returned data and get a nice validation error on any issues.
                    // Skipping this check (unsafe method) meanwhile risks segfaulting the whole process.
                    match rkyv::check_archived_root::<Message>(bytes) {
                        Err(e) => {
                            warn!("Ignoring and deleting corrupt Redis data at key='{}' (ttl={:?}, bytes={:?}): {}", cache_key, redis_ttl, bytes.len(), e);
                            // Try to delete the bad data:
                            let delete_result = redis::cmd("DEL")
                                .arg(&cache_key)
                                .query::<redis::Value>(&mut self.redis_conn);
                            if let Err(e) = delete_result {
                                warn!("Ignoring failed delete of corrupt Redis data at key='{}': {}", cache_key, e);
                            }
                            Ok(None)
                        },
                        Ok(archived) => {
                            let mut redis_response = archived.deserialize(&mut Infallible)?;
                            debug!("Cached response for cache_key='{}': (ttl={:?}, bytes={:?}) {}", cache_key, redis_ttl, bytes.len(), redis_response);

                            if let Err(e) = message::update_cached_response(&mut redis_response, &request_info, redis_ttl) {
                                warn!("Ignoring Redis response at cache_key='{}': {}", cache_key, e);
                                Ok(None)
                            } else {
                                Ok(Some(redis_response))
                            }
                        },
                    }

                }
                (Some(redis::Value::Nil), Some(redis::Value::Int(-2))) => {
                    debug!(
                        "Redis didn't have cached {:?} result for {}",
                        request_info.resource_type,
                        request_info.name
                    );
                    Ok(None)
                }
                (_other_msg, _other_ttl) => {
                    warn!(
                        "Unexpected data in Redis lookup response, bad connection?: {:?}",
                        response_vec
                    );
                    // Give up on cache and query direct
                    Ok(None)
                }
            };
        } else {
            // Shouldn't happen, but may as well handle it
            trace!(
                "Redis didn't have cache {:?} result for {}",
                request_info.resource_type,
                request_info.name
            );
            Ok(None)
        }
    }

    /// Stores an upstream server response Message to the cache.
    async fn store(&mut self, request_info: RequestInfo, response: Message) -> Result<()> {
        match message::get_min_ttl_secs(&response) {
            // This can happen with e.g. a SERVFAIL response
            None => debug!(
                "Skipping storage of {:?} response for {} with missing resources",
                request_info.resource_type, request_info.name
            ),
            // Redis will complain if we try to set with 0 TTL, so don't set anything at all
            // ("invalid expire time in setex")
            Some(0) => debug!(
                "Skipping storage of {:?} result for {} with TTL=0s",
                request_info.resource_type, request_info.name
            ),
            Some(response_min_ttl_secs) => {
                // Serialize message data
                {
                    let mut serializer = BytesMutSerializer::new(&mut self.store_buf, &mut self.scratch_buf);
                    serializer.serialize_value(&response)?;
                }

                let cache_key = self.key(&request_info);
                let store_result = redis::cmd("SETEX")
                    .arg(&cache_key)
                    .arg(response_min_ttl_secs as usize)
                    .arg(&self.store_buf[..])
                    .query::<redis::Value>(&mut self.redis_conn);

                // Clean up buffer for next usage (just updates len=0)
                let size = self.store_buf.len();
                self.store_buf.clear();

                // If cache storage fails, log a warning but don't kill the query.
                match store_result {
                    Ok(_) => debug!(
                        "Stored {} byte response for {:?} {} request to cache_key='{}' with TTL={}s",
                        size, request_info.resource_type, request_info.name, cache_key, response_min_ttl_secs
                    ),
                    Err(e) => warn!(
                        "Failed to store response for {:?} {} request to cache_key='{}', continuing anyway: {:?}",
                        request_info.resource_type, request_info.name, cache_key, e
                    ),
                }
            }
        }
        Ok(())
    }
}

/// Custom rkyv serializer type for interating with BytesMut
struct BytesMutSerializer<'a> {
    store: &'a mut BytesMut,
    scratch: rkyv::ser::serializers::BufferScratch<&'a mut BytesMut>,
}

impl <'a> BytesMutSerializer<'a> {
    pub fn new(store: &'a mut BytesMut, scratch: &'a mut BytesMut) -> Self {
        Self { store, scratch: rkyv::ser::serializers::BufferScratch::new(scratch) }
    }
}

impl <'a> Fallible for BytesMutSerializer<'a> {
    type Error = std::io::Error;
}

impl <'a> Serializer for BytesMutSerializer<'a> {
    fn pos(&self) -> usize {
        self.store.len()
    }

    fn write(&mut self, bytes: &[u8]) -> Result<(), Self::Error> {
        self.store.extend_from_slice(bytes);
        Ok(())
    }
}

impl <'a> ScratchSpace for BytesMutSerializer<'a> {
    unsafe fn push_scratch(&mut self, layout: core::alloc::Layout) -> Result<core::ptr::NonNull<[u8]>, Self::Error> {
        self.scratch.push_scratch(layout)
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
    }

    unsafe fn pop_scratch(&mut self, ptr: core::ptr::NonNull<u8>, layout: core::alloc::Layout) -> Result<(), Self::Error> {
        self.scratch.pop_scratch(ptr, layout)
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
    }
}