1use super::role::RoleValidate;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
31pub struct RoleTableLevel {
32 pub name: String,
33 pub grants: Vec<String>,
34 pub schemas: Vec<String>,
35 pub tables: Vec<String>,
36}
37
38#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
39struct Table {
40 name: String,
41 sign: String,
42}
43
44impl Table {
45 fn new(name: &str) -> Self {
46 let sign = match name.chars().next() {
47 Some('+') => "+".to_string(),
48 Some('-') => "-".to_string(),
49 _ => "+".to_string(),
50 };
51 let name = name.trim_start_matches(&sign).to_string();
52
53 Self { name, sign }
54 }
55}
56
57impl RoleTableLevel {
58 fn escape_identifier(ident: &str) -> String {
60 format!("\"{}\"", ident.replace("\"", "\"\""))
63 }
64
65 pub fn to_sql(&self, user: &str) -> String {
73 let mut sqls = vec![];
74 let mut tables = self
75 .tables
76 .iter()
77 .map(|t| Table::new(t))
78 .collect::<Vec<Table>>();
79
80 let grants = if self.grants.contains(&"ALL".to_string()) {
82 "ALL PRIVILEGES".to_string()
83 } else {
84 self.grants.join(", ")
85 };
86
87 let escaped_schemas = self
89 .schemas
90 .iter()
91 .map(|s| Self::escape_identifier(s))
92 .collect::<Vec<_>>();
93 let escaped_user = Self::escape_identifier(user);
94
95 if let Some(table_named_all) = tables.iter().find(|t| t.name == "ALL") {
97 let schema_list = escaped_schemas.join(", ");
98 let sql = match table_named_all.sign.as_str() {
99 "+" => format!(
100 "GRANT {} ON ALL TABLES IN SCHEMA {} TO {};",
101 grants, schema_list, escaped_user
102 ),
103 "-" => format!(
104 "REVOKE {} ON ALL TABLES IN SCHEMA {} FROM {};",
105 grants, schema_list, escaped_user
106 ),
107 _ => "".to_string(),
108 };
109 sqls.push(sql);
110
111 for table in tables.clone() {
113 if table.name == "ALL" || table.sign == "+" {
114 tables.retain(|x| x != &table);
115 }
116 }
117 }
118
119 let grant_tables = tables.iter().filter(|x| x.sign == "+").collect::<Vec<_>>();
121 if !grant_tables.is_empty() {
122 let _with_schema = grant_tables
123 .iter()
124 .flat_map(|t| {
125 if t.name.contains('.') {
126 let parts: Vec<&str> = t.name.split('.').collect();
128 if parts.len() == 2 {
129 vec![format!(
130 "{}.{}",
131 Self::escape_identifier(parts[0]),
132 Self::escape_identifier(parts[1])
133 )]
134 } else {
135 vec![Self::escape_identifier(&t.name)]
136 }
137 } else {
138 self.schemas
139 .iter()
140 .map(|s| {
141 format!(
142 "{}.{}",
143 Self::escape_identifier(s),
144 Self::escape_identifier(&t.name)
145 )
146 })
147 .collect::<Vec<_>>()
148 }
149 })
150 .collect::<Vec<String>>()
151 .join(", ");
152
153 let sql = format!("GRANT {} ON {} TO {};", grants, _with_schema, escaped_user);
154 sqls.push(sql);
155
156 for table in tables.clone() {
158 if table.sign == "+" {
159 tables.retain(|x| x != &table);
160 }
161 }
162 }
163
164 let revoke_tables = tables.iter().filter(|x| x.sign == "-").collect::<Vec<_>>();
166 if !revoke_tables.is_empty() {
167 let _with_schema = revoke_tables
168 .iter()
169 .flat_map(|t| {
170 if t.name.contains('.') {
171 let parts: Vec<&str> = t.name.split('.').collect();
173 if parts.len() == 2 {
174 vec![format!(
175 "{}.{}",
176 Self::escape_identifier(parts[0]),
177 Self::escape_identifier(parts[1])
178 )]
179 } else {
180 vec![Self::escape_identifier(&t.name)]
181 }
182 } else {
183 self.schemas
184 .iter()
185 .map(|s| {
186 format!(
187 "{}.{}",
188 Self::escape_identifier(s),
189 Self::escape_identifier(&t.name)
190 )
191 })
192 .collect::<Vec<_>>()
193 }
194 })
195 .collect::<Vec<String>>()
196 .join(", ");
197
198 let sql = format!(
199 "REVOKE {} ON {} FROM {};",
200 grants, _with_schema, escaped_user
201 );
202 sqls.push(sql);
203 }
204
205 sqls.join(" ")
206 }
207}
208
209impl RoleValidate for RoleTableLevel {
210 fn validate(&self) -> Result<()> {
211 if self.name.is_empty() {
212 return Err(anyhow!("role.name is empty"));
213 }
214
215 if self.schemas.is_empty() {
216 return Err(anyhow!("role.schemas is empty"));
217 }
218
219 if self.schemas.contains(&"ALL".to_string()) {
221 return Err(anyhow!("role.schemas is not supported yet: ALL"));
222 }
223
224 if self.tables.is_empty() {
225 return Err(anyhow!("role.tables is empty"));
226 }
227
228 if self.grants.is_empty() {
229 return Err(anyhow!("role.grants is empty"));
230 }
231
232 let valid_grants = vec![
234 "SELECT",
235 "INSERT",
236 "UPDATE",
237 "DELETE",
238 "DROP",
239 "REFERENCES",
240 "ALL",
241 ];
242 let mut grants = HashSet::new();
243 for grant in &self.grants {
244 if !valid_grants.contains(&&grant[..]) {
245 return Err(anyhow!(
246 "role.grants invalid: {}, expected: {:?}",
247 grant,
248 valid_grants
249 ));
250 }
251 grants.insert(grant.to_string());
252 }
253
254 Ok(())
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_role_table_level() {
264 let role = RoleTableLevel {
265 name: "test".to_string(),
266 grants: vec!["SELECT".to_string()],
267 schemas: vec!["public".to_string()],
268 tables: vec!["test".to_string()],
269 };
270 assert_eq!(
271 role.to_sql("test"),
272 "GRANT SELECT ON \"public\".\"test\" TO \"test\";"
273 );
274
275 let role = RoleTableLevel {
276 name: "test".to_string(),
277 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
278 schemas: vec!["public".to_string()],
279 tables: vec!["test".to_string()],
280 };
281 assert_eq!(
282 role.to_sql("test"),
283 "GRANT SELECT, INSERT ON \"public\".\"test\" TO \"test\";"
284 );
285
286 let role = RoleTableLevel {
287 name: "test".to_string(),
288 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
289 schemas: vec!["public".to_string(), "test".to_string()],
290 tables: vec!["test".to_string()],
291 };
292 assert_eq!(
293 role.to_sql("test"),
294 "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\";"
295 );
296
297 let role = RoleTableLevel {
298 name: "test".to_string(),
299 grants: vec!["ALL".to_string()],
300 schemas: vec!["public".to_string()],
301 tables: vec!["test".to_string()],
302 };
303 assert_eq!(
304 role.to_sql("test"),
305 "GRANT ALL PRIVILEGES ON \"public\".\"test\" TO \"test\";"
306 );
307
308 let role = RoleTableLevel {
309 name: "test".to_string(),
310 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
311 schemas: vec!["public".to_string()],
312 tables: vec!["ALL".to_string()],
313 };
314 assert_eq!(
315 role.to_sql("test"),
316 "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\" TO \"test\";"
317 );
318
319 let role = RoleTableLevel {
320 name: "test".to_string(),
321 grants: vec!["ALL".to_string()],
322 schemas: vec!["public".to_string(), "test".to_string()],
323 tables: vec!["ALL".to_string()],
324 };
325 assert_eq!(
326 role.to_sql("test"),
327 "GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\";"
328 );
329
330 let role = RoleTableLevel {
331 name: "test".to_string(),
332 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
333 schemas: vec!["public".to_string(), "test".to_string()],
334 tables: vec!["ALL".to_string()],
335 };
336 assert_eq!(
337 role.to_sql("test"),
338 "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\";"
339 );
340
341 let role = RoleTableLevel {
342 name: "test".to_string(),
343 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
344 schemas: vec!["public".to_string(), "test".to_string()],
345 tables: vec!["test".to_string(), "test.test2".to_string()],
346 };
347 assert_eq!(
348 role.to_sql("test"),
349 "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\", \"test\".\"test2\" TO \"test\";"
350 );
351
352 let role = RoleTableLevel {
353 name: "test".to_string(),
354 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
355 schemas: vec!["public".to_string(), "test".to_string()],
356 tables: vec!["test".to_string(), "-test.test2".to_string()],
357 };
358 assert_eq!(
359 role.to_sql("test"),
360 "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"test\".\"test2\" FROM \"test\";"
361 );
362
363 let role = RoleTableLevel {
364 name: "test".to_string(),
365 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
366 schemas: vec!["public".to_string(), "test".to_string()],
367 tables: vec!["test".to_string(), "-test2".to_string()],
368 };
369 assert_eq!(
370 role.to_sql("test"),
371 "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"public\".\"test2\", \"test\".\"test2\" FROM \"test\";"
372 );
373
374 let role = RoleTableLevel {
375 name: "test".to_string(),
376 grants: vec!["SELECT".to_string(), "INSERT".to_string()],
377 schemas: vec!["public".to_string(), "test".to_string()],
378 tables: vec!["ALL".to_string(), "-test.test2".to_string()],
379 };
380 assert_eq!(
381 role.to_sql("test"),
382 "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"test\".\"test2\" FROM \"test\";"
383 );
384 }
385}