Skip to main content

grant/config/
role_table.rs

1use super::role::RoleValidate;
2use super::sql_utils::escape_identifier;
3use anyhow::{anyhow, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7/// Role Table Level.
8///
9/// For example:
10///
11/// ```yaml
12/// - name: role_table
13///   grants:
14///     - SELECT
15///     - INSERT
16///     - UPDATE
17///     - DELETE
18///   schemas:
19///   - public
20///   tables:
21///     - ALL
22///     - +table1
23///     - -table2
24///     - -public.table2
25/// ```
26///
27/// The above example grants SELECT, INSERT, UPDATE, DELETE to all tables in the public schema
28/// except table2.
29/// The ALL is a special keyword that means all tables in the public schema.
30/// If the table does not have a schema, it is assumed to be in all schema.
31#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
32pub struct RoleTableLevel {
33    pub name: String,
34    pub grants: Vec<String>,
35    pub schemas: Vec<String>,
36    pub tables: Vec<String>,
37}
38
39#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
40struct Table {
41    name: String,
42    sign: String,
43}
44
45impl Table {
46    fn new(name: &str) -> Self {
47        let sign = match name.chars().next() {
48            Some('+') => "+".to_string(),
49            Some('-') => "-".to_string(),
50            _ => "+".to_string(),
51        };
52        let name = name.trim_start_matches(&sign).to_string();
53
54        Self { name, sign }
55    }
56}
57
58impl RoleTableLevel {
59    /// Generate role table to sql.
60    ///
61    /// ```sql
62    /// {GRANT | REVOKE} { { SELECT | INSERT | UPDATE | DELETE | DROP | REFERENCES } [,...] | ALL [ PRIVILEGES ] }
63    /// ON { [ TABLE ] table_name [, ...] | ALL TABLES IN SCHEMA schema_name [, ...] }
64    /// TO { username [ WITH GRANT OPTION ] | GROUP group_name | PUBLIC } [, ...]
65    /// ```
66    pub fn to_sql(&self, user: &str) -> String {
67        let mut sqls = vec![];
68        let mut tables = self
69            .tables
70            .iter()
71            .map(|t| Table::new(t))
72            .collect::<Vec<Table>>();
73
74        // grant all privileges if grants contains "ALL"
75        let grants = if self.grants.contains(&"ALL".to_string()) {
76            "ALL PRIVILEGES".to_string()
77        } else {
78            self.grants.join(", ")
79        };
80
81        // escape schemas and user identifiers to prevent SQL injection
82        let escaped_schemas = self
83            .schemas
84            .iter()
85            .map(|s| escape_identifier(s))
86            .collect::<Vec<_>>();
87        let escaped_user = escape_identifier(user);
88
89        // Check if `ALL` is present and track whether we should skip individual table processing
90        let has_all_grant = tables.iter().any(|t| t.name == "ALL" && t.sign == "+");
91        let has_all_revoke = tables.iter().any(|t| t.name == "ALL" && t.sign == "-");
92
93        // If `tables` contains `ALL`, process it
94        if let Some(table_named_all) = tables.iter().find(|t| t.name == "ALL") {
95            let schema_list = escaped_schemas.join(", ");
96            let sql = match table_named_all.sign.as_str() {
97                "+" => format!(
98                    "GRANT {} ON ALL TABLES IN SCHEMA {} TO {};",
99                    grants, schema_list, escaped_user
100                ),
101                "-" => format!(
102                    "REVOKE {} ON ALL TABLES IN SCHEMA {} FROM {};",
103                    grants, schema_list, escaped_user
104                ),
105                _ => "".to_string(),
106            };
107            sqls.push(sql);
108        }
109
110        // Grant on specific tables with sign `+` (only if ALL+ wasn't specified)
111        // When ALL+ exists, individual + tables are already included
112        if !has_all_grant {
113            let grant_tables = tables
114                .iter()
115                .filter(|x| x.sign == "+" && x.name != "ALL")
116                .collect::<Vec<_>>();
117            if !grant_tables.is_empty() {
118                let _with_schema = grant_tables
119                    .iter()
120                    .flat_map(|t| {
121                        if t.name.contains('.') {
122                            // For schema-qualified names, escape each part separately
123                            let parts: Vec<&str> = t.name.split('.').collect();
124                            if parts.len() == 2 {
125                                vec![format!(
126                                    "{}.{}",
127                                    escape_identifier(parts[0]),
128                                    escape_identifier(parts[1])
129                                )]
130                            } else {
131                                vec![escape_identifier(&t.name)]
132                            }
133                        } else {
134                            self.schemas
135                                .iter()
136                                .map(|s| {
137                                    format!(
138                                        "{}.{}",
139                                        escape_identifier(s),
140                                        escape_identifier(&t.name)
141                                    )
142                                })
143                                .collect::<Vec<_>>()
144                        }
145                    })
146                    .collect::<Vec<String>>()
147                    .join(", ");
148
149                let sql = format!("GRANT {} ON {} TO {};", grants, _with_schema, escaped_user);
150                sqls.push(sql);
151            }
152        }
153
154        // Revoke on specific tables with sign `-` (excluding `ALL`)
155        let revoke_tables = tables
156            .iter()
157            .filter(|x| x.sign == "-" && x.name != "ALL")
158            .collect::<Vec<_>>();
159        if !revoke_tables.is_empty() {
160            let _with_schema = revoke_tables
161                .iter()
162                .flat_map(|t| {
163                    if t.name.contains('.') {
164                        // For schema-qualified names, escape each part separately
165                        let parts: Vec<&str> = t.name.split('.').collect();
166                        if parts.len() == 2 {
167                            vec![format!(
168                                "{}.{}",
169                                escape_identifier(parts[0]),
170                                escape_identifier(parts[1])
171                            )]
172                        } else {
173                            vec![escape_identifier(&t.name)]
174                        }
175                    } else {
176                        self.schemas
177                            .iter()
178                            .map(|s| {
179                                format!("{}.{}", escape_identifier(s), escape_identifier(&t.name))
180                            })
181                            .collect::<Vec<_>>()
182                    }
183                })
184                .collect::<Vec<String>>()
185                .join(", ");
186
187            let sql = format!(
188                "REVOKE {} ON {} FROM {};",
189                grants, _with_schema, escaped_user
190            );
191            sqls.push(sql);
192        }
193
194        sqls.join(" ")
195    }
196}
197
198impl RoleValidate for RoleTableLevel {
199    fn validate(&self) -> Result<()> {
200        if self.name.is_empty() {
201            return Err(anyhow!("role.name is empty"));
202        }
203
204        if self.schemas.is_empty() {
205            return Err(anyhow!("role.schemas is empty"));
206        }
207
208        // TODO: support schemas=[ALL]
209        if self.schemas.contains(&"ALL".to_string()) {
210            return Err(anyhow!("role.schemas is not supported yet: ALL"));
211        }
212
213        if self.tables.is_empty() {
214            return Err(anyhow!("role.tables is empty"));
215        }
216
217        if self.grants.is_empty() {
218            return Err(anyhow!("role.grants is empty"));
219        }
220
221        // Check valid grants: SELECT, INSERT, UPDATE, DELETE, DROP, REFERENCES, ALL
222        let valid_grants = vec![
223            "SELECT",
224            "INSERT",
225            "UPDATE",
226            "DELETE",
227            "DROP",
228            "REFERENCES",
229            "ALL",
230        ];
231        let mut grants = HashSet::new();
232        for grant in &self.grants {
233            if !valid_grants.contains(&&grant[..]) {
234                return Err(anyhow!(
235                    "role.grants invalid: {}, expected: {:?}",
236                    grant,
237                    valid_grants
238                ));
239            }
240            grants.insert(grant.to_string());
241        }
242
243        Ok(())
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_role_table_level() {
253        let role = RoleTableLevel {
254            name: "test".to_string(),
255            grants: vec!["SELECT".to_string()],
256            schemas: vec!["public".to_string()],
257            tables: vec!["test".to_string()],
258        };
259        assert_eq!(
260            role.to_sql("test"),
261            "GRANT SELECT ON \"public\".\"test\" TO \"test\";"
262        );
263
264        let role = RoleTableLevel {
265            name: "test".to_string(),
266            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
267            schemas: vec!["public".to_string()],
268            tables: vec!["test".to_string()],
269        };
270        assert_eq!(
271            role.to_sql("test"),
272            "GRANT SELECT, INSERT ON \"public\".\"test\" TO \"test\";"
273        );
274
275        let role = RoleTableLevel {
276            name: "test".to_string(),
277            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
278            schemas: vec!["public".to_string(), "test".to_string()],
279            tables: vec!["test".to_string()],
280        };
281        assert_eq!(
282            role.to_sql("test"),
283            "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\";"
284        );
285
286        let role = RoleTableLevel {
287            name: "test".to_string(),
288            grants: vec!["ALL".to_string()],
289            schemas: vec!["public".to_string()],
290            tables: vec!["test".to_string()],
291        };
292        assert_eq!(
293            role.to_sql("test"),
294            "GRANT ALL PRIVILEGES ON \"public\".\"test\" TO \"test\";"
295        );
296
297        let role = RoleTableLevel {
298            name: "test".to_string(),
299            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
300            schemas: vec!["public".to_string()],
301            tables: vec!["ALL".to_string()],
302        };
303        assert_eq!(
304            role.to_sql("test"),
305            "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\" TO \"test\";"
306        );
307
308        let role = RoleTableLevel {
309            name: "test".to_string(),
310            grants: vec!["ALL".to_string()],
311            schemas: vec!["public".to_string(), "test".to_string()],
312            tables: vec!["ALL".to_string()],
313        };
314        assert_eq!(
315            role.to_sql("test"),
316            "GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\";"
317        );
318
319        let role = RoleTableLevel {
320            name: "test".to_string(),
321            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
322            schemas: vec!["public".to_string(), "test".to_string()],
323            tables: vec!["ALL".to_string()],
324        };
325        assert_eq!(
326            role.to_sql("test"),
327            "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\";"
328        );
329
330        let role = RoleTableLevel {
331            name: "test".to_string(),
332            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
333            schemas: vec!["public".to_string(), "test".to_string()],
334            tables: vec!["test".to_string(), "test.test2".to_string()],
335        };
336        assert_eq!(
337            role.to_sql("test"),
338            "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\", \"test\".\"test2\" TO \"test\";"
339        );
340
341        let role = RoleTableLevel {
342            name: "test".to_string(),
343            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
344            schemas: vec!["public".to_string(), "test".to_string()],
345            tables: vec!["test".to_string(), "-test.test2".to_string()],
346        };
347        assert_eq!(
348            role.to_sql("test"),
349            "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"test\".\"test2\" FROM \"test\";"
350        );
351
352        let role = RoleTableLevel {
353            name: "test".to_string(),
354            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
355            schemas: vec!["public".to_string(), "test".to_string()],
356            tables: vec!["test".to_string(), "-test2".to_string()],
357        };
358        assert_eq!(
359            role.to_sql("test"),
360            "GRANT SELECT, INSERT ON \"public\".\"test\", \"test\".\"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"public\".\"test2\", \"test\".\"test2\" FROM \"test\";"
361        );
362
363        let role = RoleTableLevel {
364            name: "test".to_string(),
365            grants: vec!["SELECT".to_string(), "INSERT".to_string()],
366            schemas: vec!["public".to_string(), "test".to_string()],
367            tables: vec!["ALL".to_string(), "-test.test2".to_string()],
368        };
369        assert_eq!(
370            role.to_sql("test"),
371            "GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA \"public\", \"test\" TO \"test\"; REVOKE SELECT, INSERT ON \"test\".\"test2\" FROM \"test\";"
372        );
373    }
374}