Skip to main content

grant/config/
role_database.rs

1use super::role::RoleValidate;
2use super::sql_utils::escape_identifier;
3use anyhow::{anyhow, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7/// Role Database Level.
8///
9/// For example:
10///
11/// ```yaml
12/// - name: role_database_level
13///   type: database
14///   grants:
15///     - CREATE
16///     - TEMP
17///   databases:
18///     - db1
19///     - db2
20/// ```
21#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
22pub struct RoleDatabaseLevel {
23    pub name: String,
24    pub grants: Vec<String>,
25    pub databases: Vec<String>,
26}
27
28impl RoleDatabaseLevel {
29    /// Generate role database to SQL.
30    ///
31    /// ```sql
32    /// { GRANT | REVOKE } { { CREATE | TEMPORARY | TEMP } [,...] | ALL [ PRIVILEGES ] }
33    /// ON DATABASE db_name [, ...]
34    /// TO { username [ WITH GRANT OPTION ] | GROUP group_name | PUBLIC } [, ...]
35    /// ```
36    pub fn to_sql(&self, user: &str) -> String {
37        // grant all if no grants specified or contains "ALL"
38        let grants = if self.grants.is_empty() || self.grants.contains(&"ALL".to_string()) {
39            "ALL PRIVILEGES".to_string()
40        } else {
41            self.grants.join(", ")
42        };
43
44        // escape database and user identifiers to prevent SQL injection
45        let escaped_databases = self
46            .databases
47            .iter()
48            .map(|db| escape_identifier(db))
49            .collect::<Vec<_>>()
50            .join(", ");
51        let escaped_user = escape_identifier(user);
52
53        // grant on databases to user
54        let sql = format!(
55            "GRANT {} ON DATABASE {} TO {};",
56            grants, escaped_databases, escaped_user
57        );
58
59        sql
60    }
61}
62
63impl RoleValidate for RoleDatabaseLevel {
64    fn validate(&self) -> Result<()> {
65        if self.name.is_empty() {
66            return Err(anyhow!("role name is empty"));
67        }
68
69        if self.databases.is_empty() {
70            return Err(anyhow!("role databases is empty"));
71        }
72
73        // Check valid grants: CREATE, TEMP, TEMPORARY, ALL
74        let valid_grants = vec!["CREATE", "TEMP", "TEMPORARY", "ALL"];
75        let mut grants = HashSet::new();
76        for grant in &self.grants {
77            if !valid_grants.contains(&&grant[..]) {
78                return Err(anyhow!(
79                    "invalid grant: {}, expected: {:?}",
80                    grant,
81                    valid_grants
82                ));
83            }
84            grants.insert(grant.to_string());
85        }
86
87        if self.grants.is_empty() {
88            return Err(anyhow!("role grants is empty"));
89        }
90
91        Ok(())
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn test_role_database_level() {
101        let role = RoleDatabaseLevel {
102            name: "role_database_level".to_string(),
103            grants: vec!["CREATE".to_string(), "TEMP".to_string()],
104            databases: vec!["db1".to_string(), "db2".to_string()],
105        };
106
107        assert!(role.validate().is_ok());
108        assert_eq!(
109            role.to_sql("user"),
110            "GRANT CREATE, TEMP ON DATABASE \"db1\", \"db2\" TO \"user\";"
111        );
112    }
113
114    #[test]
115    fn test_sql_injection_prevention() {
116        let role = RoleDatabaseLevel {
117            name: "test".to_string(),
118            grants: vec!["CREATE".to_string()],
119            databases: vec!["db1\"; DROP DATABASE postgres; --".to_string()],
120        };
121
122        let sql = role.to_sql("user\"; DROP USER postgres; --");
123        // Verify the injection is properly escaped
124        assert!(sql.contains("\"db1\"\"; DROP DATABASE postgres; --\""));
125        assert!(sql.contains("\"user\"\"; DROP USER postgres; --\""));
126    }
127}