grant/config/
role_database.rs1use super::role::RoleValidate;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
21pub struct RoleDatabaseLevel {
22 pub name: String,
23 pub grants: Vec<String>,
24 pub databases: Vec<String>,
25}
26
27impl RoleDatabaseLevel {
28 fn escape_identifier(ident: &str) -> String {
30 format!("\"{}\"", ident.replace("\"", "\"\""))
33 }
34
35 pub fn to_sql(&self, user: &str) -> String {
43 let grants = if self.grants.is_empty() || self.grants.contains(&"ALL".to_string()) {
45 "ALL PRIVILEGES".to_string()
46 } else {
47 self.grants.join(", ")
48 };
49
50 let escaped_databases = self
52 .databases
53 .iter()
54 .map(|db| Self::escape_identifier(db))
55 .collect::<Vec<_>>()
56 .join(", ");
57 let escaped_user = Self::escape_identifier(user);
58
59 let sql = format!(
61 "GRANT {} ON DATABASE {} TO {};",
62 grants, escaped_databases, escaped_user
63 );
64
65 sql
66 }
67}
68
69impl RoleValidate for RoleDatabaseLevel {
70 fn validate(&self) -> Result<()> {
71 if self.name.is_empty() {
72 return Err(anyhow!("role name is empty"));
73 }
74
75 if self.databases.is_empty() {
76 return Err(anyhow!("role databases is empty"));
77 }
78
79 let valid_grants = vec!["CREATE", "TEMP", "TEMPORARY", "ALL"];
81 let mut grants = HashSet::new();
82 for grant in &self.grants {
83 if !valid_grants.contains(&&grant[..]) {
84 return Err(anyhow!(
85 "invalid grant: {}, expected: {:?}",
86 grant,
87 valid_grants
88 ));
89 }
90 grants.insert(grant.to_string());
91 }
92
93 if self.grants.is_empty() {
94 return Err(anyhow!("role grants is empty"));
95 }
96
97 Ok(())
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn test_role_database_level() {
107 let role = RoleDatabaseLevel {
108 name: "role_database_level".to_string(),
109 grants: vec!["CREATE".to_string(), "TEMP".to_string()],
110 databases: vec!["db1".to_string(), "db2".to_string()],
111 };
112
113 assert!(role.validate().is_ok());
114 assert_eq!(
115 role.to_sql("user"),
116 "GRANT CREATE, TEMP ON DATABASE \"db1\", \"db2\" TO \"user\";"
117 );
118 }
119
120 #[test]
121 fn test_sql_injection_prevention() {
122 let role = RoleDatabaseLevel {
123 name: "test".to_string(),
124 grants: vec!["CREATE".to_string()],
125 databases: vec!["db1\"; DROP DATABASE postgres; --".to_string()],
126 };
127
128 let sql = role.to_sql("user\"; DROP USER postgres; --");
129 assert!(sql.contains("\"db1\"\"; DROP DATABASE postgres; --\""));
131 assert!(sql.contains("\"user\"\"; DROP USER postgres; --\""));
132 }
133}