ruma_identifiers_validation/
server_name.rs

1use crate::error::Error;
2
3pub fn validate(server_name: &str) -> Result<(), Error> {
4    use std::net::Ipv6Addr;
5
6    if server_name.is_empty() {
7        return Err(Error::InvalidServerName);
8    }
9
10    let end_of_host = if server_name.starts_with('[') {
11        let Some(end_of_ipv6) = server_name.find(']') else {
12            return Err(Error::InvalidServerName);
13        };
14
15        if server_name[1..end_of_ipv6].parse::<Ipv6Addr>().is_err() {
16            return Err(Error::InvalidServerName);
17        }
18
19        end_of_ipv6 + 1
20    } else {
21        #[allow(clippy::unnecessary_lazy_evaluations)]
22        let end_of_host = server_name.find(':').unwrap_or_else(|| server_name.len());
23
24        if end_of_host == 0 {
25            return Err(Error::InvalidServerName);
26        }
27
28        if server_name[..end_of_host]
29            .bytes()
30            .any(|byte| !(byte.is_ascii_alphanumeric() || byte == b'-' || byte == b'.'))
31        {
32            return Err(Error::InvalidServerName);
33        }
34
35        end_of_host
36    };
37
38    if server_name.len() != end_of_host
39        && (
40            // hostname is followed by something other than ":port"
41            server_name.as_bytes()[end_of_host] != b':'
42            // the remaining characters after ':' are not a valid port
43            || server_name[end_of_host + 1..].parse::<u16>().is_err()
44        )
45    {
46        Err(Error::InvalidServerName)
47    } else {
48        Ok(())
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use super::validate;
55    use crate::user_id;
56
57    #[test]
58    fn rejects_hostless_server_name_with_port() {
59        assert!(validate(":8448").is_err());
60    }
61
62    #[test]
63    fn rejects_user_id_with_hostless_server_name() {
64        assert!(user_id::validate("@alice::8448").is_err());
65    }
66}