Skip to main content

athena/
apply.rs

1//! AWS Athena query execution functionality
2//!
3//! This module handles building SQL templates and executing them in AWS Athena.
4//! It provides functionality to:
5//! - Submit queries to Athena
6//! - Poll for query completion
7//! - Retrieve query results
8//! - Extract database context from SQL comments
9//!
10//! # Database Context
11//!
12//! You can specify the target database using SQL comments:
13//!
14//! ```sql
15//! -- Database: my_database
16//! CREATE TABLE example (id INT);
17//! ```
18//!
19//! or
20//!
21//! ```sql
22//! /* Database: my_database */
23//! CREATE TABLE example (id INT);
24//! ```
25
26use anyhow::{anyhow, bail, Context, Result};
27use aws_sdk_athena::{
28    operation::get_query_execution::GetQueryExecutionOutput,
29    types::{QueryExecutionContext, QueryExecutionState, ResultConfiguration, ResultSet},
30    Client,
31};
32use devtimer::DevTime;
33use log::{error, info};
34use once_cell::sync::Lazy;
35use regex::Regex;
36use std::{collections::HashMap, env, path::PathBuf};
37use tokio::time::{sleep, Duration};
38
39use crate::utils::pretty_print;
40
41// Constants
42const QUERY_POLL_INTERVAL_SECS: u64 = 5;
43const SQL_STATEMENT_SEPARATOR: char = ';';
44
45// Compile regex patterns once and reuse them for extracting database names from SQL
46static DATABASE_PATTERNS: Lazy<Vec<Regex>> = Lazy::new(|| {
47    vec![
48        // Matches: -- Database: db_name
49        Regex::new(r"(?i)--\s+Database:\s(.*)").expect("invalid regex pattern"),
50        // Matches: /* Database: db_name */
51        Regex::new(r"(?i)/*\s+Database:\s([^\s]+)\s\*/").expect("invalid regex pattern"),
52    ]
53});
54
55#[derive(clap::Args, Debug, Clone)]
56pub struct Apply {
57    /// Target path to render. If the target path is a directory,
58    /// the root folder must contains the index.sql file
59    pub file: PathBuf,
60
61    /// Change the context current working dir
62    #[arg(long, short)]
63    pub context: Option<PathBuf>,
64
65    /// Dry-run
66    #[arg(global = true, long, short)]
67    pub dry_run: Option<bool>,
68
69    /// AWS Profile
70    /// Set this option via environment variable: export AWS_PROFILE=default
71    #[arg(global = true, long, short)]
72    pub profile: Option<String>,
73
74    /// AWS Region
75    #[arg(global = true, long, short)]
76    /// Set this option via environment variable: export AWS_DEFAULT_REGION=us-east-1
77    pub region: Option<String>,
78
79    /// AWS Athena Workgroup
80    /// Set this option via environment variable: export AWS_WORKGROUP=primary
81    #[arg(global = true, long, short)]
82    pub workgroup: Option<String>,
83
84    /// AWS Athena output location
85    /// The location in Amazon S3 where your query results are stored
86    /// such as `s3://path/to/query/bucket/`
87    /// Set this option via environment variable: export AWS_OUTPUT_LOCATION=s3://bucket/
88    #[arg(global = true, long, short)]
89    pub output_location: Option<String>,
90
91    /// No pretty print for SQL
92    #[arg(long)]
93    pub no_pretty: Option<bool>,
94}
95
96pub async fn call(args: Apply) -> Result<()> {
97    let build_args = crate::build::Build {
98        file: args.file.clone(),
99        out: None,
100        context: args.context.clone(),
101        no_pretty: None,
102    };
103
104    let sql = crate::build::build(build_args)?;
105    if args.no_pretty.unwrap_or_default() {
106        print!("{}", sql);
107    } else {
108        pretty_print(sql.as_bytes());
109    }
110
111    // Set AWS_PROFILE
112    if let Some(ref profile) = args.profile {
113        std::env::set_var("AWS_PROFILE", profile);
114    }
115
116    // Set AWS_DEFAULT_REGION
117    if let Some(ref region) = args.region {
118        std::env::set_var("AWS_DEFAULT_REGION", region);
119    }
120
121    let shared_config = aws_config::load_from_env().await;
122    let client = Client::new(&shared_config);
123
124    // Healthcheck
125    submit_and_wait(client.clone(), Some("SELECT 1".to_string()), args.clone()).await?;
126
127    // Submit SQL
128    let sql = sql
129        .split(SQL_STATEMENT_SEPARATOR)
130        .into_iter()
131        .map(|s| s.trim())
132        .filter(|s| !s.is_empty())
133        .collect::<Vec<_>>();
134    info!("Submitting {} queries to Athena", sql.len());
135
136    let mut stats: HashMap<QueryExecutionState, i32> = HashMap::new();
137
138    // Timer
139    let mut timer = DevTime::new_simple();
140    timer.start();
141
142    for s in sql {
143        let state = submit_and_wait(client.clone(), Some(s.to_string()), args.clone()).await?;
144
145        // Update stats
146        stats
147            .entry(state.clone())
148            .and_modify(|c| *c += 1)
149            .or_insert(0);
150    }
151
152    timer.stop();
153
154    info!("");
155    info!("Statistics:");
156    info!("  ==> {:?}", stats);
157    if let Some(secs) = timer.time_in_secs() {
158        info!("  ==> Took: {:?} seconds", secs);
159    }
160
161    Ok(())
162}
163
164fn get_result_configuration(args: Apply) -> ResultConfiguration {
165    let output_location = args
166        .output_location
167        .or_else(|| env::var("AWS_OUTPUT_LOCATION").ok());
168
169    ResultConfiguration::builder()
170        .set_output_location(output_location)
171        .build()
172}
173
174fn get_query_execution_context(query: Option<String>) -> Option<QueryExecutionContext> {
175    let query = query.as_ref()?;
176
177    let database = get_database_from_sql(query);
178    database.as_ref()?;
179
180    let ctx = QueryExecutionContext::builder()
181        .set_database(database)
182        .build();
183
184    Some(ctx)
185}
186
187async fn submit_and_wait(
188    client: Client,
189    query: Option<String>,
190    args: Apply,
191) -> Result<QueryExecutionState> {
192    if query.clone().is_none() {
193        bail!("Empty query");
194    }
195
196    // Timer
197    let mut timer = DevTime::new_simple();
198    timer.start();
199
200    let workgroup = args.workgroup.clone();
201    let result_configuration = get_result_configuration(args.clone());
202    let query_execution_context = get_query_execution_context(query.clone());
203    let query = query.unwrap();
204
205    match &query_execution_context {
206        Some(ctx) => match ctx.database() {
207            Some(database) => info!("\nSubmitting to database `{}`: ", database),
208            _ => info!("\nSubmitting ..."),
209        },
210        _ => info!("\nSubmitting ..."),
211    }
212
213    if args.no_pretty.unwrap_or_default() {
214        print!("{}", query);
215    } else {
216        pretty_print(query.as_bytes());
217    }
218
219    let resp = client
220        .start_query_execution()
221        .set_query_string(Some(query))
222        .set_work_group(workgroup)
223        .set_result_configuration(Some(result_configuration.clone()))
224        .set_query_execution_context(query_execution_context)
225        .send()
226        .await?;
227
228    let query_execution_id = resp
229        .query_execution_id()
230        .ok_or_else(|| anyhow!("query execution id not found in response"))?;
231    info!("Query execution id: {}", &query_execution_id);
232
233    let mut state: QueryExecutionState;
234
235    loop {
236        let resp = client
237            .get_query_execution()
238            .set_query_execution_id(Some(query_execution_id.to_string()))
239            .send()
240            .await?;
241
242        state = status(&resp)
243            .ok_or_else(|| anyhow!("could not get query execution status from response"))?
244            .clone();
245
246        match state {
247            QueryExecutionState::Queued | QueryExecutionState::Running => {
248                sleep(Duration::from_secs(QUERY_POLL_INTERVAL_SECS)).await;
249                info!(
250                    "State: {:?}, sleeping {} secs ...",
251                    state, QUERY_POLL_INTERVAL_SECS
252                );
253            }
254            QueryExecutionState::Cancelled | QueryExecutionState::Failed => {
255                error!("State: {:?}", state);
256
257                match get_query_result(&client, query_execution_id.to_string()).await {
258                    Ok(result) => info!("Result: {:?}", result),
259                    Err(e) => error!("Result error: {:?}", e),
260                }
261
262                break;
263            }
264            _ => {
265                info!("State: {:?}", state);
266                if let Some(millis) = total_execution_time(&resp) {
267                    info!("Total execution time: {} millis", millis);
268                }
269
270                match get_query_result(&client, query_execution_id.to_string()).await {
271                    Ok(result) => info!("Result: {:?}", result),
272                    Err(e) => error!("Result error: {:?}", e),
273                }
274
275                break;
276            }
277        }
278    }
279
280    timer.stop();
281    if let Some(secs) = timer.time_in_secs() {
282        info!("Took: {} secs", secs);
283    }
284
285    Ok(state.clone())
286}
287
288fn status(resp: &GetQueryExecutionOutput) -> Option<&QueryExecutionState> {
289    resp.query_execution()
290        .and_then(|qe| qe.status())
291        .and_then(|s| s.state())
292}
293
294fn total_execution_time(resp: &GetQueryExecutionOutput) -> Option<i64> {
295    resp.query_execution()
296        .and_then(|qe| qe.statistics())
297        .and_then(|s| s.total_execution_time_in_millis())
298}
299
300async fn get_query_result(client: &Client, query_execution_id: String) -> Result<ResultSet> {
301    let resp = client
302        .get_query_results()
303        .set_query_execution_id(Some(query_execution_id.clone()))
304        .send()
305        .await
306        .with_context(|| {
307            format!(
308                "could not get query results for query id {}",
309                query_execution_id
310            )
311        })?;
312
313    Ok(resp
314        .result_set()
315        .ok_or_else(|| anyhow!("could not get query result"))?
316        .clone())
317}
318
319fn get_database_from_sql<S: AsRef<str>>(sql: S) -> Option<String> {
320    for r in DATABASE_PATTERNS.iter() {
321        if let Some(caps) = r.captures(sql.as_ref()) {
322            let name = caps.get(1).map_or("", |m| m.as_str());
323            return Some(name.trim().to_string());
324        }
325    }
326
327    None
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_get_database_from_sql() {
336        let sql = "-- database: db0";
337        assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
338
339        let sql = "-- database: db1\nSELECT * FROM ...;";
340        assert_eq!(get_database_from_sql(sql).unwrap(), "db1");
341
342        let sql = "-- Database: db2\nSELECT * FROM ...;";
343        assert_eq!(get_database_from_sql(sql).unwrap(), "db2");
344
345        let sql = "-- Database: db3 \nSELECT * FROM ...;";
346        assert_eq!(get_database_from_sql(sql).unwrap(), "db3");
347
348        let sql = "-- Database: db4    \nSELECT * FROM ...;";
349        assert_eq!(get_database_from_sql(sql).unwrap(), "db4");
350
351        let sql = "--   Database: db4    \nSELECT * FROM ...;";
352        assert_eq!(get_database_from_sql(sql).unwrap(), "db4");
353
354        let sql = "/* Database: db5 */\nSELECT * FROM ...;";
355        assert_eq!(get_database_from_sql(sql).unwrap(), "db5");
356
357        let sql = "/* database: db6 */\nSELECT * FROM ...;";
358        assert_eq!(get_database_from_sql(sql).unwrap(), "db6");
359
360        let sql = "/*        database: db7 */\nSELECT * FROM ...;";
361        assert_eq!(get_database_from_sql(sql).unwrap(), "db7");
362
363        let sql = "SELECT * FROM ...;";
364        assert!(get_database_from_sql(sql).is_none());
365
366        let sql = "-- database: db0 \n-- database: db1";
367        assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
368
369        let sql = "/* database: db0 */\n/* database: db1 */";
370        assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
371    }
372
373    #[test]
374    fn test_get_database_from_sql_with_comment() {
375        let sql = "-- database: db0\n-- comment\nSELECT * FROM ...;";
376        assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
377
378        let sql = "-- database: db0\n-- comment\n-- comment\nSELECT * FROM ...;";
379        assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
380
381        let sql = "-- database: db0\n-- comment\n-- comment\n-- comment\nSELECT * FROM ...;";
382        assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
383    }
384}