grant/config/
role_database.rs1use super::role::RoleValidate;
2use super::sql_utils::escape_identifier;
3use anyhow::{anyhow, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
22pub struct RoleDatabaseLevel {
23 pub name: String,
24 pub grants: Vec<String>,
25 pub databases: Vec<String>,
26}
27
28impl RoleDatabaseLevel {
29 pub fn to_sql(&self, user: &str) -> String {
37 let grants = if self.grants.is_empty() || self.grants.contains(&"ALL".to_string()) {
39 "ALL PRIVILEGES".to_string()
40 } else {
41 self.grants.join(", ")
42 };
43
44 let escaped_databases = self
46 .databases
47 .iter()
48 .map(|db| escape_identifier(db))
49 .collect::<Vec<_>>()
50 .join(", ");
51 let escaped_user = escape_identifier(user);
52
53 let sql = format!(
55 "GRANT {} ON DATABASE {} TO {};",
56 grants, escaped_databases, escaped_user
57 );
58
59 sql
60 }
61}
62
63impl RoleValidate for RoleDatabaseLevel {
64 fn validate(&self) -> Result<()> {
65 if self.name.is_empty() {
66 return Err(anyhow!("role name is empty"));
67 }
68
69 if self.databases.is_empty() {
70 return Err(anyhow!("role databases is empty"));
71 }
72
73 let valid_grants = vec!["CREATE", "TEMP", "TEMPORARY", "ALL"];
75 let mut grants = HashSet::new();
76 for grant in &self.grants {
77 if !valid_grants.contains(&&grant[..]) {
78 return Err(anyhow!(
79 "invalid grant: {}, expected: {:?}",
80 grant,
81 valid_grants
82 ));
83 }
84 grants.insert(grant.to_string());
85 }
86
87 if self.grants.is_empty() {
88 return Err(anyhow!("role grants is empty"));
89 }
90
91 Ok(())
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn test_role_database_level() {
101 let role = RoleDatabaseLevel {
102 name: "role_database_level".to_string(),
103 grants: vec!["CREATE".to_string(), "TEMP".to_string()],
104 databases: vec!["db1".to_string(), "db2".to_string()],
105 };
106
107 assert!(role.validate().is_ok());
108 assert_eq!(
109 role.to_sql("user"),
110 "GRANT CREATE, TEMP ON DATABASE \"db1\", \"db2\" TO \"user\";"
111 );
112 }
113
114 #[test]
115 fn test_sql_injection_prevention() {
116 let role = RoleDatabaseLevel {
117 name: "test".to_string(),
118 grants: vec!["CREATE".to_string()],
119 databases: vec!["db1\"; DROP DATABASE postgres; --".to_string()],
120 };
121
122 let sql = role.to_sql("user\"; DROP USER postgres; --");
123 assert!(sql.contains("\"db1\"\"; DROP DATABASE postgres; --\""));
125 assert!(sql.contains("\"user\"\"; DROP USER postgres; --\""));
126 }
127}