grant/
connection.rs

1use crate::config::connection::ConnectionType;
2use crate::config::Config;
3use anyhow::{anyhow, Result};
4use log::{debug, info};
5use postgres::{row::Row, types::ToSql, Client, Config as ConnConfig, NoTls, ToStatement};
6
7// TODO: support multiple adapters
8
9/// Connection to the database, currently only Postgres and Redshift is supported
10/// TODO: support multiple adapters
11pub struct DbConnection {
12    pub connection_info: String,
13    pub client: Client,
14    conn_config: ConnConfig,
15}
16
17/// Presentation for a user in the database
18#[derive(Debug, Clone)]
19pub struct User {
20    pub name: String,
21    pub user_createdb: bool,
22    pub user_super: bool,
23    pub password: String,
24}
25
26/// Presentation for a user database privilege in the database
27/// which a users has `create` or `temp` on database
28#[derive(Debug)]
29pub struct UserDatabaseRole {
30    pub name: String,
31    pub database_name: String,
32    pub has_create: bool,
33    pub has_temp: bool,
34}
35
36impl UserDatabaseRole {
37    pub fn perm_to_string(&self, with_name: bool) -> String {
38        if with_name {
39            return format!("{}({})", self.database_name, self.perm_to_string(false));
40        }
41
42        match (self.has_create, self.has_temp) {
43            (true, true) => "A".to_string(),
44            (true, false) => "C".to_string(),
45            (false, true) => "T".to_string(),
46            (false, false) => "".to_string(),
47        }
48    }
49}
50
51/// Presentation for a user schema privilege in the database
52/// which a users has `create` or `usage` on schema
53#[derive(Debug)]
54pub struct UserSchemaRole {
55    pub name: String,
56    pub schema_name: String,
57    pub has_create: bool,
58    pub has_usage: bool,
59}
60
61impl UserSchemaRole {
62    pub fn perm_to_string(&self, with_name: bool) -> String {
63        if with_name {
64            return format!("{}({})", self.schema_name, self.perm_to_string(false));
65        }
66
67        match (self.has_create, self.has_usage) {
68            (true, true) => "A".to_string(),
69            (true, false) => "C".to_string(),
70            (false, true) => "U".to_string(),
71            _ => "".to_string(),
72        }
73    }
74}
75
76/// Presentation for a user table privilege in the database
77/// which a users has `select`, `insert`, `update`, `delete` or `reference` on table
78#[derive(Debug)]
79pub struct UserTableRole {
80    pub name: String,
81    pub schema_name: String,
82    pub table_name: String,
83    pub has_select: bool,
84    pub has_insert: bool,
85    pub has_update: bool,
86    pub has_delete: bool,
87    pub has_references: bool,
88}
89
90impl UserTableRole {
91    pub fn perm_to_string(&self, with_name: bool) -> String {
92        if with_name {
93            return format!(
94                "{}.{}({})",
95                self.schema_name,
96                self.table_name,
97                self.perm_to_string(false)
98            );
99        }
100
101        if self.has_select
102            && self.has_insert
103            && self.has_update
104            && self.has_delete
105            && self.has_references
106        {
107            return "A".to_string();
108        }
109
110        let has_select = if self.has_select { "S" } else { "" };
111        let has_insert = if self.has_insert { "I" } else { "" };
112        let has_update = if self.has_update { "U" } else { "" };
113        let has_delete = if self.has_delete { "D" } else { "" };
114        let has_references = if self.has_references { "R" } else { "" };
115        format!(
116            "{}{}{}{}{}",
117            has_select, has_insert, has_update, has_delete, has_references
118        )
119    }
120}
121
122impl DbConnection {
123    /// A convenience function which store the connection string into `connection_info` and then connects to the database.
124    ///
125    /// Refer to <https://rust-lang-nursery.github.io/rust-cookbook/database/postgres.html>
126    /// for more information.
127    ///
128    /// ```rust
129    /// use grant::{config::Config, connection::DbConnection};
130    /// use std::str::FromStr;
131    ///
132    /// # fn main() -> anyhow::Result<()> {
133    /// let config = Config::from_str(
134    ///     r#"
135    ///       connection:
136    ///         type: postgres
137    ///         url: "postgresql://postgres:postgres@localhost:5432/postgres"
138    ///       roles: []
139    ///       users: []
140    ///     "#,
141    ///    )
142    ///    .unwrap();
143    ///    let mut db = DbConnection::new(&config)?;
144    ///    db.query("SELECT 1", &[]).unwrap();
145    /// # Ok(())
146    /// # }
147    /// ```
148    pub fn new(config: &Config) -> Result<Self> {
149        match config.connection.type_ {
150            ConnectionType::Postgres => {
151                let connection_info = config.connection.url.clone();
152                let mut client = Client::connect(&connection_info, NoTls).map_err(|e| {
153                    anyhow!("Failed to connect to database '{}': {}", connection_info, e)
154                })?;
155
156                if let Err(e) = client.simple_query("SELECT 1") {
157                    return Err(anyhow!(
158                        "Database connection test failed for '{}': {}",
159                        connection_info,
160                        e
161                    ));
162                } else {
163                    info!("Connected to database: {}", connection_info);
164                }
165
166                let conn_config = connection_info.parse::<ConnConfig>().map_err(|e| {
167                    anyhow!(
168                        "Failed to parse connection string '{}': {}",
169                        connection_info,
170                        e
171                    )
172                })?;
173
174                Ok(DbConnection {
175                    connection_info,
176                    client,
177                    conn_config,
178                })
179            }
180        }
181    }
182
183    /// Get current database name.
184    pub fn get_current_database(&self) -> Option<&str> {
185        self.conn_config.get_dbname()
186    }
187
188    /// Returns the connection_info
189    ///
190    /// ```rust
191    /// use grant::connection::DbConnection;
192    /// use std::str::FromStr;
193    ///
194    /// let connection_info = "postgres://postgres:postgres@localhost:5432/postgres";
195    /// let mut client = DbConnection::from_str(connection_info).unwrap();
196    /// assert_eq!(client.connection_info(), "postgres://postgres:postgres@localhost:5432/postgres");
197    /// ```
198    pub fn connection_info(self) -> String {
199        self.connection_info
200    }
201
202    /// Get the list of users
203    pub fn get_users(&mut self) -> Result<Vec<User>> {
204        let mut users = vec![];
205
206        // TODO: Get the password from database, currently it only returns *****
207        let sql = "SELECT usename, usecreatedb, usesuper, passwd FROM pg_user";
208        let stmt = self
209            .client
210            .prepare(sql)
211            .map_err(|e| anyhow!("Failed to prepare query for users: {}", e))?;
212
213        debug!("executing: {}", sql);
214        let rows = self
215            .client
216            .query(&stmt, &[])
217            .map_err(|e| anyhow!("Failed to query users from pg_user: {}", e))?;
218
219        for row in rows {
220            match (row.get(0), row.get(1), row.get(2), row.get(3)) {
221                (Some(name), Some(user_createdb), Some(user_super), Some(password)) => {
222                    users.push(User {
223                        name,
224                        user_createdb,
225                        user_super,
226                        password,
227                    })
228                }
229                (Some(name), _, _, _) => users.push(User {
230                    name,
231                    user_createdb: false,
232                    user_super: false,
233                    password: String::from(""),
234                }),
235                (_, _, _, _) => (),
236            }
237        }
238
239        debug!("get_users: {:#?}", users);
240
241        Ok(users)
242    }
243
244    /// Get the current database roles for user `user_name` in current database
245    /// Returns a list of `RoleDatabaseLevel`
246    pub fn get_user_database_privileges(&mut self) -> Result<Vec<UserDatabaseRole>> {
247        let mut roles = vec![];
248
249        let sql = r#"
250            WITH db AS (
251                SELECT d.datname AS database_name
252                FROM pg_database d
253            ),
254            users AS (
255                SELECT usename as user_name FROM pg_user
256            )
257            SELECT
258                u.user_name,
259                db.database_name,
260                pg_catalog.has_database_privilege(u.user_name, database_name, 'CREATE') AS "create",
261                pg_catalog.has_database_privilege(u.user_name, database_name, 'TEMP') AS "temp"
262            FROM db CROSS JOIN users u;
263        "#;
264
265        let stmt = self
266            .client
267            .prepare(sql)
268            .map_err(|e| anyhow!("Failed to prepare query for database privileges: {}", e))?;
269
270        debug!("executing: {}", sql);
271        let rows = self
272            .client
273            .query(&stmt, &[])
274            .map_err(|e| anyhow!("Failed to query database privileges: {}", e))?;
275        for row in rows {
276            let name: &str = row.get(0);
277            let database_name: &str = row.get(1);
278            let has_create: bool = row.get(2);
279            let has_temp: bool = row.get(3);
280
281            roles.push(UserDatabaseRole {
282                name: name.to_string(),
283                database_name: database_name.to_string(),
284                has_create,
285                has_temp,
286            })
287        }
288
289        Ok(roles)
290    }
291
292    /// Get the user schema privileges for current database
293    pub fn get_user_schema_privileges(&mut self) -> Result<Vec<UserSchemaRole>> {
294        // FIXME it will be empty if the schema doesn't have any tables
295        let sql = "
296            SELECT
297              u.usename AS name,
298              s.schemaname AS schema_name,
299              has_schema_privilege(u.usename, s.schemaname, 'create') AS has_create,
300              has_schema_privilege(u.usename, s.schemaname, 'usage') AS has_usage
301            FROM
302              pg_user u
303              CROSS JOIN (SELECT DISTINCT schemaname FROM pg_tables) s
304            WHERE
305              1 = 1
306              AND s.schemaname != 'pg_catalog'
307              AND s.schemaname != 'information_schema';
308        ";
309
310        let stmt = self
311            .client
312            .prepare(sql)
313            .map_err(|e| anyhow!("Failed to prepare query for schema privileges: {}", e))?;
314
315        debug!("executing: {}", sql);
316        let rows = self
317            .client
318            .query(&stmt, &[])
319            .map_err(|e| anyhow!("Failed to query schema privileges: {}", e))?;
320        let mut roles = vec![];
321        for row in rows {
322            let name = row.get(0);
323            let schema_name = row.get(1);
324            let has_create = row.get(2);
325            let has_usage = row.get(3);
326            if let (Some(name), Some(schema_name), Some(has_create), Some(has_usage)) =
327                (name, schema_name, has_create, has_usage)
328            {
329                roles.push(UserSchemaRole {
330                    name,
331                    schema_name,
332                    has_create,
333                    has_usage,
334                })
335            }
336        }
337
338        Ok(roles)
339    }
340
341    /// Get the user table privileges for current database
342    pub fn get_user_table_privileges(&mut self) -> Result<Vec<UserTableRole>> {
343        let mut roles = vec![];
344        let sql = "
345            SELECT
346              u.usename AS name,
347              t.schemaname AS schema_name,
348              t.tablename AS table_name,
349              has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'select') AS has_select,
350              has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'insert') AS has_insert,
351              has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'update') AS has_update,
352              has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'delete') AS has_delete,
353              has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'references') AS has_references
354            FROM
355              pg_user u
356              CROSS JOIN (SELECT DISTINCT schemaname, tablename FROM pg_tables) t
357              WHERE 1 = 1
358                AND t.schemaname NOT LIKE 'pg_%'
359                AND t.schemaname != 'information_schema';
360        ";
361
362        let stmt = self
363            .client
364            .prepare(sql)
365            .map_err(|e| anyhow!("Failed to prepare query for table privileges: {}", e))?;
366
367        debug!("executing: {}", sql);
368        let rows = self
369            .client
370            .query(&stmt, &[])
371            .map_err(|e| anyhow!("Failed to query table privileges: {}", e))?;
372        for row in rows {
373            let name = row.get(0);
374            let schema_name = row.get(1);
375            let table_name = row.get(2);
376            let has_select = row.get(3);
377            let has_insert = row.get(4);
378            let has_update = row.get(5);
379            let has_delete = row.get(6);
380            let has_references = row.get(7);
381
382            if let (
383                Some(name),
384                Some(schema_name),
385                Some(table_name),
386                Some(has_select),
387                Some(has_insert),
388                Some(has_update),
389                Some(has_delete),
390                Some(has_references),
391            ) = (
392                name,
393                schema_name,
394                table_name,
395                has_select,
396                has_insert,
397                has_update,
398                has_delete,
399                has_references,
400            ) {
401                roles.push(UserTableRole {
402                    name,
403                    schema_name,
404                    table_name,
405                    has_insert,
406                    has_select,
407                    has_update,
408                    has_delete,
409                    has_references,
410                })
411            }
412        }
413
414        Ok(roles)
415    }
416
417    /// Executes a statement, returning the resulting rows
418    /// A statement may contain parameters, specified by `$n` where `n` is the
419    /// index of the parameter in the list provided, 1-indexed.
420    ///
421    /// ```rust
422    /// use grant::connection::DbConnection;
423    /// use std::str::FromStr;
424    ///
425    /// let url = "postgresql://postgres:postgres@localhost:5432/postgres";
426    /// let mut db = DbConnection::from_str(url).unwrap();
427    /// let rows = db.query("SELECT 1 as t", &[]).unwrap();
428    /// println!("test_query: {:?}", rows);
429    ///
430    /// assert_eq!(rows.len(), 1);
431    /// assert_eq!(rows.get(0).unwrap().len(), 1);
432    ///
433    /// let t: i32 = rows.get(0).unwrap().get("t");
434    /// assert_eq!(t, 1);
435    /// ```
436    pub fn query<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>>
437    where
438        T: ?Sized + ToStatement,
439    {
440        let ri = self.client.query(query, params)?;
441        Ok(ri)
442    }
443
444    /// Executes a statement, returning the number of rows modified.
445    ///
446    /// If the statement does not modify any rows (e.g. SELECT), 0 is returned.
447    ///
448    /// ```rust
449    /// use grant::connection::DbConnection;
450    /// use std::str::FromStr;
451    ///
452    /// let url = "postgresql://postgres:postgres@localhost:5432/postgres";
453    /// let mut db = DbConnection::from_str(url).unwrap();
454    /// let nrows = db.execute("SELECT 1 as t", &[]).unwrap();
455    ///
456    /// println!("test_execute: {:?}", nrows);
457    /// assert_eq!(nrows, 1);
458    /// ```
459    pub fn execute(&mut self, query: &str, params: &[&(dyn ToSql + Sync)]) -> Result<i64> {
460        // Support multiple query statements by splitting on semicolons
461        // and executing each one separately (if any)
462        // This is a bit of a hack, but it's the only way to support
463        // multiple statements in the execute method without having
464        // to rewrite the entire method
465        // should split params into multiple slices as well
466        let queries = query.split(';');
467        let mut rows_affected = 0;
468
469        for query in queries {
470            let query = query.trim();
471            if query.is_empty() {
472                continue;
473            }
474
475            let stmt = self.client.prepare(query)?;
476            let rows = self.client.execute(&stmt, params)?;
477            rows_affected += rows;
478        }
479
480        rows_affected
481            .try_into()
482            .map_err(|e| anyhow!("Row count {} exceeds i64::MAX: {}", rows_affected, e))
483    }
484}
485
486impl std::str::FromStr for DbConnection {
487    type Err = anyhow::Error;
488
489    /// Connection by a connection string.
490    ///
491    /// ```
492    /// use grant::connection::DbConnection;
493    /// use std::str::FromStr;
494    ///
495    /// let connection_info = "postgres://postgres:postgres@localhost:5432/postgres";
496    /// let mut client = DbConnection::from_str(connection_info).unwrap();
497    /// client.query("SELECT 1", &[]).unwrap();
498    /// ```
499    fn from_str(connection_info: &str) -> Result<Self> {
500        let client = Client::connect(connection_info, NoTls)
501            .map_err(|e| anyhow!("Failed to connect to database: {}", e))?;
502        let conn_config = connection_info
503            .parse::<ConnConfig>()
504            .map_err(|e| anyhow!("Failed to parse connection string: {}", e))?;
505
506        Ok(Self {
507            connection_info: connection_info.to_owned(),
508            client,
509            conn_config,
510        })
511    }
512}
513
514// Test DbConnection
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use rand::{thread_rng, Rng};
519    use std::str::FromStr;
520
521    fn drop_user(db: &mut DbConnection, name: &str) {
522        let sql = &format!("DROP USER IF EXISTS {}", name);
523        db.execute(sql, &[]).unwrap();
524    }
525
526    fn create_user(db: &mut DbConnection, user: &User) {
527        let mut sql = format!("CREATE USER {} ", user.name);
528        if user.user_createdb {
529            sql += "CREATEDB"
530        }
531        if !user.password.is_empty() {
532            sql += &format!(" PASSWORD '{}'", user.password)
533        }
534
535        db.execute(&sql, &[]).unwrap();
536    }
537
538    #[test]
539    fn test_drop_user() {
540        let url = "postgres://postgres:postgres@localhost:5432/postgres";
541        let mut db = DbConnection::from_str(url).unwrap();
542
543        let name = random_str();
544        let user = User {
545            name: name.to_owned(),
546            user_createdb: false,
547            user_super: false,
548            password: "duyet".to_string(),
549        };
550
551        drop_user(&mut db, &name);
552        create_user(&mut db, &user);
553        drop_user(&mut db, &name);
554
555        let users = db.get_users().unwrap_or_default();
556        assert_eq!(users.iter().any(|u| u.name == name), false);
557
558        // Clean up
559        drop_user(&mut db, &name);
560    }
561
562    #[test]
563    fn test_drop_create_user() {
564        let url = "postgres://postgres:postgres@localhost:5432/postgres";
565        let mut db = DbConnection::from_str(url).unwrap();
566
567        let name = random_str();
568        let user = User {
569            name: name.to_owned(),
570            user_createdb: false,
571            user_super: false,
572            password: "duyet".to_string(),
573        };
574        drop_user(&mut db, &name);
575        create_user(&mut db, &user);
576
577        let users = db.get_users().unwrap();
578
579        assert_eq!(users.iter().any(|u| u.name == name), true);
580
581        // Clean up
582        drop_user(&mut db, &name);
583    }
584
585    #[test]
586    fn test_get_schema_roles() {
587        let url = "postgres://postgres:postgres@localhost:5432/postgres";
588        let mut db = DbConnection::from_str(url).unwrap();
589
590        let name = random_str();
591        let user = User {
592            name: name.to_owned(),
593            user_createdb: false,
594            user_super: false,
595            password: "duyet".to_string(),
596        };
597        drop_user(&mut db, &name);
598        create_user(&mut db, &user);
599
600        // get user roles
601        let user_schema_privileges = db.get_user_schema_privileges().unwrap_or_default();
602
603        // FIXME it will be empty if the schema doesn't have any tables
604        if !user_schema_privileges.is_empty() {
605            // new user, that user will don't have any priviledge
606            assert_eq!(
607                user_schema_privileges
608                    .iter()
609                    .any(|u| u.name == name && !u.has_usage && !u.has_create),
610                true
611            );
612        }
613
614        // Clean up
615        drop_user(&mut db, &name);
616    }
617
618    // Test get_user_database_privileges
619    #[test]
620    fn test_get_user_database_privileges() {
621        let url = "postgres://postgres:postgres@localhost:5432/postgres";
622        let mut db = DbConnection::from_str(url).unwrap();
623
624        let name = random_str();
625        let user = User {
626            name: name.to_owned(),
627            user_createdb: false,
628            user_super: false,
629            password: "duyet".to_string(),
630        };
631        drop_user(&mut db, &name);
632        create_user(&mut db, &user);
633
634        // get user roles
635        let user_database_privileges = db.get_user_database_privileges().unwrap_or_default();
636
637        // Check if user_database_privileges contains current users
638        // is empty if the user doesn't have any database privileges
639        assert_eq!(
640            user_database_privileges
641                .iter()
642                .any(|u| u.name == name && u.has_create),
643            false
644        );
645
646        // FIXME seriously test this function
647
648        // Clean up
649        drop_user(&mut db, &name);
650    }
651
652    // Test get_user_schema_privileges
653    #[test]
654    fn test_get_user_schema_privileges() {
655        let url = "postgres://postgres:postgres@localhost:5432/postgres";
656        let mut db = DbConnection::from_str(url).unwrap();
657
658        let name = random_str();
659        let password = random_str();
660        let user = User {
661            name: name.to_owned(),
662            user_createdb: false,
663            user_super: false,
664            password,
665        };
666        drop_user(&mut db, &name);
667        create_user(&mut db, &user);
668
669        // get user roles
670        let user_schema_privileges = db.get_user_schema_privileges().unwrap_or_default();
671        println!("{:?}", user_schema_privileges);
672
673        // Check if user_schema_privileges contains current users
674        // is empty if the user doesn't have any schema privileges
675        // assert_eq!(user_schema_privileges.iter().any(|u| u.name == name), false);
676
677        // FIXME seriously test this function
678
679        // Clean up
680        drop_user(&mut db, &name);
681    }
682
683    // Test get_user_tables_privileges
684    #[test]
685    fn test_get_user_table_privileges() {
686        let url = "postgres://postgres:postgres@localhost:5432/postgres";
687        let mut db = DbConnection::from_str(url).unwrap();
688
689        let name = random_str();
690        let password = random_str();
691        let user = User {
692            name: name.to_owned(),
693            user_createdb: false,
694            user_super: false,
695            password,
696        };
697        drop_user(&mut db, &name);
698        create_user(&mut db, &user);
699
700        // get user roles
701        let user_table_privileges = db.get_user_table_privileges().unwrap_or_default();
702
703        // Check if user_tables_privileges contains current users
704        // is empty if the user doesn't have any tables privileges
705        assert_eq!(
706            user_table_privileges
707                .iter()
708                .any(|u| u.name == name && u.has_select),
709            false
710        );
711
712        // FIXME seriously test this function
713
714        // Clean up
715        drop_user(&mut db, &name);
716    }
717
718    fn random_str() -> String {
719        const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz";
720        let mut rng = thread_rng();
721
722        let name: String = (0..10)
723            .map(|_| {
724                let idx = rng.gen_range(0..CHARSET.len());
725                CHARSET[idx] as char
726            })
727            .collect();
728
729        name
730    }
731}