diff --git a/Cargo.lock b/Cargo.lock index 06cf5bab..353b0a8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -427,6 +427,32 @@ dependencies = [ "uuid", ] +[[package]] +name = "aws-sdk-dsql" +version = "1.55.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f914acf80007b4d0fc1e68d7f8045b39d58367bc3aaa8270c44368e8b8dd3ee1" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-observability", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "http 1.4.0", + "regex-lite", + "tracing", + "url", +] + [[package]] name = "aws-sdk-sesv2" version = "1.118.0" @@ -5986,6 +6012,7 @@ dependencies = [ "argon2", "aws-config", "aws-credential-types", + "aws-sdk-dsql", "aws-sdk-sesv2", "aws-smithy-runtime-api", "bigdecimal", diff --git a/Cargo.toml b/Cargo.toml index a4d26875..32a91bad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,8 @@ vendored_openssl = ["openssl/vendored"] # Enable MiMalloc memory allocator to replace the default malloc # This can improve performance for Alpine builds enable_mimalloc = ["dep:mimalloc"] -aws = ["s3", "ses"] +aws = ["dsql", "s3", "ses"] +dsql = ["postgresql", "dep:aws-config", "dep:aws-sdk-dsql", "dep:aws-smithy-runtime-api"] ses = ["dep:aws-config", "dep:aws-sdk-sesv2", "dep:aws-smithy-runtime-api"] s3 = ["opendal/services-s3", "opendal/reqwest-rustls-tls", "dep:aws-config", "dep:aws-credential-types", "dep:aws-smithy-runtime-api", "dep:http", "dep:reqsign-aws-v4", "dep:reqsign-core"] @@ -203,6 +204,7 @@ opendal = { version = "0.56.0", features = ["services-fs"], default-features = f # For retrieving AWS credentials, including temporary SSO credentials aws-config = { version = "1.8.16", features = ["behavior-version-latest", "rt-tokio", "credentials-process", "sso"], default-features = false, optional = true } aws-credential-types = { version = "1.2.14", optional = true } +aws-sdk-dsql = { version = "1.55.0", features = ["behavior-version-latest", "rt-tokio"], default-features = false, optional = true } aws-sdk-sesv2 = { version = "1.118.0", features = ["behavior-version-latest", "rt-tokio"], default-features = false, optional = true } aws-smithy-runtime-api = { version = "1.12.0", optional = true } http = { version = "1.4.0", optional = true } diff --git a/build.rs b/build.rs index 0870134d..c4db7d8a 100644 --- a/build.rs +++ b/build.rs @@ -9,7 +9,9 @@ fn main() { println!("cargo:rustc-cfg=mysql"); #[cfg(feature = "postgresql")] println!("cargo:rustc-cfg=postgresql"); - #[cfg(not(any(feature = "sqlite_system", feature = "mysql", feature = "postgresql")))] + #[cfg(feature = "dsql")] + println!("cargo:rustc-cfg=dsql"); + #[cfg(not(any(feature = "sqlite_system", feature = "mysql", feature = "postgresql", feature = "dsql")))] compile_error!( "You need to enable one DB backend. To build with previous defaults do: cargo build --features sqlite" ); @@ -26,6 +28,7 @@ fn main() { println!("cargo::rustc-check-cfg=cfg(sqlite)"); println!("cargo::rustc-check-cfg=cfg(mysql)"); println!("cargo::rustc-check-cfg=cfg(postgresql)"); + println!("cargo::rustc-check-cfg=cfg(dsql)"); println!("cargo::rustc-check-cfg=cfg(s3)"); println!("cargo::rustc-check-cfg=cfg(ses)"); println!("cargo::rustc-check-cfg=cfg(aws)"); diff --git a/migrations/dsql/2024-12-30-100000_create_tables/metadata.toml b/migrations/dsql/2024-12-30-100000_create_tables/metadata.toml new file mode 100644 index 00000000..16153bc0 --- /dev/null +++ b/migrations/dsql/2024-12-30-100000_create_tables/metadata.toml @@ -0,0 +1 @@ +run_in_transaction = false \ No newline at end of file diff --git a/migrations/dsql/2024-12-30-100000_create_tables/up.sql b/migrations/dsql/2024-12-30-100000_create_tables/up.sql new file mode 100644 index 00000000..3eefed1e --- /dev/null +++ b/migrations/dsql/2024-12-30-100000_create_tables/up.sql @@ -0,0 +1,281 @@ +CREATE TABLE attachments ( + id text NOT NULL PRIMARY KEY, + cipher_uuid character varying(40) NOT NULL, + file_name text NOT NULL, + file_size bigint NOT NULL, + akey text +); + +CREATE TABLE auth_requests ( + uuid character(36) NOT NULL PRIMARY KEY, + user_uuid character(36) NOT NULL, + organization_uuid character(36), + request_device_identifier character(36) NOT NULL, + device_type integer NOT NULL, + request_ip text NOT NULL, + response_device_id character(36), + access_code text NOT NULL, + public_key text NOT NULL, + enc_key text, + master_password_hash text, + approved boolean, + creation_date timestamp without time zone NOT NULL, + response_date timestamp without time zone, + authentication_date timestamp without time zone +); + +CREATE TABLE ciphers ( + uuid character varying(40) NOT NULL PRIMARY KEY, + created_at timestamp without time zone NOT NULL, + updated_at timestamp without time zone NOT NULL, + user_uuid character varying(40), + organization_uuid character varying(40), + atype integer NOT NULL, + name text NOT NULL, + notes text, + fields text, + data text NOT NULL, + password_history text, + deleted_at timestamp without time zone, + reprompt integer, + key text +); + +CREATE TABLE ciphers_collections ( + cipher_uuid character varying(40) NOT NULL, + collection_uuid character varying(40) NOT NULL, + PRIMARY KEY (cipher_uuid, collection_uuid) +); + +CREATE TABLE collections ( + uuid character varying(40) NOT NULL PRIMARY KEY, + org_uuid character varying(40) NOT NULL, + name text NOT NULL, + external_id text +); + +CREATE TABLE collections_groups ( + collections_uuid character varying(40) NOT NULL, + groups_uuid character(36) NOT NULL, + read_only boolean NOT NULL, + hide_passwords boolean NOT NULL, + PRIMARY KEY (collections_uuid, groups_uuid) +); + +CREATE TABLE devices ( + uuid character varying(40) NOT NULL, + created_at timestamp without time zone NOT NULL, + updated_at timestamp without time zone NOT NULL, + user_uuid character varying(40) NOT NULL, + name text NOT NULL, + atype integer NOT NULL, + push_token text, + refresh_token text NOT NULL, + twofactor_remember text, + push_uuid text, + PRIMARY KEY (uuid, user_uuid) +); + +CREATE TABLE emergency_access ( + uuid character(36) NOT NULL PRIMARY KEY, + grantor_uuid character(36), + grantee_uuid character(36), + email character varying(255), + key_encrypted text, + atype integer NOT NULL, + status integer NOT NULL, + wait_time_days integer NOT NULL, + recovery_initiated_at timestamp without time zone, + last_notification_at timestamp without time zone, + updated_at timestamp without time zone NOT NULL, + created_at timestamp without time zone NOT NULL +); + +CREATE TABLE event ( + uuid character(36) NOT NULL PRIMARY KEY, + event_type integer NOT NULL, + user_uuid character(36), + org_uuid character(36), + cipher_uuid character(36), + collection_uuid character(36), + group_uuid character(36), + org_user_uuid character(36), + act_user_uuid character(36), + device_type integer, + ip_address text, + event_date timestamp without time zone NOT NULL, + policy_uuid character(36), + provider_uuid character(36), + provider_user_uuid character(36), + provider_org_uuid character(36) +); + +CREATE TABLE favorites ( + user_uuid character varying(40) NOT NULL, + cipher_uuid character varying(40) NOT NULL, + PRIMARY KEY (user_uuid, cipher_uuid) +); + +CREATE TABLE folders ( + uuid character varying(40) NOT NULL PRIMARY KEY, + created_at timestamp without time zone NOT NULL, + updated_at timestamp without time zone NOT NULL, + user_uuid character varying(40) NOT NULL, + name text NOT NULL +); + +CREATE TABLE folders_ciphers ( + cipher_uuid character varying(40) NOT NULL, + folder_uuid character varying(40) NOT NULL, + PRIMARY KEY (cipher_uuid, folder_uuid) +); + +CREATE TABLE groups ( + uuid character(36) NOT NULL PRIMARY KEY, + organizations_uuid character varying(40) NOT NULL, + name character varying(100) NOT NULL, + access_all boolean NOT NULL, + external_id character varying(300), + creation_date timestamp without time zone NOT NULL, + revision_date timestamp without time zone NOT NULL +); + +CREATE TABLE groups_users ( + groups_uuid character(36) NOT NULL, + users_organizations_uuid character varying(36) NOT NULL, + PRIMARY KEY (groups_uuid, users_organizations_uuid) +); + +CREATE TABLE invitations ( + email text NOT NULL PRIMARY KEY +); + +CREATE TABLE org_policies ( + uuid character(36) NOT NULL PRIMARY KEY, + org_uuid character(36) NOT NULL, + atype integer NOT NULL, + enabled boolean NOT NULL, + data text NOT NULL, + UNIQUE (org_uuid, atype) +); + +CREATE TABLE organization_api_key ( + uuid character(36) NOT NULL, + org_uuid character(36) NOT NULL, + atype integer NOT NULL, + api_key character varying(255), + revision_date timestamp without time zone NOT NULL, + PRIMARY KEY (uuid, org_uuid) +); + +CREATE TABLE organizations ( + uuid character varying(40) NOT NULL PRIMARY KEY, + name text NOT NULL, + billing_email text NOT NULL, + private_key text, + public_key text +); + +CREATE TABLE sends ( + uuid character(36) NOT NULL PRIMARY KEY, + user_uuid character(36), + organization_uuid character(36), + name text NOT NULL, + notes text, + atype integer NOT NULL, + data text NOT NULL, + akey text NOT NULL, + password_hash bytea, + password_salt bytea, + password_iter integer, + max_access_count integer, + access_count integer NOT NULL, + creation_date timestamp without time zone NOT NULL, + revision_date timestamp without time zone NOT NULL, + expiration_date timestamp without time zone, + deletion_date timestamp without time zone NOT NULL, + disabled boolean NOT NULL, + hide_email boolean +); + +CREATE TABLE twofactor ( + uuid character varying(40) NOT NULL PRIMARY KEY, + user_uuid character varying(40) NOT NULL, + atype integer NOT NULL, + enabled boolean NOT NULL, + data text NOT NULL, + last_used bigint DEFAULT 0 NOT NULL, + UNIQUE (user_uuid, atype) +); + +CREATE TABLE twofactor_duo_ctx ( + state character varying(64) NOT NULL PRIMARY KEY, + user_email character varying(255) NOT NULL, + nonce character varying(64) NOT NULL, + exp bigint NOT NULL +); + +CREATE TABLE twofactor_incomplete ( + user_uuid character varying(40) NOT NULL, + device_uuid character varying(40) NOT NULL, + device_name text NOT NULL, + login_time timestamp without time zone NOT NULL, + ip_address text NOT NULL, + device_type integer DEFAULT 14 NOT NULL, + PRIMARY KEY (user_uuid, device_uuid) +); + +CREATE TABLE users ( + uuid character varying(40) NOT NULL PRIMARY KEY, + created_at timestamp without time zone NOT NULL, + updated_at timestamp without time zone NOT NULL, + email text NOT NULL UNIQUE, + name text NOT NULL, + password_hash bytea NOT NULL, + salt bytea NOT NULL, + password_iterations integer NOT NULL, + password_hint text, + akey text NOT NULL, + private_key text, + public_key text, + totp_secret text, + totp_recover text, + security_stamp text NOT NULL, + equivalent_domains text NOT NULL, + excluded_globals text NOT NULL, + client_kdf_type integer DEFAULT 0 NOT NULL, + client_kdf_iter integer DEFAULT 100000 NOT NULL, + verified_at timestamp without time zone, + last_verifying_at timestamp without time zone, + login_verify_count integer DEFAULT 0 NOT NULL, + email_new character varying(255) DEFAULT NULL::character varying, + email_new_token character varying(16) DEFAULT NULL::character varying, + enabled boolean DEFAULT true NOT NULL, + stamp_exception text, + api_key text, + avatar_color text, + client_kdf_memory integer, + client_kdf_parallelism integer, + external_id text +); + +CREATE TABLE users_collections ( + user_uuid character varying(40) NOT NULL, + collection_uuid character varying(40) NOT NULL, + read_only boolean DEFAULT false NOT NULL, + hide_passwords boolean DEFAULT false NOT NULL, + PRIMARY KEY (user_uuid, collection_uuid) +); + +CREATE TABLE users_organizations ( + uuid character varying(40) NOT NULL PRIMARY KEY, + user_uuid character varying(40) NOT NULL, + org_uuid character varying(40) NOT NULL, + access_all boolean NOT NULL, + akey text NOT NULL, + status integer NOT NULL, + atype integer NOT NULL, + reset_password_key text, + external_id text, + UNIQUE (user_uuid, org_uuid) +); \ No newline at end of file diff --git a/migrations/dsql/2025-01-09-172300_add_manage/down.sql b/migrations/dsql/2025-01-09-172300_add_manage/down.sql new file mode 100644 index 00000000..e69de29b diff --git a/migrations/dsql/2025-01-09-172300_add_manage/metadata.toml b/migrations/dsql/2025-01-09-172300_add_manage/metadata.toml new file mode 100644 index 00000000..16153bc0 --- /dev/null +++ b/migrations/dsql/2025-01-09-172300_add_manage/metadata.toml @@ -0,0 +1 @@ +run_in_transaction = false \ No newline at end of file diff --git a/migrations/dsql/2025-01-09-172300_add_manage/up.sql b/migrations/dsql/2025-01-09-172300_add_manage/up.sql new file mode 100644 index 00000000..3565446c --- /dev/null +++ b/migrations/dsql/2025-01-09-172300_add_manage/up.sql @@ -0,0 +1,8 @@ +-- DSQL preview can't add columns with constraints, dropping `NOT NULL DEFAULT FALSE` constraint +-- It appears Diesel will ensure the column has appropriate values when saving records. + +ALTER TABLE users_collections +ADD COLUMN manage BOOLEAN; + +ALTER TABLE collections_groups +ADD COLUMN manage BOOLEAN; diff --git a/migrations/dsql/2025-08-20-120000_add_users_organizations_invited_by_email/down.sql b/migrations/dsql/2025-08-20-120000_add_users_organizations_invited_by_email/down.sql new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/migrations/dsql/2025-08-20-120000_add_users_organizations_invited_by_email/down.sql @@ -0,0 +1 @@ + diff --git a/migrations/dsql/2025-08-20-120000_add_users_organizations_invited_by_email/metadata.toml b/migrations/dsql/2025-08-20-120000_add_users_organizations_invited_by_email/metadata.toml new file mode 100644 index 00000000..79e9221c --- /dev/null +++ b/migrations/dsql/2025-08-20-120000_add_users_organizations_invited_by_email/metadata.toml @@ -0,0 +1 @@ +run_in_transaction = false diff --git a/migrations/dsql/2025-08-20-120000_add_users_organizations_invited_by_email/up.sql b/migrations/dsql/2025-08-20-120000_add_users_organizations_invited_by_email/up.sql new file mode 100644 index 00000000..eb70d7f1 --- /dev/null +++ b/migrations/dsql/2025-08-20-120000_add_users_organizations_invited_by_email/up.sql @@ -0,0 +1,2 @@ +ALTER TABLE users_organizations +ADD COLUMN invited_by_email TEXT; diff --git a/migrations/dsql/2025-08-20-120100_add_sso_users/down.sql b/migrations/dsql/2025-08-20-120100_add_sso_users/down.sql new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/migrations/dsql/2025-08-20-120100_add_sso_users/down.sql @@ -0,0 +1 @@ + diff --git a/migrations/dsql/2025-08-20-120100_add_sso_users/metadata.toml b/migrations/dsql/2025-08-20-120100_add_sso_users/metadata.toml new file mode 100644 index 00000000..79e9221c --- /dev/null +++ b/migrations/dsql/2025-08-20-120100_add_sso_users/metadata.toml @@ -0,0 +1 @@ +run_in_transaction = false diff --git a/migrations/dsql/2025-08-20-120100_add_sso_users/up.sql b/migrations/dsql/2025-08-20-120100_add_sso_users/up.sql new file mode 100644 index 00000000..cd541660 --- /dev/null +++ b/migrations/dsql/2025-08-20-120100_add_sso_users/up.sql @@ -0,0 +1,4 @@ +CREATE TABLE sso_users ( + user_uuid character(36) NOT NULL PRIMARY KEY, + identifier text NOT NULL UNIQUE +); diff --git a/migrations/dsql/2025-08-20-120200_add_sso_auth/down.sql b/migrations/dsql/2025-08-20-120200_add_sso_auth/down.sql new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/migrations/dsql/2025-08-20-120200_add_sso_auth/down.sql @@ -0,0 +1 @@ + diff --git a/migrations/dsql/2025-08-20-120200_add_sso_auth/metadata.toml b/migrations/dsql/2025-08-20-120200_add_sso_auth/metadata.toml new file mode 100644 index 00000000..79e9221c --- /dev/null +++ b/migrations/dsql/2025-08-20-120200_add_sso_auth/metadata.toml @@ -0,0 +1 @@ +run_in_transaction = false diff --git a/migrations/dsql/2025-08-20-120200_add_sso_auth/up.sql b/migrations/dsql/2025-08-20-120200_add_sso_auth/up.sql new file mode 100644 index 00000000..220e0cf0 --- /dev/null +++ b/migrations/dsql/2025-08-20-120200_add_sso_auth/up.sql @@ -0,0 +1,10 @@ +CREATE TABLE sso_auth ( + state text NOT NULL PRIMARY KEY, + client_challenge text NOT NULL, + nonce text NOT NULL, + redirect_uri text NOT NULL, + code_response text, + auth_response text, + created_at timestamp without time zone NOT NULL DEFAULT now(), + updated_at timestamp without time zone NOT NULL DEFAULT now() +); diff --git a/migrations/dsql/2026-03-09-005927_add_archives/down.sql b/migrations/dsql/2026-03-09-005927_add_archives/down.sql new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/migrations/dsql/2026-03-09-005927_add_archives/down.sql @@ -0,0 +1 @@ + diff --git a/migrations/dsql/2026-03-09-005927_add_archives/metadata.toml b/migrations/dsql/2026-03-09-005927_add_archives/metadata.toml new file mode 100644 index 00000000..79e9221c --- /dev/null +++ b/migrations/dsql/2026-03-09-005927_add_archives/metadata.toml @@ -0,0 +1 @@ +run_in_transaction = false diff --git a/migrations/dsql/2026-03-09-005927_add_archives/up.sql b/migrations/dsql/2026-03-09-005927_add_archives/up.sql new file mode 100644 index 00000000..9dd95083 --- /dev/null +++ b/migrations/dsql/2026-03-09-005927_add_archives/up.sql @@ -0,0 +1,6 @@ +CREATE TABLE archives ( + user_uuid character(36) NOT NULL, + cipher_uuid character(36) NOT NULL, + archived_at timestamp without time zone NOT NULL DEFAULT now(), + PRIMARY KEY (user_uuid, cipher_uuid) +); diff --git a/migrations/dsql/2026-04-25-120000_sso_auth_binding/down.sql b/migrations/dsql/2026-04-25-120000_sso_auth_binding/down.sql new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/migrations/dsql/2026-04-25-120000_sso_auth_binding/down.sql @@ -0,0 +1 @@ + diff --git a/migrations/dsql/2026-04-25-120000_sso_auth_binding/metadata.toml b/migrations/dsql/2026-04-25-120000_sso_auth_binding/metadata.toml new file mode 100644 index 00000000..79e9221c --- /dev/null +++ b/migrations/dsql/2026-04-25-120000_sso_auth_binding/metadata.toml @@ -0,0 +1 @@ +run_in_transaction = false diff --git a/migrations/dsql/2026-04-25-120000_sso_auth_binding/up.sql b/migrations/dsql/2026-04-25-120000_sso_auth_binding/up.sql new file mode 100644 index 00000000..5272a7ca --- /dev/null +++ b/migrations/dsql/2026-04-25-120000_sso_auth_binding/up.sql @@ -0,0 +1,2 @@ +ALTER TABLE sso_auth +ADD COLUMN binding_hash TEXT; diff --git a/src/api/admin.rs b/src/api/admin.rs index 02c976cc..5e4f8247 100644 --- a/src/api/admin.rs +++ b/src/api/admin.rs @@ -83,6 +83,8 @@ pub fn catchers() -> Vec { } static DB_TYPE: LazyLock<&str> = LazyLock::new(|| match ACTIVE_DB_TYPE.get() { + #[cfg(dsql)] + Some(DbConnType::Dsql) => "Aurora DSQL", #[cfg(mysql)] Some(DbConnType::Mysql) => "MySQL", #[cfg(postgresql)] diff --git a/src/aws.rs b/src/aws.rs index 0a4f7dff..fabf8dbc 100644 --- a/src/aws.rs +++ b/src/aws.rs @@ -1,3 +1,6 @@ +#[cfg(dsql)] +use std::io::Error; + use aws_config::{AppName, BehaviorVersion}; use tokio::sync::OnceCell; @@ -24,3 +27,13 @@ pub(crate) async fn aws_sdk_config() -> &'static aws_config::SdkConfig { }) .await } + +#[cfg(dsql)] +pub(crate) fn aws_sdk_config_blocking() -> std::io::Result<&'static aws_config::SdkConfig> { + std::thread::spawn(|| { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build()?; + std::io::Result::Ok(rt.block_on(aws_sdk_config())) + }) + .join() + .map_err(|e| Error::other(format!("Failed to load AWS SDK config: {e:?}")))? +} diff --git a/src/db/dsql.rs b/src/db/dsql.rs new file mode 100644 index 00000000..7d3172a3 --- /dev/null +++ b/src/db/dsql.rs @@ -0,0 +1,110 @@ +use std::{ + collections::HashMap, + sync::{Arc, LazyLock, Mutex}, + time::Duration, +}; + +use diesel::ConnectionError; +use url::Url; + +// Generate a Postgres libpq connection string. The input connection string has +// the following format: +// +// dsql://.dsql..on.aws +// +// The generated connection string has the form: +// +// postgresql://.dsql..on.aws/postgres?sslmode=require&user=admin&password= +// +// The auth token is generated by the AWS SDK for DSQL and is valid for up to +// 15 minutes. Cache each unique DSQL URL for 14 minutes to avoid regenerating +// a token for every pooled connection. +pub(crate) fn psql_url(url: &str) -> Result { + struct PsqlUrl { + timestamp: std::time::Instant, + url: String, + } + + static PSQL_URLS: LazyLock>>>>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + + let mut psql_urls = + PSQL_URLS.lock().map_err(|e| ConnectionError::BadConnection(format!("Failed to lock DSQL URLs: {e}")))?; + + let psql_url_lock = if let Some(existing_psql_url_lock) = psql_urls.get(url) { + existing_psql_url_lock.clone() + } else { + let psql_url_lock = Arc::new(Mutex::new(None)); + psql_urls.insert(url.to_string(), psql_url_lock.clone()); + psql_url_lock + }; + + let mut psql_url_lock_guard = + psql_url_lock.lock().map_err(|e| ConnectionError::BadConnection(format!("Failed to lock DSQL URL: {e}")))?; + + drop(psql_urls); + + if let Some(ref psql_url) = *psql_url_lock_guard { + if psql_url.timestamp.elapsed() < Duration::from_secs(14 * 60) { + debug!("Reusing DSQL auth token for connection '{url}'"); + return Ok(psql_url.url.clone()); + } + + info!("Refreshing DSQL auth token for connection '{url}'"); + } else { + info!("Generating new DSQL auth token for connection '{url}'"); + } + + let mut psql_url = Url::parse(url).map_err(|e| ConnectionError::InvalidConnectionUrl(e.to_string()))?; + + let host = psql_url + .host_str() + .ok_or_else(|| ConnectionError::InvalidConnectionUrl("Missing hostname in DSQL URL".to_string()))? + .to_string(); + + static DSQL_REGION_FROM_HOST_RE: LazyLock = LazyLock::new(|| { + regex::Regex::new(r"^[a-z0-9]+\.dsql\.(?P[a-z0-9-]+)\.on\.aws$") + .expect("Failed to compile DSQL region regex") + }); + + let region = DSQL_REGION_FROM_HOST_RE + .captures(&host) + .and_then(|captures| captures.name("region")) + .ok_or_else(|| ConnectionError::InvalidConnectionUrl("Failed to find AWS region in DSQL hostname".to_string()))? + .as_str() + .to_string(); + + let auth_config = aws_sdk_dsql::auth_token::Config::builder() + .hostname(host) + .region(aws_config::Region::new(region)) + .build() + .map_err(|e| ConnectionError::BadConnection(format!("Failed to build DSQL auth token signer config: {e}")))?; + + let signer = aws_sdk_dsql::auth_token::AuthTokenGenerator::new(auth_config); + let sdk_config = crate::aws::aws_sdk_config_blocking() + .map_err(|e| ConnectionError::BadConnection(format!("Failed to load AWS SDK config: {e}")))?; + let now = std::time::Instant::now(); + + let auth_token = std::thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build()?; + rt.block_on(signer.db_connect_admin_auth_token(sdk_config)) + }) + .join() + .map_err(|e| ConnectionError::BadConnection(format!("Failed to generate DSQL auth token: {e:?}")))? + .map_err(|e| ConnectionError::BadConnection(format!("Failed to generate DSQL auth token: {e}")))?; + + psql_url.set_scheme("postgresql").expect("Failed to set 'postgresql' as scheme for DSQL connection URL"); + psql_url.set_path("postgres"); + psql_url + .query_pairs_mut() + .append_pair("sslmode", "require") + .append_pair("user", "admin") + .append_pair("password", auth_token.as_str()); + + psql_url_lock_guard.replace(PsqlUrl { + timestamp: now, + url: psql_url.to_string(), + }); + + Ok(psql_url.to_string()) +} diff --git a/src/db/mod.rs b/src/db/mod.rs index d2ed9479..dc577386 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,3 +1,5 @@ +#[cfg(dsql)] +mod dsql; mod query_logger; use std::{ @@ -68,6 +70,12 @@ impl DbConnManager { fn establish_connection(&self) -> Result { match DbConnType::from_url(&self.database_url) { + #[cfg(dsql)] + Ok(DbConnType::Dsql) => { + let db_url = dsql::psql_url(&self.database_url).map_err(diesel::r2d2::Error::ConnectionError)?; + let conn = diesel::pg::PgConnection::establish(&db_url)?; + Ok(DbConnInner::Postgresql(conn)) + } #[cfg(mysql)] Ok(DbConnType::Mysql) => { let conn = diesel::mysql::MysqlConnection::establish(&self.database_url)?; @@ -110,8 +118,10 @@ impl diesel::r2d2::ManageConnection for DbConnManager { } } -#[derive(Eq, PartialEq)] +#[derive(Clone, Copy, Eq, PartialEq)] pub enum DbConnType { + #[cfg(dsql)] + Dsql, #[cfg(mysql)] Mysql, #[cfg(postgresql)] @@ -195,6 +205,10 @@ impl DbPool { } match conn_type { + #[cfg(dsql)] + DbConnType::Dsql => { + dsql_migrations::run_migrations(&db_url)?; + } #[cfg(mysql)] DbConnType::Mysql => { mysql_migrations::run_migrations(&db_url)?; @@ -272,6 +286,14 @@ impl DbConnType { #[cfg(not(postgresql))] err!("`DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled") + // Amazon Aurora DSQL + } else if url.starts_with("dsql:") { + #[cfg(dsql)] + return Ok(DbConnType::Dsql); + + #[cfg(not(dsql))] + err!("`DATABASE_URL` is a DSQL URL, but the 'dsql' feature is not enabled") + //Sqlite } else { #[cfg(sqlite)] @@ -293,6 +315,8 @@ impl DbConnType { pub fn default_init_stmts(&self) -> String { match self { + #[cfg(dsql)] + Self::Dsql => String::new(), #[cfg(mysql)] Self::Mysql => String::new(), #[cfg(postgresql)] @@ -517,3 +541,19 @@ mod postgresql_migrations { Ok(()) } } + +#[cfg(dsql)] +mod dsql_migrations { + use diesel::Connection; + use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; + + pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/dsql"); + + pub fn run_migrations(db_url: &str) -> Result<(), super::Error> { + let db_url = super::dsql::psql_url(db_url)?; + let mut connection = diesel::pg::PgConnection::establish(&db_url)?; + + connection.run_pending_migrations(MIGRATIONS).expect("Error running DSQL migrations"); + Ok(()) + } +} diff --git a/src/http_client.rs b/src/http_client.rs index f5f5dff2..0097384d 100644 --- a/src/http_client.rs +++ b/src/http_client.rs @@ -295,7 +295,7 @@ impl Resolve for CustomDnsResolver { } } -#[cfg(any(s3, ses))] +#[cfg(any(dsql, s3, ses))] pub(crate) mod aws { use aws_smithy_runtime_api::client::{ http::{HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector}, diff --git a/src/main.rs b/src/main.rs index 67c59d43..beefcae5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -48,7 +48,7 @@ use tokio::signal::unix::SignalKind; mod error; mod api; mod auth; -#[cfg(any(s3, ses))] +#[cfg(any(dsql, s3, ses))] mod aws; mod config; mod crypto;