1use super::role::RoleValidate;
2use super::sql_utils::escape_identifier;
3use anyhow::{anyhow, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
32pub struct RoleTableLevel {
33 pub name: String,
34 pub grants: Vec<String>,
35 pub schemas: Vec<String>,
36 pub tables: Vec<String>,
37}
38
39#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
40struct Table {
41 name: String,
42 sign: String,
43}
44
45impl Table {
46 fn new(name: &str) -> Self {
47 let sign = match name.chars().next() {
48 Some('+') => "+".to_string(),
49 Some('-') => "-".to_string(),
50 _ => "+".to_string(),
51 };
52 let name = name.trim_start_matches(&sign).to_string();
53
54 Self { name, sign }
55 }
56}
57
58impl RoleTableLevel {
59 pub fn to_sql(&self, user: &str) -> String {
67 let mut sqls = vec![];
68 let mut tables = self
69 .tables
70 .iter()
71 .map(|t| Table::new(t))
72 .collect::<Vec<Table>>();
73
74 let grants = if self.grants.contains(&"ALL".to_string()) {
76 "ALL PRIVILEGES".to_string()
77 } else {
78 self.grants.join(", ")
79 };
80
81 let escaped_schemas = self
83 .schemas
84 .iter()
85 .map(|s| escape_identifier(s))
86 .collect::<Vec<_>>();
87 let escaped_user = escape_identifier(user);
88
89 let has_all_grant = tables.iter().any(|t| t.name == "ALL" && t.sign == "+");
91 let has_all_revoke = tables.iter().any(|t| t.name == "ALL" && t.sign == "-");
92
93 if let Some(table_named_all) = tables.iter().find(|t| t.name == "ALL") {
95 let schema_list = escaped_schemas.join(", ");
96 let sql = match table_named_all.sign.as_str() {
97 "+" => format!(
98 "GRANT {} ON ALL TABLES IN SCHEMA {} TO {};",
99 grants, schema_list, escaped_user
100 ),
101 "-" => format!(
102 "REVOKE {} ON ALL TABLES IN SCHEMA {} FROM {};",
103 grants, schema_list, escaped_user
104 ),
105 _ => "".to_string(),
106 };
107 sqls.push(sql);
108 }
109
110 if !has_all_grant {
113 let grant_tables = tables
114 .iter()
115 .filter(|x| x.sign == "+" && x.name != "ALL")
116 .collect::<Vec<_>>();
117 if !grant_tables.is_empty() {
118 let _with_schema = grant_tables
119 .iter()
120 .flat_map(|t| {
121 if t.name.contains('.') {
122 let parts: Vec<&str> = t.name.split('.').collect();
124 if parts.len() == 2 {
125 vec![format!(
126 "{}.{}",
127 escape_identifier(parts[0]),
128 escape_identifier(parts[1])
129 )]
130 } else {
131 vec![escape_identifier(&t.name)]
132 }
133 } else {
134 self.schemas
135 .iter()
136 .map(|s| {
137 format!(
138 "{}.{}",
139 escape_identifier(s),
140 escape_identifier(&t.name)
141 )
142 })
143 .collect::<Vec<_>>()
144 }
145 })
146 .collect::<Vec<String>>()
147 .join(", ");
148
149 let sql = format!("GRANT {} ON {} TO {};", grants, _with_schema, escaped_user);
150 sqls.push(sql);
151 }
152 }
153
154 let revoke_tables = tables
156 .iter()
157 .filter(|x| x.sign == "-" && x.name != "ALL")
158 .collect::<Vec<_>>();
159 if !revoke_tables.is_empty() {
160 let _with_schema = revoke_tables
161 .iter()
162 .flat_map(|t| {
163 if t.name.contains('.') {
164 let parts: Vec<&str> = t.name.split('.').collect();
166 if parts.len() == 2 {
167 vec![format!(
168 "{}.{}",
169 escape_identifier(parts[0]),
170 escape_identifier(parts[1])
171 )]
172 } else {
173 vec![escape_identifier(&t.name)]
174 }
175 } else {
176 self.schemas
177 .iter()
178 .map(|s| {
179 format!("{}.{}", escape_identifier(s), escape_identifier(&t.name))
180 })
181 .collect::<Vec<_>>()
182 }
183 })
184 .collect::<Vec<String>>()
185 .join(", ");
186
187 let sql = format!(
188 "REVOKE {} ON {} FROM {};",
189 grants, _with_schema, escaped_user
190 );
191 sqls.push(sql);
192 }
193
194 sqls.join(" ")
195 }
196}
197
198impl RoleValidate for RoleTableLevel {
199 fn validate(&self) -> Result<()> {
200 if self.name.is_empty() {
201 return Err(anyhow!("role.name is empty"));
202 }
203
204 if self.schemas.is_empty() {
205 return Err(anyhow!("role.schemas is empty"));
206 }
207
208 if self.schemas.contains(&"ALL".to_string()) {
210 return Err(anyhow!("role.schemas is not supported yet: ALL"));
211 }
212
213 if self.tables.is_empty() {
214 return Err(anyhow!("role.tables is empty"));
215 }
216
217 if self.grants.is_empty() {
218 return Err(anyhow!("role.grants is empty"));
219 }
220
221 let valid_grants = vec![
223 "SELECT",
224 "INSERT",
225 "UPDATE",
226 "DELETE",
227 "DROP",
228 "REFERENCES",
229 "ALL",
230 ];
231 let mut grants = HashSet::new();
232 for grant in &self.grants {
233 if !valid_grants.contains(&&grant[..]) {
234 return Err(anyhow!(
235 "role.grants invalid: {}, expected: {:?}",
236 grant,
237 valid_grants
238 ));
239 }
240 grants.insert(grant.to_string());
241 }
242
243 Ok(())
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_role_table_level() {
253 let role = RoleTableLevel {
254 name: "test".to_string(),
255 grants: vec!["SELECT".to_string()],
256 schemas: vec!["public".to_string()],
257 tables: vec!["test".to_string()],
258 };
259 assert_eq!(
260 role.to_sql("test"),
261 "GRANT SELECT ON \"public\".\"test\" TO \"test\";"
262 );
263
264 let role = RoleTableLevel {
265 name: "test".to_string(),
266 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
267 schemas: vec!["public".to_string()],
268 tables: vec!["test".to_string()],
269 };
270 assert_eq!(
271 role.to_sql("test"),
272 "GRANT SELECT, INSERT ON \"public\".\"test\" TO \"test\";"
273 );
274
275 let role = RoleTableLevel {
276 name: "test".to_string(),
277 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
278 schemas: vec!["public".to_string(), "test".to_string()],
279 tables: vec!["test".to_string()],
280 };
281 assert_eq!(
282 role.to_sql("test"),
283 "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\";"
284 );
285
286 let role = RoleTableLevel {
287 name: "test".to_string(),
288 grants: vec!["ALL".to_string()],
289 schemas: vec!["public".to_string()],
290 tables: vec!["test".to_string()],
291 };
292 assert_eq!(
293 role.to_sql("test"),
294 "GRANT ALL PRIVILEGES ON \"public\".\"test\" TO \"test\";"
295 );
296
297 let role = RoleTableLevel {
298 name: "test".to_string(),
299 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
300 schemas: vec!["public".to_string()],
301 tables: vec!["ALL".to_string()],
302 };
303 assert_eq!(
304 role.to_sql("test"),
305 "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\" TO \"test\";"
306 );
307
308 let role = RoleTableLevel {
309 name: "test".to_string(),
310 grants: vec!["ALL".to_string()],
311 schemas: vec!["public".to_string(), "test".to_string()],
312 tables: vec!["ALL".to_string()],
313 };
314 assert_eq!(
315 role.to_sql("test"),
316 "GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\";"
317 );
318
319 let role = RoleTableLevel {
320 name: "test".to_string(),
321 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
322 schemas: vec!["public".to_string(), "test".to_string()],
323 tables: vec!["ALL".to_string()],
324 };
325 assert_eq!(
326 role.to_sql("test"),
327 "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\";"
328 );
329
330 let role = RoleTableLevel {
331 name: "test".to_string(),
332 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
333 schemas: vec!["public".to_string(), "test".to_string()],
334 tables: vec!["test".to_string(), "test.test2".to_string()],
335 };
336 assert_eq!(
337 role.to_sql("test"),
338 "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\", \"test\".\"test2\" TO \"test\";"
339 );
340
341 let role = RoleTableLevel {
342 name: "test".to_string(),
343 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
344 schemas: vec!["public".to_string(), "test".to_string()],
345 tables: vec!["test".to_string(), "-test.test2".to_string()],
346 };
347 assert_eq!(
348 role.to_sql("test"),
349 "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"test\".\"test2\" FROM \"test\";"
350 );
351
352 let role = RoleTableLevel {
353 name: "test".to_string(),
354 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
355 schemas: vec!["public".to_string(), "test".to_string()],
356 tables: vec!["test".to_string(), "-test2".to_string()],
357 };
358 assert_eq!(
359 role.to_sql("test"),
360 "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"public\".\"test2\", \"test\".\"test2\" FROM \"test\";"
361 );
362
363 let role = RoleTableLevel {
364 name: "test".to_string(),
365 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
366 schemas: vec!["public".to_string(), "test".to_string()],
367 tables: vec!["ALL".to_string(), "-test.test2".to_string()],
368 };
369 assert_eq!(
370 role.to_sql("test"),
371 "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"test\".\"test2\" FROM \"test\";"
372 );
373 }
374}