1use anyhow::{anyhow, bail, Context, Result};
27use aws_sdk_athena::{
28 operation::get_query_execution::GetQueryExecutionOutput,
29 types::{QueryExecutionContext, QueryExecutionState, ResultConfiguration, ResultSet},
30 Client,
31};
32use devtimer::DevTime;
33use log::{error, info};
34use once_cell::sync::Lazy;
35use regex::Regex;
36use std::{collections::HashMap, env, path::PathBuf};
37use tokio::time::{sleep, Duration};
38
39use crate::utils::pretty_print;
40
41const QUERY_POLL_INTERVAL_SECS: u64 = 5;
43const SQL_STATEMENT_SEPARATOR: char = ';';
44
45static DATABASE_PATTERNS: Lazy<Vec<Regex>> = Lazy::new(|| {
47 vec![
48 Regex::new(r"(?i)--\s+Database:\s(.*)").expect("invalid regex pattern"),
50 Regex::new(r"(?i)/*\s+Database:\s([^\s]+)\s\*/").expect("invalid regex pattern"),
52 ]
53});
54
55#[derive(clap::Args, Debug, Clone)]
56pub struct Apply {
57 pub file: PathBuf,
60
61 #[arg(long, short)]
63 pub context: Option<PathBuf>,
64
65 #[arg(global = true, long, short)]
67 pub dry_run: Option<bool>,
68
69 #[arg(global = true, long, short)]
72 pub profile: Option<String>,
73
74 #[arg(global = true, long, short)]
76 pub region: Option<String>,
78
79 #[arg(global = true, long, short)]
82 pub workgroup: Option<String>,
83
84 #[arg(global = true, long, short)]
89 pub output_location: Option<String>,
90
91 #[arg(long)]
93 pub no_pretty: Option<bool>,
94}
95
96pub async fn call(args: Apply) -> Result<()> {
97 let build_args = crate::build::Build {
98 file: args.file.clone(),
99 out: None,
100 context: args.context.clone(),
101 no_pretty: None,
102 };
103
104 let sql = crate::build::build(build_args)?;
105 if args.no_pretty.unwrap_or_default() {
106 print!("{}", sql);
107 } else {
108 pretty_print(sql.as_bytes());
109 }
110
111 if let Some(ref profile) = args.profile {
113 std::env::set_var("AWS_PROFILE", profile);
114 }
115
116 if let Some(ref region) = args.region {
118 std::env::set_var("AWS_DEFAULT_REGION", region);
119 }
120
121 let shared_config = aws_config::load_from_env().await;
122 let client = Client::new(&shared_config);
123
124 submit_and_wait(client.clone(), Some("SELECT 1".to_string()), args.clone()).await?;
126
127 let sql = sql
129 .split(SQL_STATEMENT_SEPARATOR)
130 .into_iter()
131 .map(|s| s.trim())
132 .filter(|s| !s.is_empty())
133 .collect::<Vec<_>>();
134 info!("Submitting {} queries to Athena", sql.len());
135
136 let mut stats: HashMap<QueryExecutionState, i32> = HashMap::new();
137
138 let mut timer = DevTime::new_simple();
140 timer.start();
141
142 for s in sql {
143 let state = submit_and_wait(client.clone(), Some(s.to_string()), args.clone()).await?;
144
145 stats
147 .entry(state.clone())
148 .and_modify(|c| *c += 1)
149 .or_insert(0);
150 }
151
152 timer.stop();
153
154 info!("");
155 info!("Statistics:");
156 info!(" ==> {:?}", stats);
157 if let Some(secs) = timer.time_in_secs() {
158 info!(" ==> Took: {:?} seconds", secs);
159 }
160
161 Ok(())
162}
163
164fn get_result_configuration(args: Apply) -> ResultConfiguration {
165 let output_location = args
166 .output_location
167 .or_else(|| env::var("AWS_OUTPUT_LOCATION").ok());
168
169 ResultConfiguration::builder()
170 .set_output_location(output_location)
171 .build()
172}
173
174fn get_query_execution_context(query: Option<String>) -> Option<QueryExecutionContext> {
175 let query = query.as_ref()?;
176
177 let database = get_database_from_sql(query);
178 database.as_ref()?;
179
180 let ctx = QueryExecutionContext::builder()
181 .set_database(database)
182 .build();
183
184 Some(ctx)
185}
186
187async fn submit_and_wait(
188 client: Client,
189 query: Option<String>,
190 args: Apply,
191) -> Result<QueryExecutionState> {
192 if query.clone().is_none() {
193 bail!("Empty query");
194 }
195
196 let mut timer = DevTime::new_simple();
198 timer.start();
199
200 let workgroup = args.workgroup.clone();
201 let result_configuration = get_result_configuration(args.clone());
202 let query_execution_context = get_query_execution_context(query.clone());
203 let query = query.unwrap();
204
205 match &query_execution_context {
206 Some(ctx) => match ctx.database() {
207 Some(database) => info!("\nSubmitting to database `{}`: ", database),
208 _ => info!("\nSubmitting ..."),
209 },
210 _ => info!("\nSubmitting ..."),
211 }
212
213 if args.no_pretty.unwrap_or_default() {
214 print!("{}", query);
215 } else {
216 pretty_print(query.as_bytes());
217 }
218
219 let resp = client
220 .start_query_execution()
221 .set_query_string(Some(query))
222 .set_work_group(workgroup)
223 .set_result_configuration(Some(result_configuration.clone()))
224 .set_query_execution_context(query_execution_context)
225 .send()
226 .await?;
227
228 let query_execution_id = resp
229 .query_execution_id()
230 .ok_or_else(|| anyhow!("query execution id not found in response"))?;
231 info!("Query execution id: {}", &query_execution_id);
232
233 let mut state: QueryExecutionState;
234
235 loop {
236 let resp = client
237 .get_query_execution()
238 .set_query_execution_id(Some(query_execution_id.to_string()))
239 .send()
240 .await?;
241
242 state = status(&resp)
243 .ok_or_else(|| anyhow!("could not get query execution status from response"))?
244 .clone();
245
246 match state {
247 QueryExecutionState::Queued | QueryExecutionState::Running => {
248 sleep(Duration::from_secs(QUERY_POLL_INTERVAL_SECS)).await;
249 info!(
250 "State: {:?}, sleeping {} secs ...",
251 state, QUERY_POLL_INTERVAL_SECS
252 );
253 }
254 QueryExecutionState::Cancelled | QueryExecutionState::Failed => {
255 error!("State: {:?}", state);
256
257 match get_query_result(&client, query_execution_id.to_string()).await {
258 Ok(result) => info!("Result: {:?}", result),
259 Err(e) => error!("Result error: {:?}", e),
260 }
261
262 break;
263 }
264 _ => {
265 info!("State: {:?}", state);
266 if let Some(millis) = total_execution_time(&resp) {
267 info!("Total execution time: {} millis", millis);
268 }
269
270 match get_query_result(&client, query_execution_id.to_string()).await {
271 Ok(result) => info!("Result: {:?}", result),
272 Err(e) => error!("Result error: {:?}", e),
273 }
274
275 break;
276 }
277 }
278 }
279
280 timer.stop();
281 if let Some(secs) = timer.time_in_secs() {
282 info!("Took: {} secs", secs);
283 }
284
285 Ok(state.clone())
286}
287
288fn status(resp: &GetQueryExecutionOutput) -> Option<&QueryExecutionState> {
289 resp.query_execution()
290 .and_then(|qe| qe.status())
291 .and_then(|s| s.state())
292}
293
294fn total_execution_time(resp: &GetQueryExecutionOutput) -> Option<i64> {
295 resp.query_execution()
296 .and_then(|qe| qe.statistics())
297 .and_then(|s| s.total_execution_time_in_millis())
298}
299
300async fn get_query_result(client: &Client, query_execution_id: String) -> Result<ResultSet> {
301 let resp = client
302 .get_query_results()
303 .set_query_execution_id(Some(query_execution_id.clone()))
304 .send()
305 .await
306 .with_context(|| {
307 format!(
308 "could not get query results for query id {}",
309 query_execution_id
310 )
311 })?;
312
313 Ok(resp
314 .result_set()
315 .ok_or_else(|| anyhow!("could not get query result"))?
316 .clone())
317}
318
319fn get_database_from_sql<S: AsRef<str>>(sql: S) -> Option<String> {
320 for r in DATABASE_PATTERNS.iter() {
321 if let Some(caps) = r.captures(sql.as_ref()) {
322 let name = caps.get(1).map_or("", |m| m.as_str());
323 return Some(name.trim().to_string());
324 }
325 }
326
327 None
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_get_database_from_sql() {
336 let sql = "-- database: db0";
337 assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
338
339 let sql = "-- database: db1\nSELECT * FROM ...;";
340 assert_eq!(get_database_from_sql(sql).unwrap(), "db1");
341
342 let sql = "-- Database: db2\nSELECT * FROM ...;";
343 assert_eq!(get_database_from_sql(sql).unwrap(), "db2");
344
345 let sql = "-- Database: db3 \nSELECT * FROM ...;";
346 assert_eq!(get_database_from_sql(sql).unwrap(), "db3");
347
348 let sql = "-- Database: db4 \nSELECT * FROM ...;";
349 assert_eq!(get_database_from_sql(sql).unwrap(), "db4");
350
351 let sql = "-- Database: db4 \nSELECT * FROM ...;";
352 assert_eq!(get_database_from_sql(sql).unwrap(), "db4");
353
354 let sql = "/* Database: db5 */\nSELECT * FROM ...;";
355 assert_eq!(get_database_from_sql(sql).unwrap(), "db5");
356
357 let sql = "/* database: db6 */\nSELECT * FROM ...;";
358 assert_eq!(get_database_from_sql(sql).unwrap(), "db6");
359
360 let sql = "/* database: db7 */\nSELECT * FROM ...;";
361 assert_eq!(get_database_from_sql(sql).unwrap(), "db7");
362
363 let sql = "SELECT * FROM ...;";
364 assert!(get_database_from_sql(sql).is_none());
365
366 let sql = "-- database: db0 \n-- database: db1";
367 assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
368
369 let sql = "/* database: db0 */\n/* database: db1 */";
370 assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
371 }
372
373 #[test]
374 fn test_get_database_from_sql_with_comment() {
375 let sql = "-- database: db0\n-- comment\nSELECT * FROM ...;";
376 assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
377
378 let sql = "-- database: db0\n-- comment\n-- comment\nSELECT * FROM ...;";
379 assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
380
381 let sql = "-- database: db0\n-- comment\n-- comment\n-- comment\nSELECT * FROM ...;";
382 assert_eq!(get_database_from_sql(sql).unwrap(), "db0");
383 }
384}