Skip to main content

flow_security/
oauth.rs

1//! OAuth2 (authorization-code + refresh) for `service`-node connections.
2//!
3//! Vendor-neutral: every endpoint (authorize / token URL), client id, scope set,
4//! and redirect URI is supplied by the caller from catalog + operator data -
5//! nothing about any specific provider is hard-coded here. Tokens are stored in
6//! the OS keyring under the `service:<slug>` account as a JSON [`OAuthToken`].
7
8use std::collections::HashMap;
9use std::time::Duration;
10
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::TcpListener;
16
17use crate::credentials::{CredentialError, CredentialKind, CredentialStore};
18
19#[derive(Debug, Error)]
20pub enum OAuthError {
21    #[error("oauth http error: {0}")]
22    Http(String),
23    #[error("oauth token endpoint returned status {status}: {body}")]
24    Token { status: u16, body: String },
25    #[error("no refresh token stored; reconnect the service")]
26    NoRefreshToken,
27    #[error(transparent)]
28    Credential(#[from] CredentialError),
29    #[error("token store decode error: {0}")]
30    Decode(String),
31}
32
33/// A stored OAuth2 token bundle. Serialized to JSON in the keyring.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct OAuthToken {
36    pub access_token: String,
37    #[serde(default, skip_serializing_if = "Option::is_none")]
38    pub refresh_token: Option<String>,
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub expires_at: Option<DateTime<Utc>>,
41    #[serde(default)]
42    pub token_type: String,
43    #[serde(default)]
44    pub scope: String,
45}
46
47impl OAuthToken {
48    /// True when the access token is expired (or within a 60s skew window).
49    pub fn is_expired(&self) -> bool {
50        match self.expires_at {
51            Some(exp) => Utc::now() + chrono::Duration::seconds(60) >= exp,
52            None => false,
53        }
54    }
55}
56
57/// Raw token-endpoint response (standard OAuth2 fields).
58#[derive(Debug, Deserialize)]
59struct TokenResponse {
60    access_token: String,
61    #[serde(default)]
62    refresh_token: Option<String>,
63    #[serde(default)]
64    expires_in: Option<i64>,
65    #[serde(default)]
66    token_type: String,
67    #[serde(default)]
68    scope: String,
69}
70
71impl TokenResponse {
72    fn into_token(self, prior_refresh: Option<String>) -> OAuthToken {
73        OAuthToken {
74            // Providers commonly omit `refresh_token` on refresh - keep the prior.
75            refresh_token: self.refresh_token.or(prior_refresh),
76            expires_at: self
77                .expires_in
78                .map(|s| Utc::now() + chrono::Duration::seconds(s)),
79            access_token: self.access_token,
80            token_type: self.token_type,
81            scope: self.scope,
82        }
83    }
84}
85
86/// Build the authorization-code consent URL the user opens in a browser.
87pub fn build_authorize_url(
88    auth_url: &str,
89    client_id: &str,
90    redirect_uri: &str,
91    scopes: &[String],
92    state: &str,
93) -> String {
94    let scope = scopes.join(" ");
95    let pairs = [
96        ("response_type", "code"),
97        ("client_id", client_id),
98        ("redirect_uri", redirect_uri),
99        ("scope", scope.as_str()),
100        ("state", state),
101        // Request a refresh token (offline access) + force the consent screen so
102        // a refresh token is always returned on first connect.
103        ("access_type", "offline"),
104        ("prompt", "consent"),
105    ];
106    let query = pairs
107        .iter()
108        .map(|(k, v)| format!("{k}={}", urlencode(v)))
109        .collect::<Vec<_>>()
110        .join("&");
111    let sep = if auth_url.contains('?') { "&" } else { "?" };
112    format!("{auth_url}{sep}{query}")
113}
114
115/// Exchange an authorization code for a token bundle.
116pub async fn exchange_code(
117    token_url: &str,
118    client_id: &str,
119    client_secret: &str,
120    code: &str,
121    redirect_uri: &str,
122) -> Result<OAuthToken, OAuthError> {
123    let params = [
124        ("grant_type", "authorization_code"),
125        ("code", code),
126        ("client_id", client_id),
127        ("client_secret", client_secret),
128        ("redirect_uri", redirect_uri),
129    ];
130    post_token(token_url, &params, None).await
131}
132
133/// Refresh an access token using a stored refresh token.
134pub async fn refresh(
135    token_url: &str,
136    client_id: &str,
137    client_secret: &str,
138    refresh_token: &str,
139) -> Result<OAuthToken, OAuthError> {
140    let params = [
141        ("grant_type", "refresh_token"),
142        ("refresh_token", refresh_token),
143        ("client_id", client_id),
144        ("client_secret", client_secret),
145    ];
146    post_token(token_url, &params, Some(refresh_token.to_string())).await
147}
148
149async fn post_token(
150    token_url: &str,
151    params: &[(&str, &str)],
152    prior_refresh: Option<String>,
153) -> Result<OAuthToken, OAuthError> {
154    let client = reqwest::Client::builder()
155        .timeout(Duration::from_secs(30))
156        .build()
157        .map_err(|e| OAuthError::Http(e.to_string()))?;
158    let resp = client
159        .post(token_url)
160        .form(params)
161        .send()
162        .await
163        .map_err(|e| OAuthError::Http(e.to_string()))?;
164    let status = resp.status();
165    let body = resp
166        .text()
167        .await
168        .map_err(|e| OAuthError::Http(e.to_string()))?;
169    if !status.is_success() {
170        return Err(OAuthError::Token {
171            status: status.as_u16(),
172            body,
173        });
174    }
175    let parsed: TokenResponse =
176        serde_json::from_str(&body).map_err(|e| OAuthError::Decode(e.to_string()))?;
177    Ok(parsed.into_token(prior_refresh))
178}
179
180/// Persist a token bundle to the keyring under `service:<slug>`.
181pub fn store_token(
182    store: &dyn CredentialStore,
183    slug: &str,
184    token: &OAuthToken,
185) -> Result<(), OAuthError> {
186    let account = CredentialKind::Service.account_for(slug);
187    let json = serde_json::to_string(token).map_err(|e| OAuthError::Decode(e.to_string()))?;
188    store.set(&account, &json)?;
189    Ok(())
190}
191
192/// Load a token bundle from the keyring. Returns `Ok(None)` when none is stored.
193pub fn load_token(store: &dyn CredentialStore, slug: &str) -> Result<Option<OAuthToken>, OAuthError> {
194    let account = CredentialKind::Service.account_for(slug);
195    match store.get(&account) {
196        Ok(raw) => Ok(Some(
197            serde_json::from_str(&raw).map_err(|e| OAuthError::Decode(e.to_string()))?,
198        )),
199        Err(CredentialError::NotFound { .. }) => Ok(None),
200        Err(e) => Err(e.into()),
201    }
202}
203
204/// Bind a one-shot loopback listener for the OAuth redirect on
205/// `127.0.0.1:<port>`. Bind **before** opening the browser so the redirect can
206/// never arrive before we're listening.
207pub async fn bind_loopback(port: u16) -> Result<TcpListener, OAuthError> {
208    TcpListener::bind(("127.0.0.1", port))
209        .await
210        .map_err(|e| OAuthError::Http(format!("bind 127.0.0.1:{port}: {e}")))
211}
212
213/// Wait on `listener` for the provider's redirect and return the `code` once a
214/// request arrives whose `state` matches `expected_state`. Serves a tiny
215/// "you can close this tab" page; ignores non-callback hits (e.g. favicon).
216pub async fn accept_oauth_code(
217    listener: TcpListener,
218    expected_state: &str,
219    timeout: Duration,
220) -> Result<String, OAuthError> {
221    let work = async {
222        loop {
223            let (mut stream, _) = listener
224                .accept()
225                .await
226                .map_err(|e| OAuthError::Http(e.to_string()))?;
227            let mut buf = vec![0u8; 4096];
228            let n = stream.read(&mut buf).await.unwrap_or(0);
229            let req = String::from_utf8_lossy(&buf[..n]);
230            // Request line: `GET /callback?code=...&state=... HTTP/1.1`.
231            let path = req
232                .lines()
233                .next()
234                .and_then(|l| l.split_whitespace().nth(1))
235                .unwrap_or("");
236            let params = parse_query(path);
237
238            let (status, message) = if params.contains_key("error") {
239                ("400 Bad Request", "Authorization was denied. You can close this tab.")
240            } else if params.contains_key("code") {
241                ("200 OK", "Connected. You can close this tab and return to Flow Studio.")
242            } else {
243                ("404 Not Found", "Not found.")
244            };
245            let body = format!(
246                "<!doctype html><meta charset=utf-8><body style=\"font-family:system-ui;padding:2rem\">{message}</body>"
247            );
248            let resp = format!(
249                "HTTP/1.1 {status}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
250                body.len()
251            );
252            let _ = stream.write_all(resp.as_bytes()).await;
253            let _ = stream.shutdown().await;
254
255            if let Some(err) = params.get("error") {
256                return Err(OAuthError::Token {
257                    status: 400,
258                    body: format!("authorization denied: {err}"),
259                });
260            }
261            if let Some(code) = params.get("code") {
262                return match params.get("state") {
263                    Some(s) if s == expected_state => Ok(code.clone()),
264                    _ => Err(OAuthError::Token {
265                        status: 400,
266                        body: "OAuth state mismatch (possible CSRF)".into(),
267                    }),
268                };
269            }
270            // Not the callback - keep waiting for the real redirect.
271        }
272    };
273    match tokio::time::timeout(timeout, work).await {
274        Ok(res) => res,
275        Err(_) => Err(OAuthError::Http(
276            "timed out waiting for the OAuth redirect".into(),
277        )),
278    }
279}
280
281/// Parse a `?k=v&k2=v2` query string off a request target into a map.
282fn parse_query(path: &str) -> HashMap<String, String> {
283    let mut map = HashMap::new();
284    if let Some((_, q)) = path.split_once('?') {
285        for pair in q.split('&') {
286            if let Some((k, v)) = pair.split_once('=') {
287                map.insert(k.to_string(), urldecode(v));
288            }
289        }
290    }
291    map
292}
293
294/// Minimal percent-decode (`%XX` + `+` → space) for query values.
295fn urldecode(s: &str) -> String {
296    let s = s.replace('+', " ");
297    let b = s.as_bytes();
298    let mut out = Vec::with_capacity(b.len());
299    let mut i = 0;
300    while i < b.len() {
301        if b[i] == b'%' && i + 3 <= b.len() {
302            if let Ok(byte) = u8::from_str_radix(&s[i + 1..i + 3], 16) {
303                out.push(byte);
304                i += 3;
305                continue;
306            }
307        }
308        out.push(b[i]);
309        i += 1;
310    }
311    String::from_utf8_lossy(&out).to_string()
312}
313
314/// Minimal percent-encoding for query values (RFC 3986 unreserved kept as-is).
315fn urlencode(s: &str) -> String {
316    let mut out = String::with_capacity(s.len());
317    for b in s.bytes() {
318        match b {
319            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
320                out.push(b as char)
321            }
322            _ => out.push_str(&format!("%{b:02X}")),
323        }
324    }
325    out
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use crate::credentials::InMemoryCredentialStore;
332    use std::sync::Arc;
333
334    #[test]
335    fn authorize_url_encodes_scopes_and_state() {
336        let url = build_authorize_url(
337            "https://example.test/authorize",
338            "client-123",
339            "http://127.0.0.1:7421/callback",
340            &["read".into(), "write things".into()],
341            "xyz",
342        );
343        assert!(url.starts_with("https://example.test/authorize?"));
344        assert!(url.contains("client_id=client-123"));
345        assert!(url.contains("scope=read%20write%20things"));
346        assert!(url.contains("redirect_uri=http%3A%2F%2F127.0.0.1%3A7421%2Fcallback"));
347        assert!(url.contains("state=xyz"));
348    }
349
350    #[test]
351    fn token_is_expired_within_skew() {
352        let mut t = OAuthToken {
353            access_token: "a".into(),
354            refresh_token: None,
355            expires_at: Some(Utc::now() + chrono::Duration::seconds(30)),
356            token_type: "Bearer".into(),
357            scope: String::new(),
358        };
359        assert!(t.is_expired(), "within 60s skew counts as expired");
360        t.expires_at = Some(Utc::now() + chrono::Duration::seconds(3600));
361        assert!(!t.is_expired());
362        t.expires_at = None;
363        assert!(!t.is_expired(), "no expiry => never expired");
364    }
365
366    #[test]
367    fn store_and_load_round_trip() {
368        let store: Arc<dyn CredentialStore> = Arc::new(InMemoryCredentialStore::new());
369        let token = OAuthToken {
370            access_token: "at".into(),
371            refresh_token: Some("rt".into()),
372            expires_at: None,
373            token_type: "Bearer".into(),
374            scope: "read".into(),
375        };
376        store_token(store.as_ref(), "acme.svc", &token).unwrap();
377        let loaded = load_token(store.as_ref(), "acme.svc").unwrap().unwrap();
378        assert_eq!(loaded.access_token, "at");
379        assert_eq!(loaded.refresh_token.as_deref(), Some("rt"));
380        assert!(load_token(store.as_ref(), "missing").unwrap().is_none());
381    }
382
383    #[test]
384    fn parse_query_decodes_callback_params() {
385        let m = parse_query("/callback?code=abc%2F123&state=xyz");
386        assert_eq!(m.get("code").map(String::as_str), Some("abc/123"));
387        assert_eq!(m.get("state").map(String::as_str), Some("xyz"));
388        // A non-callback request (favicon, root) yields no params.
389        assert!(parse_query("/favicon.ico").is_empty());
390    }
391}