Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ tilt_modules/

# VSCode
.vscode/
.claude/worktrees/
75 changes: 72 additions & 3 deletions src/aws/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ pub enum AwsAuthentication {
/// [aws_region]: https://docs.aws.amazon.com/general/latest/gr/rande.html#regional-endpoints
#[configurable(metadata(docs::examples = "us-west-2"))]
region: Option<String>,

/// The optional custom endpoint URL for STS `AssumeRole` calls.
///
/// When set, overrides the default STS endpoint (e.g. `sts.amazonaws.com`).
/// Useful for GovCloud, private-link setups, or pointing at a mock STS in tests.
/// When unset, the AWS SDK default is used — no behaviour change for existing configs.
#[configurable(metadata(docs::examples = "http://localhost:4566"))]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use localhost in docs string

sts_endpoint: Option<String>,
},

/// Authenticate using credentials stored in a file.
Expand Down Expand Up @@ -152,6 +160,14 @@ pub enum AwsAuthentication {
/// [aws_region]: https://docs.aws.amazon.com/general/latest/gr/rande.html#regional-endpoints
#[configurable(metadata(docs::examples = "us-west-2"))]
region: Option<String>,

/// The optional custom endpoint URL for STS `AssumeRole` calls.
///
/// When set, overrides the default STS endpoint (e.g. `sts.amazonaws.com`).
/// Useful for GovCloud, private-link setups, or pointing at a mock STS in tests.
/// When unset, the AWS SDK default is used — no behaviour change for existing configs.
#[configurable(metadata(docs::examples = "http://localhost:4566"))]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

sts_endpoint: Option<String>,
},

/// Default authentication strategy which tries a variety of substrategies in sequential order.
Expand Down Expand Up @@ -216,13 +232,17 @@ impl AwsAuthentication {
region: &Region,
assume_role: &str,
external_id: Option<&str>,
sts_endpoint: Option<&str>,
) -> crate::Result<AssumeRoleProviderBuilder> {
let connector = super::connector(proxy, tls_options)?;
let config = SdkConfig::builder()
let mut config_builder = SdkConfig::builder()
.http_client(connector)
.region(region.clone())
.time_source(SystemTimeSource::new())
.build();
.time_source(SystemTimeSource::new());
if let Some(endpoint) = sts_endpoint {
config_builder = config_builder.endpoint_url(endpoint);
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use [set_endpoint_url](https://docs.rs/aws-sdk-s3/latest/aws_sdk_s3/config/struct.Builder.html#method.set_endpoint_url) here

let config = config_builder.build();

let mut builder = AssumeRoleProviderBuilder::new(assume_role)
.region(region.clone())
Expand All @@ -249,6 +269,7 @@ impl AwsAuthentication {
assume_role,
external_id,
region,
sts_endpoint,
} => {
let provider = SharedCredentialsProvider::new(Credentials::from_keys(
access_key_id.inner(),
Expand All @@ -263,6 +284,7 @@ impl AwsAuthentication {
&auth_region,
assume_role,
external_id.as_deref(),
sts_endpoint.as_deref(),
)?;

let provider = builder.build_from_provider(provider).await;
Expand Down Expand Up @@ -297,6 +319,7 @@ impl AwsAuthentication {
external_id,
imds,
region,
sts_endpoint,
..
} => {
let auth_region = region.clone().map(Region::new).unwrap_or(service_region);
Expand All @@ -306,6 +329,7 @@ impl AwsAuthentication {
&auth_region,
assume_role,
external_id.as_deref(),
sts_endpoint.as_deref(),
)?;

let provider = builder
Expand Down Expand Up @@ -338,6 +362,7 @@ impl AwsAuthentication {
assume_role: None,
external_id: None,
region: None,
sts_endpoint: None,
}
}
}
Expand Down Expand Up @@ -513,6 +538,8 @@ mod tests {
load_timeout_secs,
imds,
region,
sts_endpoint,
..
} => {
assert_eq!(&assume_role, "root");
assert_eq!(external_id, None);
Expand All @@ -526,6 +553,7 @@ mod tests {
}
));
assert_eq!(region, None);
assert_eq!(sts_endpoint, None);
}
_ => panic!(),
}
Expand All @@ -550,6 +578,7 @@ mod tests {
load_timeout_secs,
imds,
region,
..
} => {
assert_eq!(&assume_role, "auth.root");
assert_eq!(external_id, None);
Expand Down Expand Up @@ -674,4 +703,44 @@ mod tests {
_ => panic!(),
}
}

#[test]
fn parsing_role_with_sts_endpoint() {
let config = toml::from_str::<ComponentConfig>(
r#"
auth.assume_role = "arn:aws:iam::123456789098:role/my_role"
auth.sts_endpoint = "http://localhost:4566"
"#,
)
.unwrap();

match config.auth {
AwsAuthentication::Role {
assume_role,
sts_endpoint,
..
} => {
assert_eq!(&assume_role, "arn:aws:iam::123456789098:role/my_role");
assert_eq!(sts_endpoint, Some("http://localhost:4566".to_string()));
}
_ => panic!("expected Role variant"),
}
}

#[test]
fn parsing_role_without_sts_endpoint_defaults_to_none() {
let config = toml::from_str::<ComponentConfig>(
r#"
auth.assume_role = "arn:aws:iam::123456789098:role/my_role"
"#,
)
.unwrap();

match config.auth {
AwsAuthentication::Role { sts_endpoint, .. } => {
assert_eq!(sts_endpoint, None);
}
_ => panic!("expected Role variant"),
}
}
}