Skip to main content

flow_adapter_ai/providers/
claude.rs

1use async_trait::async_trait;
2use futures_util::StreamExt;
3use serde::Deserialize;
4use serde_json::json;
5use std::time::Instant;
6
7use crate::error::{redact_error_body, CloudAiError};
8use crate::registry::{CloudAiProvider, ModelCapabilityInfo};
9use crate::request::{CloudAiRequest, CloudAiResponse, ToolDispatcher, ToolSpec};
10use crate::stream::{LlmStreamEvent, LlmStreamSink};
11
12const NAME: &str = "claude";
13const ENV_VAR: &str = "ANTHROPIC_API_KEY";
14const ENDPOINT: &str = "https://api.anthropic.com/v1/messages";
15const ANTHROPIC_VERSION: &str = "2023-06-01";
16
17const DEFAULT_MODELS: &[&str] = &["claude-opus-4-7", "claude-sonnet-4-6", "claude-haiku-4-5"];
18
19pub struct ClaudeProvider {
20    client: reqwest::Client,
21}
22
23impl ClaudeProvider {
24    pub fn new() -> Self {
25        Self {
26            client: reqwest::Client::new(),
27        }
28    }
29}
30
31impl Default for ClaudeProvider {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37#[async_trait]
38impl CloudAiProvider for ClaudeProvider {
39    fn name(&self) -> &str {
40        NAME
41    }
42
43    fn env_var(&self) -> &str {
44        ENV_VAR
45    }
46
47    fn default_models(&self) -> &[&str] {
48        DEFAULT_MODELS
49    }
50
51    fn model_capabilities(&self) -> Vec<ModelCapabilityInfo> {
52        vec![
53            ModelCapabilityInfo::new(
54                "claude-opus-4-7",
55                &["reasoning", "vision", "tool_use", "code"],
56            ),
57            ModelCapabilityInfo::new(
58                "claude-sonnet-4-6",
59                &["reasoning", "vision", "tool_use", "code"],
60            ),
61            ModelCapabilityInfo::new("claude-haiku-4-5", &["vision", "tool_use", "code"]),
62        ]
63    }
64
65    async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError> {
66        let body = build_body(req);
67
68        let started = Instant::now();
69        let resp = self
70            .client
71            .post(ENDPOINT)
72            .header("x-api-key", &req.api_key)
73            .header("anthropic-version", ANTHROPIC_VERSION)
74            .header("content-type", "application/json")
75            .json(&body)
76            .send()
77            .await
78            .map_err(|e| CloudAiError::Http {
79                provider: NAME.into(),
80                reason: e.to_string(),
81            })?;
82
83        let status = resp.status();
84        let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
85            provider: NAME.into(),
86            reason: e.to_string(),
87        })?;
88        let latency_ms = started.elapsed().as_millis() as u64;
89
90        if !status.is_success() {
91            return Err(CloudAiError::Status {
92                provider: NAME.into(),
93                status: status.as_u16(),
94                body: redact_error_body(&raw_body),
95            });
96        }
97
98        parse_response(&raw_body, latency_ms)
99    }
100
101    async fn invoke_stream(
102        &self,
103        req: &CloudAiRequest,
104        sink: &dyn LlmStreamSink,
105    ) -> Result<CloudAiResponse, CloudAiError> {
106        let call_id = req.call_id.clone().unwrap_or_default();
107        let mut body = build_body(req);
108        body["stream"] = json!(true);
109        let started = Instant::now();
110        let resp = self
111            .client
112            .post(ENDPOINT)
113            .header("x-api-key", &req.api_key)
114            .header("anthropic-version", ANTHROPIC_VERSION)
115            .header("content-type", "application/json")
116            .header("accept", "text/event-stream")
117            .json(&body)
118            .send()
119            .await
120            .map_err(|e| CloudAiError::Http {
121                provider: NAME.into(),
122                reason: e.to_string(),
123            })?;
124
125        let status = resp.status();
126        if !status.is_success() {
127            let raw = resp.text().await.unwrap_or_default();
128            let err = CloudAiError::Status {
129                provider: NAME.into(),
130                status: status.as_u16(),
131                body: redact_error_body(&raw),
132            };
133            sink.emit(LlmStreamEvent::error(&call_id, NAME, err.to_string()))
134                .await;
135            return Err(err);
136        }
137
138        let mut stream = resp.bytes_stream();
139        let mut buffer = String::new();
140        let mut full_text = String::new();
141        let mut model_name = String::new();
142        let mut finish_reason = String::from("unknown");
143        let mut input_tokens: u32 = 0;
144        let mut output_tokens: u32 = 0;
145        let mut errored: Option<String> = None;
146
147        while let Some(chunk) = stream.next().await {
148            let bytes = match chunk {
149                Ok(b) => b,
150                Err(e) => {
151                    errored = Some(e.to_string());
152                    break;
153                }
154            };
155            buffer.push_str(&String::from_utf8_lossy(&bytes));
156            // Anthropic frames are separated by `\n\n` like generic SSE.
157            while let Some(idx) = buffer.find("\n\n") {
158                let frame = buffer[..idx].to_string();
159                buffer.drain(..idx + 2);
160                for raw_line in frame.split('\n') {
161                    let line = raw_line.trim_start();
162                    let Some(payload) = line.strip_prefix("data:") else {
163                        continue;
164                    };
165                    let payload = payload.trim();
166                    if payload.is_empty() {
167                        continue;
168                    }
169                    let value: serde_json::Value = match serde_json::from_str(payload) {
170                        Ok(v) => v,
171                        Err(_) => continue,
172                    };
173                    let kind = value.get("type").and_then(|t| t.as_str()).unwrap_or("");
174                    match kind {
175                        "message_start" => {
176                            if let Some(msg) = value.get("message") {
177                                if let Some(m) = msg.get("model").and_then(|m| m.as_str()) {
178                                    model_name = m.to_string();
179                                }
180                                if let Some(u) = msg.get("usage") {
181                                    input_tokens = u
182                                        .get("input_tokens")
183                                        .and_then(|v| v.as_u64())
184                                        .map(|v| v as u32)
185                                        .unwrap_or(input_tokens);
186                                }
187                            }
188                        }
189                        "content_block_delta" => {
190                            if let Some(delta) = value
191                                .get("delta")
192                                .and_then(|d| d.get("text"))
193                                .and_then(|t| t.as_str())
194                            {
195                                if !delta.is_empty() {
196                                    full_text.push_str(delta);
197                                    sink.emit(LlmStreamEvent {
198                                        call_id: call_id.clone(),
199                                        provider: NAME.to_string(),
200                                        model: model_name.clone(),
201                                        delta: delta.to_string(),
202                                        done: false,
203                                        error: None,
204                                    })
205                                    .await;
206                                }
207                            }
208                        }
209                        "message_delta" => {
210                            if let Some(reason) = value
211                                .get("delta")
212                                .and_then(|d| d.get("stop_reason"))
213                                .and_then(|r| r.as_str())
214                            {
215                                finish_reason = reason.to_string();
216                            }
217                            if let Some(u) = value.get("usage") {
218                                output_tokens = u
219                                    .get("output_tokens")
220                                    .and_then(|v| v.as_u64())
221                                    .map(|v| v as u32)
222                                    .unwrap_or(output_tokens);
223                            }
224                        }
225                        "message_stop" => {
226                            sink.emit(LlmStreamEvent {
227                                call_id: call_id.clone(),
228                                provider: NAME.to_string(),
229                                model: model_name.clone(),
230                                delta: String::new(),
231                                done: true,
232                                error: None,
233                            })
234                            .await;
235                            let latency_ms = started.elapsed().as_millis() as u64;
236                            return Ok(CloudAiResponse {
237                                provider: NAME.to_string(),
238                                model: if model_name.is_empty() {
239                                    "unknown".to_string()
240                                } else {
241                                    model_name
242                                },
243                                text: full_text,
244                                finish_reason,
245                                input_tokens,
246                                output_tokens,
247                                latency_ms,
248                            });
249                        }
250                        _ => {}
251                    }
252                }
253            }
254        }
255
256        if let Some(err) = errored {
257            sink.emit(LlmStreamEvent::error(&call_id, NAME, err.clone()))
258                .await;
259            return Err(CloudAiError::Http {
260                provider: NAME.into(),
261                reason: err,
262            });
263        }
264        // Stream ended without an explicit `message_stop`. Treat it as
265        // best-effort: emit the terminal event and return what we have.
266        sink.emit(LlmStreamEvent {
267            call_id: call_id.clone(),
268            provider: NAME.to_string(),
269            model: model_name.clone(),
270            delta: String::new(),
271            done: true,
272            error: None,
273        })
274        .await;
275        if full_text.is_empty() {
276            return Err(CloudAiError::Shape {
277                provider: NAME.into(),
278                detail: "stream closed before any text was emitted".into(),
279            });
280        }
281        let latency_ms = started.elapsed().as_millis() as u64;
282        Ok(CloudAiResponse {
283            provider: NAME.to_string(),
284            model: if model_name.is_empty() {
285                "unknown".to_string()
286            } else {
287                model_name
288            },
289            text: full_text,
290            finish_reason,
291            input_tokens,
292            output_tokens,
293            latency_ms,
294        })
295    }
296
297    async fn invoke_tools(
298        &self,
299        req: &CloudAiRequest,
300        tools: &[ToolSpec],
301        dispatcher: &dyn ToolDispatcher,
302        max_iters: usize,
303    ) -> Result<CloudAiResponse, CloudAiError> {
304        let started = Instant::now();
305        let mut messages = vec![json!({ "role": "user", "content": user_content(req) })];
306        let tools_body: Vec<_> = tools
307            .iter()
308            .map(|t| {
309                json!({ "name": t.name, "description": t.description, "input_schema": t.parameters })
310            })
311            .collect();
312        let mut input_tokens: u32 = 0;
313        let mut output_tokens: u32 = 0;
314        let mut model_out = req.model.clone();
315
316        for _ in 0..max_iters.max(1) {
317            let mut body = json!({
318                "model": req.model,
319                "max_tokens": req.max_tokens.unwrap_or(1024),
320                "messages": messages,
321                "tools": tools_body,
322            });
323            if let Some(t) = req.temperature {
324                body["temperature"] = json!(t);
325            }
326            if let Some(s) = req.system.as_deref().filter(|s| !s.is_empty()) {
327                body["system"] = json!(s);
328            }
329            apply_thinking(&mut body, req);
330
331            let resp = self
332                .client
333                .post(ENDPOINT)
334                .header("x-api-key", &req.api_key)
335                .header("anthropic-version", ANTHROPIC_VERSION)
336                .header("content-type", "application/json")
337                .json(&body)
338                .send()
339                .await
340                .map_err(|e| CloudAiError::Http {
341                    provider: NAME.into(),
342                    reason: e.to_string(),
343                })?;
344            let status = resp.status();
345            let raw = resp.text().await.unwrap_or_default();
346            if !status.is_success() {
347                return Err(CloudAiError::Status {
348                    provider: NAME.into(),
349                    status: status.as_u16(),
350                    body: redact_error_body(&raw),
351                });
352            }
353            let v: serde_json::Value =
354                serde_json::from_str(&raw).map_err(|e| CloudAiError::Parse {
355                    provider: NAME.into(),
356                    reason: e.to_string(),
357                })?;
358            if let Some(u) = v.get("usage") {
359                input_tokens = input_tokens
360                    .saturating_add(u.get("input_tokens").and_then(|x| x.as_u64()).unwrap_or(0) as u32);
361                output_tokens = output_tokens.saturating_add(
362                    u.get("output_tokens").and_then(|x| x.as_u64()).unwrap_or(0) as u32,
363                );
364            }
365            if let Some(m) = v.get("model").and_then(|m| m.as_str()) {
366                model_out = m.to_string();
367            }
368            let stop_reason = v
369                .get("stop_reason")
370                .and_then(|r| r.as_str())
371                .unwrap_or("end_turn")
372                .to_string();
373            let content = v
374                .get("content")
375                .and_then(|c| c.as_array())
376                .cloned()
377                .unwrap_or_default();
378
379            let mut text = String::new();
380            let mut tool_uses: Vec<(String, String, serde_json::Value)> = Vec::new();
381            for block in &content {
382                match block.get("type").and_then(|t| t.as_str()) {
383                    Some("text") => {
384                        if let Some(t) = block.get("text").and_then(|t| t.as_str()) {
385                            text.push_str(t);
386                        }
387                    }
388                    Some("tool_use") => {
389                        tool_uses.push((
390                            block.get("id").and_then(|i| i.as_str()).unwrap_or("").to_string(),
391                            block.get("name").and_then(|n| n.as_str()).unwrap_or("").to_string(),
392                            block.get("input").cloned().unwrap_or_else(|| json!({})),
393                        ));
394                    }
395                    _ => {}
396                }
397            }
398
399            if tool_uses.is_empty() {
400                if text.is_empty() {
401                    return Err(CloudAiError::Shape {
402                        provider: NAME.into(),
403                        detail: "no text content blocks in response".into(),
404                    });
405                }
406                return Ok(CloudAiResponse {
407                    provider: NAME.into(),
408                    model: model_out,
409                    text,
410                    finish_reason: stop_reason,
411                    input_tokens,
412                    output_tokens,
413                    latency_ms: started.elapsed().as_millis() as u64,
414                });
415            }
416
417            // Echo the assistant turn, then return each tool result as a user
418            // turn of `tool_result` blocks (errors fed back so the model can
419            // recover).
420            messages.push(json!({ "role": "assistant", "content": content }));
421            let mut results = Vec::new();
422            for (id, name, input) in &tool_uses {
423                let out = match dispatcher.call(name, input).await {
424                    Ok(v) => v.to_string(),
425                    Err(e) => json!({ "error": e }).to_string(),
426                };
427                results.push(json!({ "type": "tool_result", "tool_use_id": id, "content": out }));
428            }
429            messages.push(json!({ "role": "user", "content": results }));
430        }
431
432        Err(CloudAiError::Shape {
433            provider: NAME.into(),
434            detail: format!("tool loop did not converge within {max_iters} iterations"),
435        })
436    }
437}
438
439/// User-message content: a plain string, or `text` + `image` blocks when the
440/// request carries images (Claude multimodal).
441fn user_content(req: &CloudAiRequest) -> serde_json::Value {
442    if req.images.is_empty() {
443        return json!(req.prompt);
444    }
445    let mut blocks = vec![json!({ "type": "text", "text": req.prompt })];
446    for img in &req.images {
447        blocks.push(image_block(img));
448    }
449    json!(blocks)
450}
451
452/// A Claude `image` content block: a `data:` URL becomes a base64 source, an
453/// `http(s)` URL a url source.
454fn image_block(img: &str) -> serde_json::Value {
455    if let Some(rest) = img.strip_prefix("data:") {
456        if let Some((media_type, data)) = rest.split_once(";base64,") {
457            return json!({
458                "type": "image",
459                "source": { "type": "base64", "media_type": media_type, "data": data },
460            });
461        }
462    }
463    json!({ "type": "image", "source": { "type": "url", "url": img } })
464}
465
466/// Apply the optional extended-thinking config. Claude requires
467/// `budget_tokens < max_tokens` and disallows a custom temperature while
468/// thinking, so bump max_tokens and drop temperature when enabling it.
469fn apply_thinking(payload: &mut serde_json::Value, req: &CloudAiRequest) {
470    if req.reasoning != Some(true) {
471        return;
472    }
473    let budget = 1024u32;
474    let max = payload["max_tokens"].as_u64().unwrap_or(1024) as u32;
475    if max <= budget {
476        payload["max_tokens"] = json!(budget + 1024);
477    }
478    if let Some(obj) = payload.as_object_mut() {
479        obj.remove("temperature");
480    }
481    payload["thinking"] = json!({ "type": "enabled", "budget_tokens": budget });
482}
483
484pub fn build_body(req: &CloudAiRequest) -> serde_json::Value {
485    let mut payload = json!({
486        "model": req.model,
487        "max_tokens": req.max_tokens.unwrap_or(1024),
488        "messages": [{ "role": "user", "content": user_content(req) }],
489    });
490    if let Some(t) = req.temperature {
491        payload["temperature"] = json!(t);
492    }
493    if let Some(s) = req.system.as_deref().filter(|s| !s.is_empty()) {
494        // Claude's Messages API takes the system prompt as a top-level
495        // string, not as a message in `messages`. Sending it via the
496        // dedicated field improves instruction-following and lets
497        // Anthropic's prompt cache pin the system text across turns.
498        payload["system"] = json!(s);
499    }
500    apply_thinking(&mut payload, req);
501    payload
502}
503
504#[derive(Debug, Deserialize)]
505struct RawResponse {
506    model: String,
507    content: Vec<ContentBlock>,
508    stop_reason: Option<String>,
509    usage: Usage,
510}
511
512#[derive(Debug, Deserialize)]
513#[serde(tag = "type", rename_all = "snake_case")]
514enum ContentBlock {
515    Text {
516        text: String,
517    },
518    #[serde(other)]
519    Other,
520}
521
522#[derive(Debug, Deserialize)]
523struct Usage {
524    input_tokens: u32,
525    output_tokens: u32,
526}
527
528pub fn parse_response(body: &str, latency_ms: u64) -> Result<CloudAiResponse, CloudAiError> {
529    let raw: RawResponse = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
530        provider: NAME.into(),
531        reason: e.to_string(),
532    })?;
533
534    let text = raw
535        .content
536        .iter()
537        .filter_map(|b| match b {
538            ContentBlock::Text { text } => Some(text.as_str()),
539            ContentBlock::Other => None,
540        })
541        .collect::<Vec<_>>()
542        .join("");
543
544    if text.is_empty() {
545        return Err(CloudAiError::Shape {
546            provider: NAME.into(),
547            detail: "no text content blocks in response".into(),
548        });
549    }
550
551    Ok(CloudAiResponse {
552        provider: NAME.into(),
553        model: raw.model,
554        text,
555        finish_reason: raw.stop_reason.unwrap_or_else(|| "unknown".into()),
556        input_tokens: raw.usage.input_tokens,
557        output_tokens: raw.usage.output_tokens,
558        latency_ms,
559    })
560}