Skip to main content

grant/
connection.rs

1use crate::config::connection::ConnectionType;
2use crate::config::Config;
3use anyhow::{anyhow, Result};
4use log::{debug, info};
5use postgres::{row::Row, types::ToSql, Client, Config as ConnConfig, NoTls, ToStatement};
6
7// TODO: support multiple adapters
8
9/// Connection to the database, currently only Postgres and Redshift is supported
10/// TODO: support multiple adapters
11pub struct DbConnection {
12    pub connection_info: String,
13    pub client: Client,
14    conn_config: ConnConfig,
15}
16
17/// Presentation for a user in the database
18#[derive(Debug, Clone)]
19pub struct User {
20    pub name: String,
21    pub user_createdb: bool,
22    pub user_super: bool,
23    pub password: String,
24}
25
26/// Presentation for a user database privilege in the database
27/// which a users has `create` or `temp` on database
28#[derive(Debug)]
29pub struct UserDatabaseRole {
30    pub name: String,
31    pub database_name: String,
32    pub has_create: bool,
33    pub has_temp: bool,
34}
35
36impl UserDatabaseRole {
37    pub fn perm_to_string(&self, with_name: bool) -> String {
38        if with_name {
39            return format!("{}({})", self.database_name, self.perm_to_string(false));
40        }
41
42        match (self.has_create, self.has_temp) {
43            (true, true) => "A".to_string(),
44            (true, false) => "C".to_string(),
45            (false, true) => "T".to_string(),
46            (false, false) => "".to_string(),
47        }
48    }
49}
50
51/// Presentation for a user schema privilege in the database
52/// which a users has `create` or `usage` on schema
53#[derive(Debug)]
54pub struct UserSchemaRole {
55    pub name: String,
56    pub schema_name: String,
57    pub has_create: bool,
58    pub has_usage: bool,
59}
60
61impl UserSchemaRole {
62    pub fn perm_to_string(&self, with_name: bool) -> String {
63        if with_name {
64            return format!("{}({})", self.schema_name, self.perm_to_string(false));
65        }
66
67        match (self.has_create, self.has_usage) {
68            (true, true) => "A".to_string(),
69            (true, false) => "C".to_string(),
70            (false, true) => "U".to_string(),
71            _ => "".to_string(),
72        }
73    }
74}
75
76/// Presentation for a user table privilege in the database
77/// which a users has `select`, `insert`, `update`, `delete` or `reference` on table
78#[derive(Debug)]
79pub struct UserTableRole {
80    pub name: String,
81    pub schema_name: String,
82    pub table_name: String,
83    pub has_select: bool,
84    pub has_insert: bool,
85    pub has_update: bool,
86    pub has_delete: bool,
87    pub has_references: bool,
88}
89
90impl UserTableRole {
91    pub fn perm_to_string(&self, with_name: bool) -> String {
92        if with_name {
93            return format!(
94                "{}.{}({})",
95                self.schema_name,
96                self.table_name,
97                self.perm_to_string(false)
98            );
99        }
100
101        if self.has_select
102            && self.has_insert
103            && self.has_update
104            && self.has_delete
105            && self.has_references
106        {
107            return "A".to_string();
108        }
109
110        let has_select = if self.has_select { "S" } else { "" };
111        let has_insert = if self.has_insert { "I" } else { "" };
112        let has_update = if self.has_update { "U" } else { "" };
113        let has_delete = if self.has_delete { "D" } else { "" };
114        let has_references = if self.has_references { "R" } else { "" };
115        format!(
116            "{}{}{}{}{}",
117            has_select, has_insert, has_update, has_delete, has_references
118        )
119    }
120}
121
122impl DbConnection {
123    /// A convenience function which store the connection string into `connection_info` and then connects to the database.
124    ///
125    /// Refer to <https://rust-lang-nursery.github.io/rust-cookbook/database/postgres.html>
126    /// for more information.
127    ///
128    /// ```rust
129    /// use grant::{config::Config, connection::DbConnection};
130    /// use std::str::FromStr;
131    ///
132    /// # fn main() -> anyhow::Result<()> {
133    /// let config = Config::from_str(
134    ///     r#"
135    ///       connection:
136    ///         type: postgres
137    ///         url: "postgresql://postgres:postgres@localhost:5432/postgres"
138    ///       roles: []
139    ///       users: []
140    ///     "#,
141    ///    )
142    ///    .unwrap();
143    ///    let mut db = DbConnection::new(&config)?;
144    ///    db.query("SELECT 1", &[]).unwrap();
145    /// # Ok(())
146    /// # }
147    /// ```
148    pub fn new(config: &Config) -> Result<Self> {
149        match config.connection.type_ {
150            ConnectionType::Postgres => {
151                let connection_info = config.connection.url.clone();
152                let mut client = Client::connect(&connection_info, NoTls).map_err(|e| {
153                    anyhow!("Failed to connect to database '{}': {}", connection_info, e)
154                })?;
155
156                if let Err(e) = client.simple_query("SELECT 1") {
157                    return Err(anyhow!(
158                        "Database connection test failed for '{}': {}",
159                        connection_info,
160                        e
161                    ));
162                } else {
163                    info!("Connected to database: {}", connection_info);
164                }
165
166                let conn_config = connection_info.parse::<ConnConfig>().map_err(|e| {
167                    anyhow!(
168                        "Failed to parse connection string '{}': {}",
169                        connection_info,
170                        e
171                    )
172                })?;
173
174                Ok(DbConnection {
175                    connection_info,
176                    client,
177                    conn_config,
178                })
179            }
180        }
181    }
182
183    /// Get current database name.
184    pub fn get_current_database(&self) -> Option<&str> {
185        self.conn_config.get_dbname()
186    }
187
188    /// Returns the connection_info
189    ///
190    /// ```rust
191    /// use grant::connection::DbConnection;
192    /// use std::str::FromStr;
193    ///
194    /// let connection_info = "postgres://postgres:postgres@localhost:5432/postgres";
195    /// let mut client = DbConnection::from_str(connection_info).unwrap();
196    /// assert_eq!(client.connection_info(), "postgres://postgres:postgres@localhost:5432/postgres");
197    /// ```
198    pub fn connection_info(self) -> String {
199        self.connection_info
200    }
201
202    /// Get the list of users
203    pub fn get_users(&mut self) -> Result<Vec<User>> {
204        let mut users = vec![];
205
206        // TODO: Get the password from database, currently it only returns *****
207        let sql = "SELECT usename, usecreatedb, usesuper, passwd FROM pg_user";
208        let stmt = self
209            .client
210            .prepare(sql)
211            .map_err(|e| anyhow!("Failed to prepare query for users: {}", e))?;
212
213        debug!("executing: {}", sql);
214        let rows = self
215            .client
216            .query(&stmt, &[])
217            .map_err(|e| anyhow!("Failed to query users from pg_user: {}", e))?;
218
219        for row in rows {
220            match (row.get(0), row.get(1), row.get(2), row.get(3)) {
221                (Some(name), Some(user_createdb), Some(user_super), Some(password)) => {
222                    users.push(User {
223                        name,
224                        user_createdb,
225                        user_super,
226                        password,
227                    })
228                }
229                (Some(name), _, _, _) => users.push(User {
230                    name,
231                    user_createdb: false,
232                    user_super: false,
233                    password: String::from(""),
234                }),
235                (_, _, _, _) => (),
236            }
237        }
238
239        debug!("get_users: {:#?}", users);
240
241        Ok(users)
242    }
243
244    /// Get the current database roles for user `user_name` in current database
245    /// Returns a list of `RoleDatabaseLevel`
246    pub fn get_user_database_privileges(&mut self) -> Result<Vec<UserDatabaseRole>> {
247        let mut roles = vec![];
248
249        let sql = r#"
250            WITH db AS (
251                SELECT d.datname AS database_name
252                FROM pg_database d
253            ),
254            users AS (
255                SELECT usename as user_name FROM pg_user
256            )
257            SELECT
258                u.user_name,
259                db.database_name,
260                pg_catalog.has_database_privilege(u.user_name, database_name, 'CREATE') AS "create",
261                pg_catalog.has_database_privilege(u.user_name, database_name, 'TEMP') AS "temp"
262            FROM db CROSS JOIN users u;
263        "#;
264
265        let stmt = self
266            .client
267            .prepare(sql)
268            .map_err(|e| anyhow!("Failed to prepare query for database privileges: {}", e))?;
269
270        debug!("executing: {}", sql);
271        let rows = self
272            .client
273            .query(&stmt, &[])
274            .map_err(|e| anyhow!("Failed to query database privileges: {}", e))?;
275        for row in rows {
276            let name: &str = row.get(0);
277            let database_name: &str = row.get(1);
278            let has_create: bool = row.get(2);
279            let has_temp: bool = row.get(3);
280
281            roles.push(UserDatabaseRole {
282                name: name.to_string(),
283                database_name: database_name.to_string(),
284                has_create,
285                has_temp,
286            })
287        }
288
289        Ok(roles)
290    }
291
292    /// Get the user schema privileges for current database
293    pub fn get_user_schema_privileges(&mut self) -> Result<Vec<UserSchemaRole>> {
294        // FIXME it will be empty if the schema doesn't have any tables
295        let sql = "
296            SELECT
297              u.usename AS name,
298              s.schemaname AS schema_name,
299              has_schema_privilege(u.usename, s.schemaname, 'create') AS has_create,
300              has_schema_privilege(u.usename, s.schemaname, 'usage') AS has_usage
301            FROM
302              pg_user u
303              CROSS JOIN (SELECT DISTINCT schemaname FROM pg_tables) s
304            WHERE
305              1 = 1
306              AND s.schemaname != 'pg_catalog'
307              AND s.schemaname != 'information_schema';
308        ";
309
310        let stmt = self
311            .client
312            .prepare(sql)
313            .map_err(|e| anyhow!("Failed to prepare query for schema privileges: {}", e))?;
314
315        debug!("executing: {}", sql);
316        let rows = self
317            .client
318            .query(&stmt, &[])
319            .map_err(|e| anyhow!("Failed to query schema privileges: {}", e))?;
320        let mut roles = vec![];
321        for row in rows {
322            let name = row.get(0);
323            let schema_name = row.get(1);
324            let has_create = row.get(2);
325            let has_usage = row.get(3);
326            if let (Some(name), Some(schema_name), Some(has_create), Some(has_usage)) =
327                (name, schema_name, has_create, has_usage)
328            {
329                roles.push(UserSchemaRole {
330                    name,
331                    schema_name,
332                    has_create,
333                    has_usage,
334                })
335            }
336        }
337
338        Ok(roles)
339    }
340
341    /// Get the user table privileges for current database
342    pub fn get_user_table_privileges(&mut self) -> Result<Vec<UserTableRole>> {
343        let mut roles = vec![];
344        let sql = "
345            SELECT
346              u.usename AS name,
347              t.schemaname AS schema_name,
348              t.tablename AS table_name,
349              has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'select') AS has_select,
350              has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'insert') AS has_insert,
351              has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'update') AS has_update,
352              has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'delete') AS has_delete,
353              has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'references') AS has_references
354            FROM
355              pg_user u
356              CROSS JOIN (SELECT DISTINCT schemaname, tablename FROM pg_tables) t
357              WHERE 1 = 1
358                AND t.schemaname NOT LIKE 'pg_%'
359                AND t.schemaname != 'information_schema';
360        ";
361
362        let stmt = self
363            .client
364            .prepare(sql)
365            .map_err(|e| anyhow!("Failed to prepare query for table privileges: {}", e))?;
366
367        debug!("executing: {}", sql);
368        let rows = self
369            .client
370            .query(&stmt, &[])
371            .map_err(|e| anyhow!("Failed to query table privileges: {}", e))?;
372        for row in rows {
373            let name = row.get(0);
374            let schema_name = row.get(1);
375            let table_name = row.get(2);
376            let has_select = row.get(3);
377            let has_insert = row.get(4);
378            let has_update = row.get(5);
379            let has_delete = row.get(6);
380            let has_references = row.get(7);
381
382            if let (
383                Some(name),
384                Some(schema_name),
385                Some(table_name),
386                Some(has_select),
387                Some(has_insert),
388                Some(has_update),
389                Some(has_delete),
390                Some(has_references),
391            ) = (
392                name,
393                schema_name,
394                table_name,
395                has_select,
396                has_insert,
397                has_update,
398                has_delete,
399                has_references,
400            ) {
401                roles.push(UserTableRole {
402                    name,
403                    schema_name,
404                    table_name,
405                    has_insert,
406                    has_select,
407                    has_update,
408                    has_delete,
409                    has_references,
410                })
411            }
412        }
413
414        Ok(roles)
415    }
416
417    /// Executes a statement, returning the resulting rows
418    /// A statement may contain parameters, specified by `$n` where `n` is the
419    /// index of the parameter in the list provided, 1-indexed.
420    ///
421    /// ```rust
422    /// use grant::connection::DbConnection;
423    /// use std::str::FromStr;
424    ///
425    /// let url = "postgresql://postgres:postgres@localhost:5432/postgres";
426    /// let mut db = DbConnection::from_str(url).unwrap();
427    /// let rows = db.query("SELECT 1 as t", &[]).unwrap();
428    /// println!("test_query: {:?}", rows);
429    ///
430    /// assert_eq!(rows.len(), 1);
431    /// assert_eq!(rows.get(0).unwrap().len(), 1);
432    ///
433    /// let t: i32 = rows.get(0).unwrap().get("t");
434    /// assert_eq!(t, 1);
435    /// ```
436    pub fn query<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>>
437    where
438        T: ?Sized + ToStatement,
439    {
440        let ri = self.client.query(query, params)?;
441        Ok(ri)
442    }
443
444    /// Executes a statement, returning the number of rows modified.
445    ///
446    /// If the statement does not modify any rows (e.g. SELECT), 0 is returned.
447    ///
448    /// ```rust
449    /// use grant::connection::DbConnection;
450    /// use std::str::FromStr;
451    ///
452    /// let url = "postgresql://postgres:postgres@localhost:5432/postgres";
453    /// let mut db = DbConnection::from_str(url).unwrap();
454    /// let nrows = db.execute("SELECT 1 as t", &[]).unwrap();
455    ///
456    /// println!("test_execute: {:?}", nrows);
457    /// assert_eq!(nrows, 1);
458    /// ```
459    pub fn execute(&mut self, query: &str, params: &[&(dyn ToSql + Sync)]) -> Result<i64> {
460        // Support multiple query statements by splitting on semicolons
461        // and executing each one separately (if any)
462        // This is a bit of a hack, but it's the only way to support
463        // multiple statements in the execute method without having
464        // to rewrite the entire method
465        // should split params into multiple slices as well
466        let queries = query.split(';');
467        let mut rows_affected = 0;
468
469        for query in queries {
470            let query = query.trim();
471            if query.is_empty() {
472                continue;
473            }
474
475            let stmt = self.client.prepare(query)?;
476            let rows = self.client.execute(&stmt, params)?;
477            rows_affected += rows;
478        }
479
480        rows_affected
481            .try_into()
482            .map_err(|e| anyhow!("Row count {} exceeds i64::MAX: {}", rows_affected, e))
483    }
484}
485
486impl std::str::FromStr for DbConnection {
487    type Err = anyhow::Error;
488
489    /// Connection by a connection string.
490    ///
491    /// ```
492    /// use grant::connection::DbConnection;
493    /// use std::str::FromStr;
494    ///
495    /// let connection_info = "postgres://postgres:postgres@localhost:5432/postgres";
496    /// let mut client = DbConnection::from_str(connection_info).unwrap();
497    /// client.query("SELECT 1", &[]).unwrap();
498    /// ```
499    fn from_str(connection_info: &str) -> Result<Self> {
500        let client = Client::connect(connection_info, NoTls)
501            .map_err(|e| anyhow!("Failed to connect to database: {}", e))?;
502        let conn_config = connection_info
503            .parse::<ConnConfig>()
504            .map_err(|e| anyhow!("Failed to parse connection string: {}", e))?;
505
506        Ok(Self {
507            connection_info: connection_info.to_owned(),
508            client,
509            conn_config,
510        })
511    }
512}
513
514// Test DbConnection
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use crate::config::sql_utils::{escape_identifier, escape_sql_string};
519    use rand::{thread_rng, Rng};
520    use std::str::FromStr;
521
522    fn drop_user(db: &mut DbConnection, name: &str) {
523        let sql = format!("DROP USER IF EXISTS {}", escape_identifier(name));
524        db.execute(&sql, &[]).unwrap();
525    }
526
527    fn create_user(db: &mut DbConnection, user: &User) {
528        let mut sql = format!("CREATE USER {} ", escape_identifier(&user.name));
529        if user.user_createdb {
530            sql += "CREATEDB"
531        }
532        if !user.password.is_empty() {
533            let escaped_password = escape_sql_string(&user.password);
534            sql += &format!(" PASSWORD '{}'", escaped_password)
535        }
536
537        db.execute(&sql, &[]).unwrap();
538    }
539
540    #[test]
541    fn test_drop_user() {
542        let url = "postgres://postgres:postgres@localhost:5432/postgres";
543        let mut db = DbConnection::from_str(url).unwrap();
544
545        let name = random_str();
546        let user = User {
547            name: name.to_owned(),
548            user_createdb: false,
549            user_super: false,
550            password: "duyet".to_string(),
551        };
552
553        drop_user(&mut db, &name);
554        create_user(&mut db, &user);
555        drop_user(&mut db, &name);
556
557        let users = db.get_users().unwrap_or_default();
558        assert_eq!(users.iter().any(|u| u.name == name), false);
559
560        // Clean up
561        drop_user(&mut db, &name);
562    }
563
564    #[test]
565    fn test_drop_create_user() {
566        let url = "postgres://postgres:postgres@localhost:5432/postgres";
567        let mut db = DbConnection::from_str(url).unwrap();
568
569        let name = random_str();
570        let user = User {
571            name: name.to_owned(),
572            user_createdb: false,
573            user_super: false,
574            password: "duyet".to_string(),
575        };
576        drop_user(&mut db, &name);
577        create_user(&mut db, &user);
578
579        let users = db.get_users().unwrap();
580
581        assert_eq!(users.iter().any(|u| u.name == name), true);
582
583        // Clean up
584        drop_user(&mut db, &name);
585    }
586
587    #[test]
588    fn test_get_schema_roles() {
589        let url = "postgres://postgres:postgres@localhost:5432/postgres";
590        let mut db = DbConnection::from_str(url).unwrap();
591
592        let name = random_str();
593        let user = User {
594            name: name.to_owned(),
595            user_createdb: false,
596            user_super: false,
597            password: "duyet".to_string(),
598        };
599        drop_user(&mut db, &name);
600        create_user(&mut db, &user);
601
602        // get user roles
603        let user_schema_privileges = db.get_user_schema_privileges().unwrap_or_default();
604
605        // FIXME it will be empty if the schema doesn't have any tables
606        if !user_schema_privileges.is_empty() {
607            // new user, that user will don't have any priviledge
608            assert_eq!(
609                user_schema_privileges
610                    .iter()
611                    .any(|u| u.name == name && !u.has_usage && !u.has_create),
612                true
613            );
614        }
615
616        // Clean up
617        drop_user(&mut db, &name);
618    }
619
620    // Test get_user_database_privileges
621    #[test]
622    fn test_get_user_database_privileges() {
623        let url = "postgres://postgres:postgres@localhost:5432/postgres";
624        let mut db = DbConnection::from_str(url).unwrap();
625
626        let name = random_str();
627        let user = User {
628            name: name.to_owned(),
629            user_createdb: false,
630            user_super: false,
631            password: "duyet".to_string(),
632        };
633        drop_user(&mut db, &name);
634        create_user(&mut db, &user);
635
636        // get user roles
637        let user_database_privileges = db.get_user_database_privileges().unwrap_or_default();
638
639        // Check if user_database_privileges contains current users
640        // is empty if the user doesn't have any database privileges
641        assert_eq!(
642            user_database_privileges
643                .iter()
644                .any(|u| u.name == name && u.has_create),
645            false
646        );
647
648        // FIXME seriously test this function
649
650        // Clean up
651        drop_user(&mut db, &name);
652    }
653
654    // Test get_user_schema_privileges
655    #[test]
656    fn test_get_user_schema_privileges() {
657        let url = "postgres://postgres:postgres@localhost:5432/postgres";
658        let mut db = DbConnection::from_str(url).unwrap();
659
660        let name = random_str();
661        let password = random_str();
662        let user = User {
663            name: name.to_owned(),
664            user_createdb: false,
665            user_super: false,
666            password,
667        };
668        drop_user(&mut db, &name);
669        create_user(&mut db, &user);
670
671        // get user roles
672        let user_schema_privileges = db.get_user_schema_privileges().unwrap_or_default();
673        println!("{:?}", user_schema_privileges);
674
675        // Check if user_schema_privileges contains current users
676        // is empty if the user doesn't have any schema privileges
677        // assert_eq!(user_schema_privileges.iter().any(|u| u.name == name), false);
678
679        // FIXME seriously test this function
680
681        // Clean up
682        drop_user(&mut db, &name);
683    }
684
685    // Test get_user_tables_privileges
686    #[test]
687    fn test_get_user_table_privileges() {
688        let url = "postgres://postgres:postgres@localhost:5432/postgres";
689        let mut db = DbConnection::from_str(url).unwrap();
690
691        let name = random_str();
692        let password = random_str();
693        let user = User {
694            name: name.to_owned(),
695            user_createdb: false,
696            user_super: false,
697            password,
698        };
699        drop_user(&mut db, &name);
700        create_user(&mut db, &user);
701
702        // get user roles
703        let user_table_privileges = db.get_user_table_privileges().unwrap_or_default();
704
705        // Check if user_tables_privileges contains current users
706        // is empty if the user doesn't have any tables privileges
707        assert_eq!(
708            user_table_privileges
709                .iter()
710                .any(|u| u.name == name && u.has_select),
711            false
712        );
713
714        // FIXME seriously test this function
715
716        // Clean up
717        drop_user(&mut db, &name);
718    }
719
720    fn random_str() -> String {
721        const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz";
722        let mut rng = thread_rng();
723
724        let name: String = (0..10)
725            .map(|_| {
726                let idx = rng.gen_range(0..CHARSET.len());
727                CHARSET[idx] as char
728            })
729            .collect();
730
731        name
732    }
733}