grant/config/
role_schema.rs1use super::role::RoleValidate;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
23pub struct RoleSchemaLevel {
24 pub name: String,
25 pub grants: Vec<String>,
26 pub schemas: Vec<String>,
27}
28
29impl RoleSchemaLevel {
30 fn escape_identifier(ident: &str) -> String {
32 format!("\"{}\"", ident.replace("\"", "\"\""))
35 }
36
37 pub fn to_sql(&self, user: &str) -> String {
45 let grants = if self.grants.is_empty() || self.grants.contains(&"ALL".to_string()) {
47 "ALL PRIVILEGES".to_string()
48 } else {
49 self.grants.join(", ")
50 };
51
52 let escaped_schemas = self
54 .schemas
55 .iter()
56 .map(|s| Self::escape_identifier(s))
57 .collect::<Vec<_>>()
58 .join(", ");
59 let escaped_user = Self::escape_identifier(user);
60
61 let sql = format!(
63 "GRANT {} ON SCHEMA {} TO {};",
64 grants, escaped_schemas, escaped_user
65 );
66
67 sql
68 }
69}
70
71impl RoleValidate for RoleSchemaLevel {
72 fn validate(&self) -> Result<()> {
73 if self.name.is_empty() {
74 return Err(anyhow!("role name is empty"));
75 }
76
77 if self.schemas.is_empty() {
78 return Err(anyhow!("role schemas is empty"));
79 }
80
81 let valid_grants = vec!["CREATE", "USAGE", "ALL"];
83 let mut grants = HashSet::new();
84 for grant in &self.grants {
85 if !valid_grants.contains(&&grant[..]) {
86 return Err(anyhow!(
87 "invalid grant: {}, expected: {:?}",
88 grant,
89 valid_grants
90 ));
91 }
92 grants.insert(grant.to_string());
93 }
94
95 if self.grants.is_empty() {
96 return Err(anyhow!("role grants is empty"));
97 }
98
99 Ok(())
100 }
101}
102
103#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn test_role_schema_level() {
110 let role_schema_level = RoleSchemaLevel {
111 name: "role_schema_level".to_string(),
112 grants: vec!["CREATE".to_string(), "TEMP".to_string()],
113 schemas: vec!["schema1".to_string(), "schema2".to_string()],
114 };
115
116 role_schema_level.validate().ok();
117
118 let sql = role_schema_level.to_sql("user");
119 assert_eq!(
120 sql,
121 "GRANT CREATE, TEMP ON SCHEMA \"schema1\", \"schema2\" TO \"user\";"
122 );
123 }
124
125 #[test]
126 fn test_sql_injection_prevention() {
127 let role = RoleSchemaLevel {
128 name: "test".to_string(),
129 grants: vec!["CREATE".to_string()],
130 schemas: vec!["schema\"; DROP SCHEMA public; --".to_string()],
131 };
132
133 let sql = role.to_sql("user\"; DROP USER postgres; --");
134 assert!(sql.contains("\"schema\"\"; DROP SCHEMA public; --\""));
136 assert!(sql.contains("\"user\"\"; DROP USER postgres; --\""));
137 }
138}