Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 37 additions & 33 deletions conformance/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use rmcp::{
model::*,
service::RequestContext,
transport::{
AuthClient, AuthorizationManager, StreamableHttpClientTransport, auth::OAuthState,
AuthClient, AuthorizationManager, StreamableHttpClientTransport,
auth::{AuthorizationCallback, OAuthState},
streamable_http_client::StreamableHttpClientTransportConfig,
},
};
Expand Down Expand Up @@ -221,20 +222,16 @@ async fn perform_oauth_flow(
.and_then(|v| v.to_str().ok())
.ok_or_else(|| anyhow::anyhow!("No Location header in auth redirect"))?;

let redirect_url = url::Url::parse(location)?;
let code = redirect_url
.query_pairs()
.find(|(k, _)| k == "code")
.map(|(_, v)| v.to_string())
.ok_or_else(|| anyhow::anyhow!("No code in redirect URL"))?;
let state = redirect_url
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.ok_or_else(|| anyhow::anyhow!("No state in redirect URL"))?;
let callback = AuthorizationCallback::from_redirect_url(location)?;

tracing::debug!("Got auth code, exchanging for token...");
oauth.handle_callback(&code, &state).await?;
oauth
.handle_callback_with_issuer(
&callback.code,
&callback.csrf_token,
callback.issuer.as_deref(),
)
.await?;

let am = oauth
.into_authorization_manager()
Expand Down Expand Up @@ -334,8 +331,14 @@ async fn run_auth_scope_step_up_client(
.await?;

let auth_url = oauth.get_authorization_url().await?;
let (code, state) = headless_authorize(&auth_url).await?;
oauth.handle_callback(&code, &state).await?;
let callback = headless_authorize(&auth_url).await?;
oauth
.handle_callback_with_issuer(
&callback.code,
&callback.csrf_token,
callback.issuer.as_deref(),
)
.await?;

let am = oauth
.into_authorization_manager()
Expand Down Expand Up @@ -380,8 +383,14 @@ async fn run_auth_scope_step_up_client(
)
.await?;
let auth_url2 = oauth2.get_authorization_url().await?;
let (code2, state2) = headless_authorize(&auth_url2).await?;
oauth2.handle_callback(&code2, &state2).await?;
let callback2 = headless_authorize(&auth_url2).await?;
oauth2
.handle_callback_with_issuer(
&callback2.code,
&callback2.csrf_token,
callback2.issuer.as_deref(),
)
.await?;

let am2 = oauth2.into_authorization_manager().unwrap();
let auth_client2 = AuthClient::new(reqwest::Client::default(), am2);
Expand Down Expand Up @@ -422,8 +431,14 @@ async fn run_auth_scope_retry_limit_client(
)
.await?;
let auth_url = oauth.get_authorization_url().await?;
let (code, state) = headless_authorize(&auth_url).await?;
oauth.handle_callback(&code, &state).await?;
let callback = headless_authorize(&auth_url).await?;
oauth
.handle_callback_with_issuer(
&callback.code,
&callback.csrf_token,
callback.issuer.as_deref(),
)
.await?;

let am = oauth.into_authorization_manager().unwrap();
let auth_client = AuthClient::new(reqwest::Client::default(), am);
Expand Down Expand Up @@ -696,8 +711,8 @@ async fn run_cross_app_access_client(

// ─── Helpers ────────────────────────────────────────────────────────────────

/// Fetch an authorization URL headlessly, returning (code, state).
async fn headless_authorize(auth_url: &str) -> anyhow::Result<(String, String)> {
/// Fetch an authorization URL headlessly, returning the callback parameters.
async fn headless_authorize(auth_url: &str) -> anyhow::Result<AuthorizationCallback> {
let http = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()?;
Expand All @@ -707,18 +722,7 @@ async fn headless_authorize(auth_url: &str) -> anyhow::Result<(String, String)>
.get("location")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| anyhow::anyhow!("No Location header in auth redirect"))?;
let redirect_url = url::Url::parse(location)?;
let code = redirect_url
.query_pairs()
.find(|(k, _)| k == "code")
.map(|(_, v)| v.to_string())
.ok_or_else(|| anyhow::anyhow!("No code in redirect URL"))?;
let state = redirect_url
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.ok_or_else(|| anyhow::anyhow!("No state in redirect URL"))?;
Ok((code, state))
AuthorizationCallback::from_redirect_url(location).map_err(Into::into)
}

/// Build a `CallToolRequestParams` for a tool, optionally with arguments.
Expand Down
Loading