1use std::collections::HashMap;
9use std::time::Duration;
10
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::TcpListener;
16
17use crate::credentials::{CredentialError, CredentialKind, CredentialStore};
18
19#[derive(Debug, Error)]
20pub enum OAuthError {
21 #[error("oauth http error: {0}")]
22 Http(String),
23 #[error("oauth token endpoint returned status {status}: {body}")]
24 Token { status: u16, body: String },
25 #[error("no refresh token stored; reconnect the service")]
26 NoRefreshToken,
27 #[error(transparent)]
28 Credential(#[from] CredentialError),
29 #[error("token store decode error: {0}")]
30 Decode(String),
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct OAuthToken {
36 pub access_token: String,
37 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub refresh_token: Option<String>,
39 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub expires_at: Option<DateTime<Utc>>,
41 #[serde(default)]
42 pub token_type: String,
43 #[serde(default)]
44 pub scope: String,
45}
46
47impl OAuthToken {
48 pub fn is_expired(&self) -> bool {
50 match self.expires_at {
51 Some(exp) => Utc::now() + chrono::Duration::seconds(60) >= exp,
52 None => false,
53 }
54 }
55}
56
57#[derive(Debug, Deserialize)]
59struct TokenResponse {
60 access_token: String,
61 #[serde(default)]
62 refresh_token: Option<String>,
63 #[serde(default)]
64 expires_in: Option<i64>,
65 #[serde(default)]
66 token_type: String,
67 #[serde(default)]
68 scope: String,
69}
70
71impl TokenResponse {
72 fn into_token(self, prior_refresh: Option<String>) -> OAuthToken {
73 OAuthToken {
74 refresh_token: self.refresh_token.or(prior_refresh),
76 expires_at: self
77 .expires_in
78 .map(|s| Utc::now() + chrono::Duration::seconds(s)),
79 access_token: self.access_token,
80 token_type: self.token_type,
81 scope: self.scope,
82 }
83 }
84}
85
86pub fn build_authorize_url(
88 auth_url: &str,
89 client_id: &str,
90 redirect_uri: &str,
91 scopes: &[String],
92 state: &str,
93) -> String {
94 let scope = scopes.join(" ");
95 let pairs = [
96 ("response_type", "code"),
97 ("client_id", client_id),
98 ("redirect_uri", redirect_uri),
99 ("scope", scope.as_str()),
100 ("state", state),
101 ("access_type", "offline"),
104 ("prompt", "consent"),
105 ];
106 let query = pairs
107 .iter()
108 .map(|(k, v)| format!("{k}={}", urlencode(v)))
109 .collect::<Vec<_>>()
110 .join("&");
111 let sep = if auth_url.contains('?') { "&" } else { "?" };
112 format!("{auth_url}{sep}{query}")
113}
114
115pub async fn exchange_code(
117 token_url: &str,
118 client_id: &str,
119 client_secret: &str,
120 code: &str,
121 redirect_uri: &str,
122) -> Result<OAuthToken, OAuthError> {
123 let params = [
124 ("grant_type", "authorization_code"),
125 ("code", code),
126 ("client_id", client_id),
127 ("client_secret", client_secret),
128 ("redirect_uri", redirect_uri),
129 ];
130 post_token(token_url, ¶ms, None).await
131}
132
133pub async fn refresh(
135 token_url: &str,
136 client_id: &str,
137 client_secret: &str,
138 refresh_token: &str,
139) -> Result<OAuthToken, OAuthError> {
140 let params = [
141 ("grant_type", "refresh_token"),
142 ("refresh_token", refresh_token),
143 ("client_id", client_id),
144 ("client_secret", client_secret),
145 ];
146 post_token(token_url, ¶ms, Some(refresh_token.to_string())).await
147}
148
149async fn post_token(
150 token_url: &str,
151 params: &[(&str, &str)],
152 prior_refresh: Option<String>,
153) -> Result<OAuthToken, OAuthError> {
154 let client = reqwest::Client::builder()
155 .timeout(Duration::from_secs(30))
156 .build()
157 .map_err(|e| OAuthError::Http(e.to_string()))?;
158 let resp = client
159 .post(token_url)
160 .form(params)
161 .send()
162 .await
163 .map_err(|e| OAuthError::Http(e.to_string()))?;
164 let status = resp.status();
165 let body = resp
166 .text()
167 .await
168 .map_err(|e| OAuthError::Http(e.to_string()))?;
169 if !status.is_success() {
170 return Err(OAuthError::Token {
171 status: status.as_u16(),
172 body,
173 });
174 }
175 let parsed: TokenResponse =
176 serde_json::from_str(&body).map_err(|e| OAuthError::Decode(e.to_string()))?;
177 Ok(parsed.into_token(prior_refresh))
178}
179
180pub fn store_token(
182 store: &dyn CredentialStore,
183 slug: &str,
184 token: &OAuthToken,
185) -> Result<(), OAuthError> {
186 let account = CredentialKind::Service.account_for(slug);
187 let json = serde_json::to_string(token).map_err(|e| OAuthError::Decode(e.to_string()))?;
188 store.set(&account, &json)?;
189 Ok(())
190}
191
192pub fn load_token(store: &dyn CredentialStore, slug: &str) -> Result<Option<OAuthToken>, OAuthError> {
194 let account = CredentialKind::Service.account_for(slug);
195 match store.get(&account) {
196 Ok(raw) => Ok(Some(
197 serde_json::from_str(&raw).map_err(|e| OAuthError::Decode(e.to_string()))?,
198 )),
199 Err(CredentialError::NotFound { .. }) => Ok(None),
200 Err(e) => Err(e.into()),
201 }
202}
203
204pub async fn bind_loopback(port: u16) -> Result<TcpListener, OAuthError> {
208 TcpListener::bind(("127.0.0.1", port))
209 .await
210 .map_err(|e| OAuthError::Http(format!("bind 127.0.0.1:{port}: {e}")))
211}
212
213pub async fn accept_oauth_code(
217 listener: TcpListener,
218 expected_state: &str,
219 timeout: Duration,
220) -> Result<String, OAuthError> {
221 let work = async {
222 loop {
223 let (mut stream, _) = listener
224 .accept()
225 .await
226 .map_err(|e| OAuthError::Http(e.to_string()))?;
227 let mut buf = vec![0u8; 4096];
228 let n = stream.read(&mut buf).await.unwrap_or(0);
229 let req = String::from_utf8_lossy(&buf[..n]);
230 let path = req
232 .lines()
233 .next()
234 .and_then(|l| l.split_whitespace().nth(1))
235 .unwrap_or("");
236 let params = parse_query(path);
237
238 let (status, message) = if params.contains_key("error") {
239 ("400 Bad Request", "Authorization was denied. You can close this tab.")
240 } else if params.contains_key("code") {
241 ("200 OK", "Connected. You can close this tab and return to Flow Studio.")
242 } else {
243 ("404 Not Found", "Not found.")
244 };
245 let body = format!(
246 "<!doctype html><meta charset=utf-8><body style=\"font-family:system-ui;padding:2rem\">{message}</body>"
247 );
248 let resp = format!(
249 "HTTP/1.1 {status}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
250 body.len()
251 );
252 let _ = stream.write_all(resp.as_bytes()).await;
253 let _ = stream.shutdown().await;
254
255 if let Some(err) = params.get("error") {
256 return Err(OAuthError::Token {
257 status: 400,
258 body: format!("authorization denied: {err}"),
259 });
260 }
261 if let Some(code) = params.get("code") {
262 return match params.get("state") {
263 Some(s) if s == expected_state => Ok(code.clone()),
264 _ => Err(OAuthError::Token {
265 status: 400,
266 body: "OAuth state mismatch (possible CSRF)".into(),
267 }),
268 };
269 }
270 }
272 };
273 match tokio::time::timeout(timeout, work).await {
274 Ok(res) => res,
275 Err(_) => Err(OAuthError::Http(
276 "timed out waiting for the OAuth redirect".into(),
277 )),
278 }
279}
280
281fn parse_query(path: &str) -> HashMap<String, String> {
283 let mut map = HashMap::new();
284 if let Some((_, q)) = path.split_once('?') {
285 for pair in q.split('&') {
286 if let Some((k, v)) = pair.split_once('=') {
287 map.insert(k.to_string(), urldecode(v));
288 }
289 }
290 }
291 map
292}
293
294fn urldecode(s: &str) -> String {
296 let s = s.replace('+', " ");
297 let b = s.as_bytes();
298 let mut out = Vec::with_capacity(b.len());
299 let mut i = 0;
300 while i < b.len() {
301 if b[i] == b'%' && i + 3 <= b.len() {
302 if let Ok(byte) = u8::from_str_radix(&s[i + 1..i + 3], 16) {
303 out.push(byte);
304 i += 3;
305 continue;
306 }
307 }
308 out.push(b[i]);
309 i += 1;
310 }
311 String::from_utf8_lossy(&out).to_string()
312}
313
314fn urlencode(s: &str) -> String {
316 let mut out = String::with_capacity(s.len());
317 for b in s.bytes() {
318 match b {
319 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
320 out.push(b as char)
321 }
322 _ => out.push_str(&format!("%{b:02X}")),
323 }
324 }
325 out
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use crate::credentials::InMemoryCredentialStore;
332 use std::sync::Arc;
333
334 #[test]
335 fn authorize_url_encodes_scopes_and_state() {
336 let url = build_authorize_url(
337 "https://example.test/authorize",
338 "client-123",
339 "http://127.0.0.1:7421/callback",
340 &["read".into(), "write things".into()],
341 "xyz",
342 );
343 assert!(url.starts_with("https://example.test/authorize?"));
344 assert!(url.contains("client_id=client-123"));
345 assert!(url.contains("scope=read%20write%20things"));
346 assert!(url.contains("redirect_uri=http%3A%2F%2F127.0.0.1%3A7421%2Fcallback"));
347 assert!(url.contains("state=xyz"));
348 }
349
350 #[test]
351 fn token_is_expired_within_skew() {
352 let mut t = OAuthToken {
353 access_token: "a".into(),
354 refresh_token: None,
355 expires_at: Some(Utc::now() + chrono::Duration::seconds(30)),
356 token_type: "Bearer".into(),
357 scope: String::new(),
358 };
359 assert!(t.is_expired(), "within 60s skew counts as expired");
360 t.expires_at = Some(Utc::now() + chrono::Duration::seconds(3600));
361 assert!(!t.is_expired());
362 t.expires_at = None;
363 assert!(!t.is_expired(), "no expiry => never expired");
364 }
365
366 #[test]
367 fn store_and_load_round_trip() {
368 let store: Arc<dyn CredentialStore> = Arc::new(InMemoryCredentialStore::new());
369 let token = OAuthToken {
370 access_token: "at".into(),
371 refresh_token: Some("rt".into()),
372 expires_at: None,
373 token_type: "Bearer".into(),
374 scope: "read".into(),
375 };
376 store_token(store.as_ref(), "acme.svc", &token).unwrap();
377 let loaded = load_token(store.as_ref(), "acme.svc").unwrap().unwrap();
378 assert_eq!(loaded.access_token, "at");
379 assert_eq!(loaded.refresh_token.as_deref(), Some("rt"));
380 assert!(load_token(store.as_ref(), "missing").unwrap().is_none());
381 }
382
383 #[test]
384 fn parse_query_decodes_callback_params() {
385 let m = parse_query("/callback?code=abc%2F123&state=xyz");
386 assert_eq!(m.get("code").map(String::as_str), Some("abc/123"));
387 assert_eq!(m.get("state").map(String::as_str), Some("xyz"));
388 assert!(parse_query("/favicon.ico").is_empty());
390 }
391}