@@ 0,0 1,169 @@
+use std::error::Error;
+use std::{fmt::Display, path::PathBuf};
+
+use serde::{Deserialize, Serialize};
+
+// NOTE: We are skipping the query and fragment part of the URI for now, as those are not necessary
+// for our WebFinger use cases just now.
+#[derive(Debug, Eq, Hash, PartialEq)]
+pub struct Uri {
+ scheme: String,
+ authority: String,
+ path: PathBuf,
+}
+
+impl TryFrom<String> for Uri {
+ type Error = UriValidationError;
+
+ fn try_from(s: String) -> Result<Self, Self::Error> {
+ let mut parts = s.split("://");
+
+ let scheme = parts
+ .next()
+ .ok_or(UriValidationError::MissingScheme)
+ .and_then(validate_scheme)?
+ .into();
+
+ let rest = parts.next().ok_or(UriValidationError::MissingDomain)?;
+
+ if parts.next().is_some() {
+ return Err(UriValidationError::InvalidScheme);
+ }
+
+ let mut parts = rest.split("/");
+ let authority = parts
+ .next()
+ .ok_or(UriValidationError::MissingDomain)
+ .and_then(validate_authority)?
+ .into();
+
+ Ok(Uri {
+ scheme,
+ authority,
+ path: parts.collect(),
+ })
+ }
+}
+
+impl TryFrom<&str> for Uri {
+ type Error = UriValidationError;
+
+ fn try_from(s: &str) -> Result<Self, Self::Error> {
+ Uri::try_from(s.to_string())
+ }
+}
+
+impl Display for Uri {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(
+ f,
+ "{}://{}/{}",
+ self.scheme,
+ self.authority,
+ self.path.display()
+ )
+ }
+}
+
+impl<'de> Deserialize<'de> for Uri {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ Uri::try_from(String::deserialize(deserializer)?).map_err(serde::de::Error::custom)
+ }
+}
+
+impl Serialize for Uri {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: serde::Serializer,
+ {
+ self.to_string().serialize(serializer)
+ }
+}
+
+fn validate_scheme(s: &str) -> Result<&str, UriValidationError> {
+ if s.is_empty() {
+ return Err(UriValidationError::MissingScheme);
+ }
+
+ if s.contains("://") {
+ return Err(UriValidationError::InvalidScheme);
+ }
+
+ if !s
+ .bytes()
+ .all(|c| c.is_ascii_alphabetic() || c == b'+' || c == b'-')
+ {
+ return Err(UriValidationError::InvalidScheme);
+ }
+
+ Ok(s)
+}
+
+fn validate_authority(s: &str) -> Result<&str, UriValidationError> {
+ if s.is_empty() {
+ return Err(UriValidationError::MissingDomain);
+ }
+
+ if s.contains("/") {
+ return Err(UriValidationError::InvalidDomain);
+ }
+
+ Ok(s)
+}
+
+#[derive(Debug)]
+pub enum UriValidationError {
+ MissingScheme,
+ InvalidScheme,
+ MissingDomain,
+ InvalidDomain,
+ InvalidPath,
+}
+
+impl Display for UriValidationError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ UriValidationError::MissingScheme => write!(f, "missing scheme"),
+ UriValidationError::InvalidScheme => write!(f, "invalid scheme"),
+ UriValidationError::MissingDomain => write!(f, "missing domain"),
+ UriValidationError::InvalidDomain => write!(f, "invalid domain"),
+ UriValidationError::InvalidPath => write!(f, "invalid path"),
+ }
+ }
+}
+
+impl Error for UriValidationError {}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_uri_deserialize() {
+ let uri: Uri = serde_json::from_str("\"https://example.com/foo/bar\"").unwrap();
+ assert_eq!(
+ uri,
+ Uri {
+ scheme: "https".to_string(),
+ authority: "example.com".to_string(),
+ path: "foo/bar".into(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_uri_serialize() {
+ let uri = Uri {
+ scheme: "https".to_string(),
+ authority: "example.com".to_string(),
+ path: "foo/bar".into(),
+ };
+ assert_eq!(
+ serde_json::to_string(&uri).unwrap(),
+ "\"https://example.com/foo/bar\""
+ );
+ }
+}