Skip to main content

grant/config/
role_schema.rs

1use super::role::RoleValidate;
2use super::sql_utils::escape_identifier;
3use anyhow::{anyhow, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7/// Role Schema Level.
8///
9/// For example:
10///
11/// ```yaml
12/// - name: role_schema_level
13///   type: SCHEMA
14///   grants:
15///     - CREATE
16///     - TEMP
17///   schemas:
18///     - schema1
19///     - schema2
20/// ```
21///
22///  The above example will grant CREATE and TEMP privileges on schema1 and schema2.
23#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
24pub struct RoleSchemaLevel {
25    pub name: String,
26    pub grants: Vec<String>,
27    pub schemas: Vec<String>,
28}
29
30impl RoleSchemaLevel {
31    /// Generate role schema to sql.
32    ///
33    /// ```sql
34    /// { GRANT | REVOKE } { { CREATE | USAGE } [,...] | ALL [ PRIVILEGES ] }
35    /// ON SCHEMA schema_name [, ...]
36    /// TO { username [ WITH GRANT OPTION ] | GROUP group_name | PUBLIC } [, ...]
37    /// ```
38    pub fn to_sql(&self, user: &str) -> String {
39        // grant all privileges if no grants are specified or if grants contains "ALL"
40        let grants = if self.grants.is_empty() || self.grants.contains(&"ALL".to_string()) {
41            "ALL PRIVILEGES".to_string()
42        } else {
43            self.grants.join(", ")
44        };
45
46        // escape schema and user identifiers to prevent SQL injection
47        let escaped_schemas = self
48            .schemas
49            .iter()
50            .map(|s| escape_identifier(s))
51            .collect::<Vec<_>>()
52            .join(", ");
53        let escaped_user = escape_identifier(user);
54
55        // grant on schemas to user
56        let sql = format!(
57            "GRANT {} ON SCHEMA {} TO {};",
58            grants, escaped_schemas, escaped_user
59        );
60
61        sql
62    }
63}
64
65impl RoleValidate for RoleSchemaLevel {
66    fn validate(&self) -> Result<()> {
67        if self.name.is_empty() {
68            return Err(anyhow!("role name is empty"));
69        }
70
71        if self.schemas.is_empty() {
72            return Err(anyhow!("role schemas is empty"));
73        }
74
75        // Check valid grants: CREATE, USAGE, ALL
76        let valid_grants = vec!["CREATE", "USAGE", "ALL"];
77        let mut grants = HashSet::new();
78        for grant in &self.grants {
79            if !valid_grants.contains(&&grant[..]) {
80                return Err(anyhow!(
81                    "invalid grant: {}, expected: {:?}",
82                    grant,
83                    valid_grants
84                ));
85            }
86            grants.insert(grant.to_string());
87        }
88
89        if self.grants.is_empty() {
90            return Err(anyhow!("role grants is empty"));
91        }
92
93        Ok(())
94    }
95}
96
97// Test
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn test_role_schema_level() {
104        let role_schema_level = RoleSchemaLevel {
105            name: "role_schema_level".to_string(),
106            grants: vec!["CREATE".to_string(), "TEMP".to_string()],
107            schemas: vec!["schema1".to_string(), "schema2".to_string()],
108        };
109
110        role_schema_level.validate().ok();
111
112        let sql = role_schema_level.to_sql("user");
113        assert_eq!(
114            sql,
115            "GRANT CREATE, TEMP ON SCHEMA \"schema1\", \"schema2\" TO \"user\";"
116        );
117    }
118
119    #[test]
120    fn test_sql_injection_prevention() {
121        let role = RoleSchemaLevel {
122            name: "test".to_string(),
123            grants: vec!["CREATE".to_string()],
124            schemas: vec!["schema\"; DROP SCHEMA public; --".to_string()],
125        };
126
127        let sql = role.to_sql("user\"; DROP USER postgres; --");
128        // Verify the injection is properly escaped
129        assert!(sql.contains("\"schema\"\"; DROP SCHEMA public; --\""));
130        assert!(sql.contains("\"user\"\"; DROP USER postgres; --\""));
131    }
132}