1use crate::config::connection::ConnectionType;
2use crate::config::Config;
3use anyhow::{anyhow, Result};
4use log::{debug, info};
5use postgres::{row::Row, types::ToSql, Client, Config as ConnConfig, NoTls, ToStatement};
6
7pub struct DbConnection {
12 pub connection_info: String,
13 pub client: Client,
14 conn_config: ConnConfig,
15}
16
17#[derive(Debug, Clone)]
19pub struct User {
20 pub name: String,
21 pub user_createdb: bool,
22 pub user_super: bool,
23 pub password: String,
24}
25
26#[derive(Debug)]
29pub struct UserDatabaseRole {
30 pub name: String,
31 pub database_name: String,
32 pub has_create: bool,
33 pub has_temp: bool,
34}
35
36impl UserDatabaseRole {
37 pub fn perm_to_string(&self, with_name: bool) -> String {
38 if with_name {
39 return format!("{}({})", self.database_name, self.perm_to_string(false));
40 }
41
42 match (self.has_create, self.has_temp) {
43 (true, true) => "A".to_string(),
44 (true, false) => "C".to_string(),
45 (false, true) => "T".to_string(),
46 (false, false) => "".to_string(),
47 }
48 }
49}
50
51#[derive(Debug)]
54pub struct UserSchemaRole {
55 pub name: String,
56 pub schema_name: String,
57 pub has_create: bool,
58 pub has_usage: bool,
59}
60
61impl UserSchemaRole {
62 pub fn perm_to_string(&self, with_name: bool) -> String {
63 if with_name {
64 return format!("{}({})", self.schema_name, self.perm_to_string(false));
65 }
66
67 match (self.has_create, self.has_usage) {
68 (true, true) => "A".to_string(),
69 (true, false) => "C".to_string(),
70 (false, true) => "U".to_string(),
71 _ => "".to_string(),
72 }
73 }
74}
75
76#[derive(Debug)]
79pub struct UserTableRole {
80 pub name: String,
81 pub schema_name: String,
82 pub table_name: String,
83 pub has_select: bool,
84 pub has_insert: bool,
85 pub has_update: bool,
86 pub has_delete: bool,
87 pub has_references: bool,
88}
89
90impl UserTableRole {
91 pub fn perm_to_string(&self, with_name: bool) -> String {
92 if with_name {
93 return format!(
94 "{}.{}({})",
95 self.schema_name,
96 self.table_name,
97 self.perm_to_string(false)
98 );
99 }
100
101 if self.has_select
102 && self.has_insert
103 && self.has_update
104 && self.has_delete
105 && self.has_references
106 {
107 return "A".to_string();
108 }
109
110 let has_select = if self.has_select { "S" } else { "" };
111 let has_insert = if self.has_insert { "I" } else { "" };
112 let has_update = if self.has_update { "U" } else { "" };
113 let has_delete = if self.has_delete { "D" } else { "" };
114 let has_references = if self.has_references { "R" } else { "" };
115 format!(
116 "{}{}{}{}{}",
117 has_select, has_insert, has_update, has_delete, has_references
118 )
119 }
120}
121
122impl DbConnection {
123 pub fn new(config: &Config) -> Result<Self> {
149 match config.connection.type_ {
150 ConnectionType::Postgres => {
151 let connection_info = config.connection.url.clone();
152 let mut client = Client::connect(&connection_info, NoTls).map_err(|e| {
153 anyhow!("Failed to connect to database '{}': {}", connection_info, e)
154 })?;
155
156 if let Err(e) = client.simple_query("SELECT 1") {
157 return Err(anyhow!(
158 "Database connection test failed for '{}': {}",
159 connection_info,
160 e
161 ));
162 } else {
163 info!("Connected to database: {}", connection_info);
164 }
165
166 let conn_config = connection_info.parse::<ConnConfig>().map_err(|e| {
167 anyhow!(
168 "Failed to parse connection string '{}': {}",
169 connection_info,
170 e
171 )
172 })?;
173
174 Ok(DbConnection {
175 connection_info,
176 client,
177 conn_config,
178 })
179 }
180 }
181 }
182
183 pub fn get_current_database(&self) -> Option<&str> {
185 self.conn_config.get_dbname()
186 }
187
188 pub fn connection_info(self) -> String {
199 self.connection_info
200 }
201
202 pub fn get_users(&mut self) -> Result<Vec<User>> {
204 let mut users = vec![];
205
206 let sql = "SELECT usename, usecreatedb, usesuper, passwd FROM pg_user";
208 let stmt = self
209 .client
210 .prepare(sql)
211 .map_err(|e| anyhow!("Failed to prepare query for users: {}", e))?;
212
213 debug!("executing: {}", sql);
214 let rows = self
215 .client
216 .query(&stmt, &[])
217 .map_err(|e| anyhow!("Failed to query users from pg_user: {}", e))?;
218
219 for row in rows {
220 match (row.get(0), row.get(1), row.get(2), row.get(3)) {
221 (Some(name), Some(user_createdb), Some(user_super), Some(password)) => {
222 users.push(User {
223 name,
224 user_createdb,
225 user_super,
226 password,
227 })
228 }
229 (Some(name), _, _, _) => users.push(User {
230 name,
231 user_createdb: false,
232 user_super: false,
233 password: String::from(""),
234 }),
235 (_, _, _, _) => (),
236 }
237 }
238
239 debug!("get_users: {:#?}", users);
240
241 Ok(users)
242 }
243
244 pub fn get_user_database_privileges(&mut self) -> Result<Vec<UserDatabaseRole>> {
247 let mut roles = vec![];
248
249 let sql = r#"
250 WITH db AS (
251 SELECT d.datname AS database_name
252 FROM pg_database d
253 ),
254 users AS (
255 SELECT usename as user_name FROM pg_user
256 )
257 SELECT
258 u.user_name,
259 db.database_name,
260 pg_catalog.has_database_privilege(u.user_name, database_name, 'CREATE') AS "create",
261 pg_catalog.has_database_privilege(u.user_name, database_name, 'TEMP') AS "temp"
262 FROM db CROSS JOIN users u;
263 "#;
264
265 let stmt = self
266 .client
267 .prepare(sql)
268 .map_err(|e| anyhow!("Failed to prepare query for database privileges: {}", e))?;
269
270 debug!("executing: {}", sql);
271 let rows = self
272 .client
273 .query(&stmt, &[])
274 .map_err(|e| anyhow!("Failed to query database privileges: {}", e))?;
275 for row in rows {
276 let name: &str = row.get(0);
277 let database_name: &str = row.get(1);
278 let has_create: bool = row.get(2);
279 let has_temp: bool = row.get(3);
280
281 roles.push(UserDatabaseRole {
282 name: name.to_string(),
283 database_name: database_name.to_string(),
284 has_create,
285 has_temp,
286 })
287 }
288
289 Ok(roles)
290 }
291
292 pub fn get_user_schema_privileges(&mut self) -> Result<Vec<UserSchemaRole>> {
294 let sql = "
296 SELECT
297 u.usename AS name,
298 s.schemaname AS schema_name,
299 has_schema_privilege(u.usename, s.schemaname, 'create') AS has_create,
300 has_schema_privilege(u.usename, s.schemaname, 'usage') AS has_usage
301 FROM
302 pg_user u
303 CROSS JOIN (SELECT DISTINCT schemaname FROM pg_tables) s
304 WHERE
305 1 = 1
306 AND s.schemaname != 'pg_catalog'
307 AND s.schemaname != 'information_schema';
308 ";
309
310 let stmt = self
311 .client
312 .prepare(sql)
313 .map_err(|e| anyhow!("Failed to prepare query for schema privileges: {}", e))?;
314
315 debug!("executing: {}", sql);
316 let rows = self
317 .client
318 .query(&stmt, &[])
319 .map_err(|e| anyhow!("Failed to query schema privileges: {}", e))?;
320 let mut roles = vec![];
321 for row in rows {
322 let name = row.get(0);
323 let schema_name = row.get(1);
324 let has_create = row.get(2);
325 let has_usage = row.get(3);
326 if let (Some(name), Some(schema_name), Some(has_create), Some(has_usage)) =
327 (name, schema_name, has_create, has_usage)
328 {
329 roles.push(UserSchemaRole {
330 name,
331 schema_name,
332 has_create,
333 has_usage,
334 })
335 }
336 }
337
338 Ok(roles)
339 }
340
341 pub fn get_user_table_privileges(&mut self) -> Result<Vec<UserTableRole>> {
343 let mut roles = vec![];
344 let sql = "
345 SELECT
346 u.usename AS name,
347 t.schemaname AS schema_name,
348 t.tablename AS table_name,
349 has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'select') AS has_select,
350 has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'insert') AS has_insert,
351 has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'update') AS has_update,
352 has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'delete') AS has_delete,
353 has_table_privilege(u.usename, t.schemaname || '.' || t.tablename, 'references') AS has_references
354 FROM
355 pg_user u
356 CROSS JOIN (SELECT DISTINCT schemaname, tablename FROM pg_tables) t
357 WHERE 1 = 1
358 AND t.schemaname NOT LIKE 'pg_%'
359 AND t.schemaname != 'information_schema';
360 ";
361
362 let stmt = self
363 .client
364 .prepare(sql)
365 .map_err(|e| anyhow!("Failed to prepare query for table privileges: {}", e))?;
366
367 debug!("executing: {}", sql);
368 let rows = self
369 .client
370 .query(&stmt, &[])
371 .map_err(|e| anyhow!("Failed to query table privileges: {}", e))?;
372 for row in rows {
373 let name = row.get(0);
374 let schema_name = row.get(1);
375 let table_name = row.get(2);
376 let has_select = row.get(3);
377 let has_insert = row.get(4);
378 let has_update = row.get(5);
379 let has_delete = row.get(6);
380 let has_references = row.get(7);
381
382 if let (
383 Some(name),
384 Some(schema_name),
385 Some(table_name),
386 Some(has_select),
387 Some(has_insert),
388 Some(has_update),
389 Some(has_delete),
390 Some(has_references),
391 ) = (
392 name,
393 schema_name,
394 table_name,
395 has_select,
396 has_insert,
397 has_update,
398 has_delete,
399 has_references,
400 ) {
401 roles.push(UserTableRole {
402 name,
403 schema_name,
404 table_name,
405 has_insert,
406 has_select,
407 has_update,
408 has_delete,
409 has_references,
410 })
411 }
412 }
413
414 Ok(roles)
415 }
416
417 pub fn query<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>>
437 where
438 T: ?Sized + ToStatement,
439 {
440 let ri = self.client.query(query, params)?;
441 Ok(ri)
442 }
443
444 pub fn execute(&mut self, query: &str, params: &[&(dyn ToSql + Sync)]) -> Result<i64> {
460 let queries = query.split(';');
467 let mut rows_affected = 0;
468
469 for query in queries {
470 let query = query.trim();
471 if query.is_empty() {
472 continue;
473 }
474
475 let stmt = self.client.prepare(query)?;
476 let rows = self.client.execute(&stmt, params)?;
477 rows_affected += rows;
478 }
479
480 rows_affected
481 .try_into()
482 .map_err(|e| anyhow!("Row count {} exceeds i64::MAX: {}", rows_affected, e))
483 }
484}
485
486impl std::str::FromStr for DbConnection {
487 type Err = anyhow::Error;
488
489 fn from_str(connection_info: &str) -> Result<Self> {
500 let client = Client::connect(connection_info, NoTls)
501 .map_err(|e| anyhow!("Failed to connect to database: {}", e))?;
502 let conn_config = connection_info
503 .parse::<ConnConfig>()
504 .map_err(|e| anyhow!("Failed to parse connection string: {}", e))?;
505
506 Ok(Self {
507 connection_info: connection_info.to_owned(),
508 client,
509 conn_config,
510 })
511 }
512}
513
514#[cfg(test)]
516mod tests {
517 use super::*;
518 use rand::{thread_rng, Rng};
519 use std::str::FromStr;
520
521 fn drop_user(db: &mut DbConnection, name: &str) {
522 let sql = &format!("DROP USER IF EXISTS {}", name);
523 db.execute(sql, &[]).unwrap();
524 }
525
526 fn create_user(db: &mut DbConnection, user: &User) {
527 let mut sql = format!("CREATE USER {} ", user.name);
528 if user.user_createdb {
529 sql += "CREATEDB"
530 }
531 if !user.password.is_empty() {
532 sql += &format!(" PASSWORD '{}'", user.password)
533 }
534
535 db.execute(&sql, &[]).unwrap();
536 }
537
538 #[test]
539 fn test_drop_user() {
540 let url = "postgres://postgres:postgres@localhost:5432/postgres";
541 let mut db = DbConnection::from_str(url).unwrap();
542
543 let name = random_str();
544 let user = User {
545 name: name.to_owned(),
546 user_createdb: false,
547 user_super: false,
548 password: "duyet".to_string(),
549 };
550
551 drop_user(&mut db, &name);
552 create_user(&mut db, &user);
553 drop_user(&mut db, &name);
554
555 let users = db.get_users().unwrap_or_default();
556 assert_eq!(users.iter().any(|u| u.name == name), false);
557
558 drop_user(&mut db, &name);
560 }
561
562 #[test]
563 fn test_drop_create_user() {
564 let url = "postgres://postgres:postgres@localhost:5432/postgres";
565 let mut db = DbConnection::from_str(url).unwrap();
566
567 let name = random_str();
568 let user = User {
569 name: name.to_owned(),
570 user_createdb: false,
571 user_super: false,
572 password: "duyet".to_string(),
573 };
574 drop_user(&mut db, &name);
575 create_user(&mut db, &user);
576
577 let users = db.get_users().unwrap();
578
579 assert_eq!(users.iter().any(|u| u.name == name), true);
580
581 drop_user(&mut db, &name);
583 }
584
585 #[test]
586 fn test_get_schema_roles() {
587 let url = "postgres://postgres:postgres@localhost:5432/postgres";
588 let mut db = DbConnection::from_str(url).unwrap();
589
590 let name = random_str();
591 let user = User {
592 name: name.to_owned(),
593 user_createdb: false,
594 user_super: false,
595 password: "duyet".to_string(),
596 };
597 drop_user(&mut db, &name);
598 create_user(&mut db, &user);
599
600 let user_schema_privileges = db.get_user_schema_privileges().unwrap_or_default();
602
603 if !user_schema_privileges.is_empty() {
605 assert_eq!(
607 user_schema_privileges
608 .iter()
609 .any(|u| u.name == name && !u.has_usage && !u.has_create),
610 true
611 );
612 }
613
614 drop_user(&mut db, &name);
616 }
617
618 #[test]
620 fn test_get_user_database_privileges() {
621 let url = "postgres://postgres:postgres@localhost:5432/postgres";
622 let mut db = DbConnection::from_str(url).unwrap();
623
624 let name = random_str();
625 let user = User {
626 name: name.to_owned(),
627 user_createdb: false,
628 user_super: false,
629 password: "duyet".to_string(),
630 };
631 drop_user(&mut db, &name);
632 create_user(&mut db, &user);
633
634 let user_database_privileges = db.get_user_database_privileges().unwrap_or_default();
636
637 assert_eq!(
640 user_database_privileges
641 .iter()
642 .any(|u| u.name == name && u.has_create),
643 false
644 );
645
646 drop_user(&mut db, &name);
650 }
651
652 #[test]
654 fn test_get_user_schema_privileges() {
655 let url = "postgres://postgres:postgres@localhost:5432/postgres";
656 let mut db = DbConnection::from_str(url).unwrap();
657
658 let name = random_str();
659 let password = random_str();
660 let user = User {
661 name: name.to_owned(),
662 user_createdb: false,
663 user_super: false,
664 password,
665 };
666 drop_user(&mut db, &name);
667 create_user(&mut db, &user);
668
669 let user_schema_privileges = db.get_user_schema_privileges().unwrap_or_default();
671 println!("{:?}", user_schema_privileges);
672
673 drop_user(&mut db, &name);
681 }
682
683 #[test]
685 fn test_get_user_table_privileges() {
686 let url = "postgres://postgres:postgres@localhost:5432/postgres";
687 let mut db = DbConnection::from_str(url).unwrap();
688
689 let name = random_str();
690 let password = random_str();
691 let user = User {
692 name: name.to_owned(),
693 user_createdb: false,
694 user_super: false,
695 password,
696 };
697 drop_user(&mut db, &name);
698 create_user(&mut db, &user);
699
700 let user_table_privileges = db.get_user_table_privileges().unwrap_or_default();
702
703 assert_eq!(
706 user_table_privileges
707 .iter()
708 .any(|u| u.name == name && u.has_select),
709 false
710 );
711
712 drop_user(&mut db, &name);
716 }
717
718 fn random_str() -> String {
719 const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz";
720 let mut rng = thread_rng();
721
722 let name: String = (0..10)
723 .map(|_| {
724 let idx = rng.gen_range(0..CHARSET.len());
725 CHARSET[idx] as char
726 })
727 .collect();
728
729 name
730 }
731}