grant/config/
role_database.rs

1use super::role::RoleValidate;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6/// Role Database Level.
7///
8/// For example:
9///
10/// ```yaml
11/// - name: role_database_level
12///   type: database
13///   grants:
14///     - CREATE
15///     - TEMP
16///   databases:
17///     - db1
18///     - db2
19/// ```
20#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
21pub struct RoleDatabaseLevel {
22    pub name: String,
23    pub grants: Vec<String>,
24    pub databases: Vec<String>,
25}
26
27impl RoleDatabaseLevel {
28    /// Escape and quote a PostgreSQL identifier to prevent SQL injection
29    fn escape_identifier(ident: &str) -> String {
30        // PostgreSQL identifiers are quoted with double quotes
31        // Escape double quotes by doubling them
32        format!("\"{}\"", ident.replace("\"", "\"\""))
33    }
34
35    /// Generate role database to SQL.
36    ///
37    /// ```sql
38    /// { GRANT | REVOKE } { { CREATE | TEMPORARY | TEMP } [,...] | ALL [ PRIVILEGES ] }
39    /// ON DATABASE db_name [, ...]
40    /// TO { username [ WITH GRANT OPTION ] | GROUP group_name | PUBLIC } [, ...]
41    /// ```
42    pub fn to_sql(&self, user: &str) -> String {
43        // grant all if no grants specified or contains "ALL"
44        let grants = if self.grants.is_empty() || self.grants.contains(&"ALL".to_string()) {
45            "ALL PRIVILEGES".to_string()
46        } else {
47            self.grants.join(", ")
48        };
49
50        // escape database and user identifiers to prevent SQL injection
51        let escaped_databases = self
52            .databases
53            .iter()
54            .map(|db| Self::escape_identifier(db))
55            .collect::<Vec<_>>()
56            .join(", ");
57        let escaped_user = Self::escape_identifier(user);
58
59        // grant on databases to user
60        let sql = format!(
61            "GRANT {} ON DATABASE {} TO {};",
62            grants, escaped_databases, escaped_user
63        );
64
65        sql
66    }
67}
68
69impl RoleValidate for RoleDatabaseLevel {
70    fn validate(&self) -> Result<()> {
71        if self.name.is_empty() {
72            return Err(anyhow!("role name is empty"));
73        }
74
75        if self.databases.is_empty() {
76            return Err(anyhow!("role databases is empty"));
77        }
78
79        // Check valid grants: CREATE, TEMP, TEMPORARY, ALL
80        let valid_grants = vec!["CREATE", "TEMP", "TEMPORARY", "ALL"];
81        let mut grants = HashSet::new();
82        for grant in &self.grants {
83            if !valid_grants.contains(&&grant[..]) {
84                return Err(anyhow!(
85                    "invalid grant: {}, expected: {:?}",
86                    grant,
87                    valid_grants
88                ));
89            }
90            grants.insert(grant.to_string());
91        }
92
93        if self.grants.is_empty() {
94            return Err(anyhow!("role grants is empty"));
95        }
96
97        Ok(())
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn test_role_database_level() {
107        let role = RoleDatabaseLevel {
108            name: "role_database_level".to_string(),
109            grants: vec!["CREATE".to_string(), "TEMP".to_string()],
110            databases: vec!["db1".to_string(), "db2".to_string()],
111        };
112
113        assert!(role.validate().is_ok());
114        assert_eq!(
115            role.to_sql("user"),
116            "GRANT CREATE, TEMP ON DATABASE \"db1\", \"db2\" TO \"user\";"
117        );
118    }
119
120    #[test]
121    fn test_sql_injection_prevention() {
122        let role = RoleDatabaseLevel {
123            name: "test".to_string(),
124            grants: vec!["CREATE".to_string()],
125            databases: vec!["db1\"; DROP DATABASE postgres; --".to_string()],
126        };
127
128        let sql = role.to_sql("user\"; DROP USER postgres; --");
129        // Verify the injection is properly escaped
130        assert!(sql.contains("\"db1\"\"; DROP DATABASE postgres; --\""));
131        assert!(sql.contains("\"user\"\"; DROP USER postgres; --\""));
132    }
133}