1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::fmt;

pub use super::role_database::RoleDatabaseLevel;
pub use super::role_schema::RoleSchemaLevel;
pub use super::role_table::RoleTableLevel;

/// Level type for role.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(tag = "level")]
pub enum RoleLevelType {
    Database,
    Schema,
    Table,
}

impl fmt::Display for RoleLevelType {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            RoleLevelType::Database => write!(f, "database"),
            RoleLevelType::Schema => write!(f, "schema"),
            RoleLevelType::Table => write!(f, "table"),
        }
    }
}

/// Configuration for a role.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
#[serde(tag = "type")]
pub enum Role {
    #[serde(rename = "database")]
    Database(RoleDatabaseLevel),
    #[serde(rename = "schema")]
    Schema(RoleSchemaLevel),
    #[serde(rename = "table")]
    Table(RoleTableLevel),
}

pub trait RoleValidate {
    fn validate(&self) -> Result<()>;
}

impl Role {
    pub fn to_sql(&self, user: &str) -> String {
        match self {
            Role::Database(role) => role.to_sql(user),
            Role::Schema(role) => role.to_sql(user),
            Role::Table(role) => role.to_sql(user),
        }
    }

    pub fn validate(&self) -> Result<()> {
        match self {
            Role::Database(role) => role.validate(),
            Role::Schema(role) => role.validate(),
            Role::Table(role) => role.validate(),
        }
    }

    pub fn get_name(&self) -> String {
        match self {
            Role::Database(role) => role.name.clone(),
            Role::Schema(role) => role.name.clone(),
            Role::Table(role) => role.name.clone(),
        }
    }

    pub fn find(&self, name: &str) -> bool {
        // role name can contain '-', so we need to remove it before comparing
        let name = name.replace('-', "");

        match self {
            Role::Database(role) => role.name == name,
            Role::Schema(role) => role.name == name,
            Role::Table(role) => role.name == name,
        }
    }

    pub fn get_level(&self) -> RoleLevelType {
        match self {
            Role::Database(_role) => RoleLevelType::Database,
            Role::Schema(_role) => RoleLevelType::Schema,
            Role::Table(_role) => RoleLevelType::Table,
        }
    }

    pub fn get_grants(&self) -> Vec<String> {
        match self {
            Role::Database(role) => role.grants.clone(),
            Role::Schema(role) => role.grants.clone(),
            Role::Table(role) => role.grants.clone(),
        }
    }

    pub fn get_databases(&self) -> Vec<String> {
        match self {
            Role::Database(role) => role.databases.clone(),
            Role::Schema(_) => vec![],
            Role::Table(_) => vec![],
        }
    }

    pub fn get_schemas(&self) -> Vec<String> {
        match self {
            Role::Database(_) => vec![],
            Role::Schema(role) => role.schemas.clone(),
            Role::Table(role) => role.schemas.clone(),
        }
    }

    pub fn get_tables(&self) -> Vec<String> {
        match self {
            Role::Database(_) => vec![],
            Role::Schema(_) => vec![],
            Role::Table(role) => role.tables.clone(),
        }
    }
}