grant/config/
role_table.rs

1use super::role::RoleValidate;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6/// Role Table Level.
7///
8/// For example:
9///
10/// ```yaml
11/// - name: role_table
12///   grants:
13///     - SELECT
14///     - INSERT
15///     - UPDATE
16///     - DELETE
17///   schemas:
18///   - public
19///   tables:
20///     - ALL
21///     - +table1
22///     - -table2
23///     - -public.table2
24/// ```
25///
26/// The above example grants SELECT, INSERT, UPDATE, DELETE to all tables in the public schema
27/// except table2.
28/// The ALL is a special keyword that means all tables in the public schema.
29/// If the table does not have a schema, it is assumed to be in all schema.
30#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
31pub struct RoleTableLevel {
32    pub name: String,
33    pub grants: Vec<String>,
34    pub schemas: Vec<String>,
35    pub tables: Vec<String>,
36}
37
38#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
39struct Table {
40    name: String,
41    sign: String,
42}
43
44impl Table {
45    fn new(name: &str) -> Self {
46        let sign = match name.chars().next() {
47            Some('+') => "+".to_string(),
48            Some('-') => "-".to_string(),
49            _ => "+".to_string(),
50        };
51        let name = name.trim_start_matches(&sign).to_string();
52
53        Self { name, sign }
54    }
55}
56
57impl RoleTableLevel {
58    /// Escape and quote a PostgreSQL identifier to prevent SQL injection
59    fn escape_identifier(ident: &str) -> String {
60        // PostgreSQL identifiers are quoted with double quotes
61        // Escape double quotes by doubling them
62        format!("\"{}\"", ident.replace("\"", "\"\""))
63    }
64
65    /// Generate role table to sql.
66    ///
67    /// ```sql
68    /// {GRANT | REVOKE} { { SELECT | INSERT | UPDATE | DELETE | DROP | REFERENCES } [,...] | ALL [ PRIVILEGES ] }
69    /// ON { [ TABLE ] table_name [, ...] | ALL TABLES IN SCHEMA schema_name [, ...] }
70    /// TO { username [ WITH GRANT OPTION ] | GROUP group_name | PUBLIC } [, ...]
71    /// ```
72    pub fn to_sql(&self, user: &str) -> String {
73        let mut sqls = vec![];
74        let mut tables = self
75            .tables
76            .iter()
77            .map(|t| Table::new(t))
78            .collect::<Vec<Table>>();
79
80        // grant all privileges if grants contains "ALL"
81        let grants = if self.grants.contains(&"ALL".to_string()) {
82            "ALL PRIVILEGES".to_string()
83        } else {
84            self.grants.join(", ")
85        };
86
87        // escape schemas and user identifiers to prevent SQL injection
88        let escaped_schemas = self
89            .schemas
90            .iter()
91            .map(|s| Self::escape_identifier(s))
92            .collect::<Vec<_>>();
93        let escaped_user = Self::escape_identifier(user);
94
95        // if `tables` only contains `ALL`
96        if let Some(table_named_all) = tables.iter().find(|t| t.name == "ALL") {
97            let schema_list = escaped_schemas.join(", ");
98            let sql = match table_named_all.sign.as_str() {
99                "+" => format!(
100                    "GRANT {} ON ALL TABLES IN SCHEMA {} TO {};",
101                    grants, schema_list, escaped_user
102                ),
103                "-" => format!(
104                    "REVOKE {} ON ALL TABLES IN SCHEMA {} FROM {};",
105                    grants, schema_list, escaped_user
106                ),
107                _ => "".to_string(),
108            };
109            sqls.push(sql);
110
111            // remove name `ALL` and all tables start with `+`
112            for table in tables.clone() {
113                if table.name == "ALL" || table.sign == "+" {
114                    tables.retain(|x| x != &table);
115                }
116            }
117        }
118
119        // grant on tables sign `+`
120        let grant_tables = tables.iter().filter(|x| x.sign == "+").collect::<Vec<_>>();
121        if !grant_tables.is_empty() {
122            let _with_schema = grant_tables
123                .iter()
124                .flat_map(|t| {
125                    if t.name.contains('.') {
126                        // For schema-qualified names, escape each part separately
127                        let parts: Vec<&str> = t.name.split('.').collect();
128                        if parts.len() == 2 {
129                            vec![format!(
130                                "{}.{}",
131                                Self::escape_identifier(parts[0]),
132                                Self::escape_identifier(parts[1])
133                            )]
134                        } else {
135                            vec![Self::escape_identifier(&t.name)]
136                        }
137                    } else {
138                        self.schemas
139                            .iter()
140                            .map(|s| {
141                                format!(
142                                    "{}.{}",
143                                    Self::escape_identifier(s),
144                                    Self::escape_identifier(&t.name)
145                                )
146                            })
147                            .collect::<Vec<_>>()
148                    }
149                })
150                .collect::<Vec<String>>()
151                .join(", ");
152
153            let sql = format!("GRANT {} ON {} TO {};", grants, _with_schema, escaped_user);
154            sqls.push(sql);
155
156            // remove all tables start with `+`
157            for table in tables.clone() {
158                if table.sign == "+" {
159                    tables.retain(|x| x != &table);
160                }
161            }
162        }
163
164        // revoke on tables start with `-`
165        let revoke_tables = tables.iter().filter(|x| x.sign == "-").collect::<Vec<_>>();
166        if !revoke_tables.is_empty() {
167            let _with_schema = revoke_tables
168                .iter()
169                .flat_map(|t| {
170                    if t.name.contains('.') {
171                        // For schema-qualified names, escape each part separately
172                        let parts: Vec<&str> = t.name.split('.').collect();
173                        if parts.len() == 2 {
174                            vec![format!(
175                                "{}.{}",
176                                Self::escape_identifier(parts[0]),
177                                Self::escape_identifier(parts[1])
178                            )]
179                        } else {
180                            vec![Self::escape_identifier(&t.name)]
181                        }
182                    } else {
183                        self.schemas
184                            .iter()
185                            .map(|s| {
186                                format!(
187                                    "{}.{}",
188                                    Self::escape_identifier(s),
189                                    Self::escape_identifier(&t.name)
190                                )
191                            })
192                            .collect::<Vec<_>>()
193                    }
194                })
195                .collect::<Vec<String>>()
196                .join(", ");
197
198            let sql = format!(
199                "REVOKE {} ON {} FROM {};",
200                grants, _with_schema, escaped_user
201            );
202            sqls.push(sql);
203        }
204
205        sqls.join(" ")
206    }
207}
208
209impl RoleValidate for RoleTableLevel {
210    fn validate(&self) -> Result<()> {
211        if self.name.is_empty() {
212            return Err(anyhow!("role.name is empty"));
213        }
214
215        if self.schemas.is_empty() {
216            return Err(anyhow!("role.schemas is empty"));
217        }
218
219        // TODO: support schemas=[ALL]
220        if self.schemas.contains(&"ALL".to_string()) {
221            return Err(anyhow!("role.schemas is not supported yet: ALL"));
222        }
223
224        if self.tables.is_empty() {
225            return Err(anyhow!("role.tables is empty"));
226        }
227
228        if self.grants.is_empty() {
229            return Err(anyhow!("role.grants is empty"));
230        }
231
232        // Check valid grants: SELECT, INSERT, UPDATE, DELETE, DROP, REFERENCES, ALL
233        let valid_grants = vec![
234            "SELECT",
235            "INSERT",
236            "UPDATE",
237            "DELETE",
238            "DROP",
239            "REFERENCES",
240            "ALL",
241        ];
242        let mut grants = HashSet::new();
243        for grant in &self.grants {
244            if !valid_grants.contains(&&grant[..]) {
245                return Err(anyhow!(
246                    "role.grants invalid: {}, expected: {:?}",
247                    grant,
248                    valid_grants
249                ));
250            }
251            grants.insert(grant.to_string());
252        }
253
254        Ok(())
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_role_table_level() {
264        let role = RoleTableLevel {
265            name: "test".to_string(),
266            grants: vec!["SELECT".to_string()],
267            schemas: vec!["public".to_string()],
268            tables: vec!["test".to_string()],
269        };
270        assert_eq!(
271            role.to_sql("test"),
272            "GRANT SELECT ON \"public\".\"test\" TO \"test\";"
273        );
274
275        let role = RoleTableLevel {
276            name: "test".to_string(),
277            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
278            schemas: vec!["public".to_string()],
279            tables: vec!["test".to_string()],
280        };
281        assert_eq!(
282            role.to_sql("test"),
283            "GRANT SELECT, INSERT ON \"public\".\"test\" TO \"test\";"
284        );
285
286        let role = RoleTableLevel {
287            name: "test".to_string(),
288            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
289            schemas: vec!["public".to_string(), "test".to_string()],
290            tables: vec!["test".to_string()],
291        };
292        assert_eq!(
293            role.to_sql("test"),
294            "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\";"
295        );
296
297        let role = RoleTableLevel {
298            name: "test".to_string(),
299            grants: vec!["ALL".to_string()],
300            schemas: vec!["public".to_string()],
301            tables: vec!["test".to_string()],
302        };
303        assert_eq!(
304            role.to_sql("test"),
305            "GRANT ALL PRIVILEGES ON \"public\".\"test\" TO \"test\";"
306        );
307
308        let role = RoleTableLevel {
309            name: "test".to_string(),
310            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
311            schemas: vec!["public".to_string()],
312            tables: vec!["ALL".to_string()],
313        };
314        assert_eq!(
315            role.to_sql("test"),
316            "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\" TO \"test\";"
317        );
318
319        let role = RoleTableLevel {
320            name: "test".to_string(),
321            grants: vec!["ALL".to_string()],
322            schemas: vec!["public".to_string(), "test".to_string()],
323            tables: vec!["ALL".to_string()],
324        };
325        assert_eq!(
326            role.to_sql("test"),
327            "GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\";"
328        );
329
330        let role = RoleTableLevel {
331            name: "test".to_string(),
332            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
333            schemas: vec!["public".to_string(), "test".to_string()],
334            tables: vec!["ALL".to_string()],
335        };
336        assert_eq!(
337            role.to_sql("test"),
338            "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\";"
339        );
340
341        let role = RoleTableLevel {
342            name: "test".to_string(),
343            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
344            schemas: vec!["public".to_string(), "test".to_string()],
345            tables: vec!["test".to_string(), "test.test2".to_string()],
346        };
347        assert_eq!(
348            role.to_sql("test"),
349            "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\", \"test\".\"test2\" TO \"test\";"
350        );
351
352        let role = RoleTableLevel {
353            name: "test".to_string(),
354            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
355            schemas: vec!["public".to_string(), "test".to_string()],
356            tables: vec!["test".to_string(), "-test.test2".to_string()],
357        };
358        assert_eq!(
359            role.to_sql("test"),
360            "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"test\".\"test2\" FROM \"test\";"
361        );
362
363        let role = RoleTableLevel {
364            name: "test".to_string(),
365            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
366            schemas: vec!["public".to_string(), "test".to_string()],
367            tables: vec!["test".to_string(), "-test2".to_string()],
368        };
369        assert_eq!(
370            role.to_sql("test"),
371            "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"public\".\"test2\", \"test\".\"test2\" FROM \"test\";"
372        );
373
374        let role = RoleTableLevel {
375            name: "test".to_string(),
376            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
377            schemas: vec!["public".to_string(), "test".to_string()],
378            tables: vec!["ALL".to_string(), "-test.test2".to_string()],
379        };
380        assert_eq!(
381            role.to_sql("test"),
382            "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"test\".\"test2\" FROM \"test\";"
383        );
384    }
385}