1use async_trait::async_trait;
2use futures_util::StreamExt;
3use serde::Deserialize;
4use serde_json::json;
5use std::time::Instant;
6
7use crate::error::{redact_error_body, CloudAiError};
8use crate::registry::{CloudAiProvider, ModelCapabilityInfo};
9use crate::request::{CloudAiRequest, CloudAiResponse, ToolDispatcher, ToolSpec};
10use crate::stream::{LlmStreamEvent, LlmStreamSink};
11
12const NAME: &str = "claude";
13const ENV_VAR: &str = "ANTHROPIC_API_KEY";
14const ENDPOINT: &str = "https://api.anthropic.com/v1/messages";
15const ANTHROPIC_VERSION: &str = "2023-06-01";
16
17const DEFAULT_MODELS: &[&str] = &["claude-opus-4-7", "claude-sonnet-4-6", "claude-haiku-4-5"];
18
19pub struct ClaudeProvider {
20 client: reqwest::Client,
21}
22
23impl ClaudeProvider {
24 pub fn new() -> Self {
25 Self {
26 client: reqwest::Client::new(),
27 }
28 }
29}
30
31impl Default for ClaudeProvider {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37#[async_trait]
38impl CloudAiProvider for ClaudeProvider {
39 fn name(&self) -> &str {
40 NAME
41 }
42
43 fn env_var(&self) -> &str {
44 ENV_VAR
45 }
46
47 fn default_models(&self) -> &[&str] {
48 DEFAULT_MODELS
49 }
50
51 fn model_capabilities(&self) -> Vec<ModelCapabilityInfo> {
52 vec![
53 ModelCapabilityInfo::new(
54 "claude-opus-4-7",
55 &["reasoning", "vision", "tool_use", "code"],
56 ),
57 ModelCapabilityInfo::new(
58 "claude-sonnet-4-6",
59 &["reasoning", "vision", "tool_use", "code"],
60 ),
61 ModelCapabilityInfo::new("claude-haiku-4-5", &["vision", "tool_use", "code"]),
62 ]
63 }
64
65 async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError> {
66 let body = build_body(req);
67
68 let started = Instant::now();
69 let resp = self
70 .client
71 .post(ENDPOINT)
72 .header("x-api-key", &req.api_key)
73 .header("anthropic-version", ANTHROPIC_VERSION)
74 .header("content-type", "application/json")
75 .json(&body)
76 .send()
77 .await
78 .map_err(|e| CloudAiError::Http {
79 provider: NAME.into(),
80 reason: e.to_string(),
81 })?;
82
83 let status = resp.status();
84 let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
85 provider: NAME.into(),
86 reason: e.to_string(),
87 })?;
88 let latency_ms = started.elapsed().as_millis() as u64;
89
90 if !status.is_success() {
91 return Err(CloudAiError::Status {
92 provider: NAME.into(),
93 status: status.as_u16(),
94 body: redact_error_body(&raw_body),
95 });
96 }
97
98 parse_response(&raw_body, latency_ms)
99 }
100
101 async fn invoke_stream(
102 &self,
103 req: &CloudAiRequest,
104 sink: &dyn LlmStreamSink,
105 ) -> Result<CloudAiResponse, CloudAiError> {
106 let call_id = req.call_id.clone().unwrap_or_default();
107 let mut body = build_body(req);
108 body["stream"] = json!(true);
109 let started = Instant::now();
110 let resp = self
111 .client
112 .post(ENDPOINT)
113 .header("x-api-key", &req.api_key)
114 .header("anthropic-version", ANTHROPIC_VERSION)
115 .header("content-type", "application/json")
116 .header("accept", "text/event-stream")
117 .json(&body)
118 .send()
119 .await
120 .map_err(|e| CloudAiError::Http {
121 provider: NAME.into(),
122 reason: e.to_string(),
123 })?;
124
125 let status = resp.status();
126 if !status.is_success() {
127 let raw = resp.text().await.unwrap_or_default();
128 let err = CloudAiError::Status {
129 provider: NAME.into(),
130 status: status.as_u16(),
131 body: redact_error_body(&raw),
132 };
133 sink.emit(LlmStreamEvent::error(&call_id, NAME, err.to_string()))
134 .await;
135 return Err(err);
136 }
137
138 let mut stream = resp.bytes_stream();
139 let mut buffer = String::new();
140 let mut full_text = String::new();
141 let mut model_name = String::new();
142 let mut finish_reason = String::from("unknown");
143 let mut input_tokens: u32 = 0;
144 let mut output_tokens: u32 = 0;
145 let mut errored: Option<String> = None;
146
147 while let Some(chunk) = stream.next().await {
148 let bytes = match chunk {
149 Ok(b) => b,
150 Err(e) => {
151 errored = Some(e.to_string());
152 break;
153 }
154 };
155 buffer.push_str(&String::from_utf8_lossy(&bytes));
156 while let Some(idx) = buffer.find("\n\n") {
158 let frame = buffer[..idx].to_string();
159 buffer.drain(..idx + 2);
160 for raw_line in frame.split('\n') {
161 let line = raw_line.trim_start();
162 let Some(payload) = line.strip_prefix("data:") else {
163 continue;
164 };
165 let payload = payload.trim();
166 if payload.is_empty() {
167 continue;
168 }
169 let value: serde_json::Value = match serde_json::from_str(payload) {
170 Ok(v) => v,
171 Err(_) => continue,
172 };
173 let kind = value.get("type").and_then(|t| t.as_str()).unwrap_or("");
174 match kind {
175 "message_start" => {
176 if let Some(msg) = value.get("message") {
177 if let Some(m) = msg.get("model").and_then(|m| m.as_str()) {
178 model_name = m.to_string();
179 }
180 if let Some(u) = msg.get("usage") {
181 input_tokens = u
182 .get("input_tokens")
183 .and_then(|v| v.as_u64())
184 .map(|v| v as u32)
185 .unwrap_or(input_tokens);
186 }
187 }
188 }
189 "content_block_delta" => {
190 if let Some(delta) = value
191 .get("delta")
192 .and_then(|d| d.get("text"))
193 .and_then(|t| t.as_str())
194 {
195 if !delta.is_empty() {
196 full_text.push_str(delta);
197 sink.emit(LlmStreamEvent {
198 call_id: call_id.clone(),
199 provider: NAME.to_string(),
200 model: model_name.clone(),
201 delta: delta.to_string(),
202 done: false,
203 error: None,
204 })
205 .await;
206 }
207 }
208 }
209 "message_delta" => {
210 if let Some(reason) = value
211 .get("delta")
212 .and_then(|d| d.get("stop_reason"))
213 .and_then(|r| r.as_str())
214 {
215 finish_reason = reason.to_string();
216 }
217 if let Some(u) = value.get("usage") {
218 output_tokens = u
219 .get("output_tokens")
220 .and_then(|v| v.as_u64())
221 .map(|v| v as u32)
222 .unwrap_or(output_tokens);
223 }
224 }
225 "message_stop" => {
226 sink.emit(LlmStreamEvent {
227 call_id: call_id.clone(),
228 provider: NAME.to_string(),
229 model: model_name.clone(),
230 delta: String::new(),
231 done: true,
232 error: None,
233 })
234 .await;
235 let latency_ms = started.elapsed().as_millis() as u64;
236 return Ok(CloudAiResponse {
237 provider: NAME.to_string(),
238 model: if model_name.is_empty() {
239 "unknown".to_string()
240 } else {
241 model_name
242 },
243 text: full_text,
244 finish_reason,
245 input_tokens,
246 output_tokens,
247 latency_ms,
248 });
249 }
250 _ => {}
251 }
252 }
253 }
254 }
255
256 if let Some(err) = errored {
257 sink.emit(LlmStreamEvent::error(&call_id, NAME, err.clone()))
258 .await;
259 return Err(CloudAiError::Http {
260 provider: NAME.into(),
261 reason: err,
262 });
263 }
264 sink.emit(LlmStreamEvent {
267 call_id: call_id.clone(),
268 provider: NAME.to_string(),
269 model: model_name.clone(),
270 delta: String::new(),
271 done: true,
272 error: None,
273 })
274 .await;
275 if full_text.is_empty() {
276 return Err(CloudAiError::Shape {
277 provider: NAME.into(),
278 detail: "stream closed before any text was emitted".into(),
279 });
280 }
281 let latency_ms = started.elapsed().as_millis() as u64;
282 Ok(CloudAiResponse {
283 provider: NAME.to_string(),
284 model: if model_name.is_empty() {
285 "unknown".to_string()
286 } else {
287 model_name
288 },
289 text: full_text,
290 finish_reason,
291 input_tokens,
292 output_tokens,
293 latency_ms,
294 })
295 }
296
297 async fn invoke_tools(
298 &self,
299 req: &CloudAiRequest,
300 tools: &[ToolSpec],
301 dispatcher: &dyn ToolDispatcher,
302 max_iters: usize,
303 ) -> Result<CloudAiResponse, CloudAiError> {
304 let started = Instant::now();
305 let mut messages = vec![json!({ "role": "user", "content": user_content(req) })];
306 let tools_body: Vec<_> = tools
307 .iter()
308 .map(|t| {
309 json!({ "name": t.name, "description": t.description, "input_schema": t.parameters })
310 })
311 .collect();
312 let mut input_tokens: u32 = 0;
313 let mut output_tokens: u32 = 0;
314 let mut model_out = req.model.clone();
315
316 for _ in 0..max_iters.max(1) {
317 let mut body = json!({
318 "model": req.model,
319 "max_tokens": req.max_tokens.unwrap_or(1024),
320 "messages": messages,
321 "tools": tools_body,
322 });
323 if let Some(t) = req.temperature {
324 body["temperature"] = json!(t);
325 }
326 if let Some(s) = req.system.as_deref().filter(|s| !s.is_empty()) {
327 body["system"] = json!(s);
328 }
329 apply_thinking(&mut body, req);
330
331 let resp = self
332 .client
333 .post(ENDPOINT)
334 .header("x-api-key", &req.api_key)
335 .header("anthropic-version", ANTHROPIC_VERSION)
336 .header("content-type", "application/json")
337 .json(&body)
338 .send()
339 .await
340 .map_err(|e| CloudAiError::Http {
341 provider: NAME.into(),
342 reason: e.to_string(),
343 })?;
344 let status = resp.status();
345 let raw = resp.text().await.unwrap_or_default();
346 if !status.is_success() {
347 return Err(CloudAiError::Status {
348 provider: NAME.into(),
349 status: status.as_u16(),
350 body: redact_error_body(&raw),
351 });
352 }
353 let v: serde_json::Value =
354 serde_json::from_str(&raw).map_err(|e| CloudAiError::Parse {
355 provider: NAME.into(),
356 reason: e.to_string(),
357 })?;
358 if let Some(u) = v.get("usage") {
359 input_tokens = input_tokens
360 .saturating_add(u.get("input_tokens").and_then(|x| x.as_u64()).unwrap_or(0) as u32);
361 output_tokens = output_tokens.saturating_add(
362 u.get("output_tokens").and_then(|x| x.as_u64()).unwrap_or(0) as u32,
363 );
364 }
365 if let Some(m) = v.get("model").and_then(|m| m.as_str()) {
366 model_out = m.to_string();
367 }
368 let stop_reason = v
369 .get("stop_reason")
370 .and_then(|r| r.as_str())
371 .unwrap_or("end_turn")
372 .to_string();
373 let content = v
374 .get("content")
375 .and_then(|c| c.as_array())
376 .cloned()
377 .unwrap_or_default();
378
379 let mut text = String::new();
380 let mut tool_uses: Vec<(String, String, serde_json::Value)> = Vec::new();
381 for block in &content {
382 match block.get("type").and_then(|t| t.as_str()) {
383 Some("text") => {
384 if let Some(t) = block.get("text").and_then(|t| t.as_str()) {
385 text.push_str(t);
386 }
387 }
388 Some("tool_use") => {
389 tool_uses.push((
390 block.get("id").and_then(|i| i.as_str()).unwrap_or("").to_string(),
391 block.get("name").and_then(|n| n.as_str()).unwrap_or("").to_string(),
392 block.get("input").cloned().unwrap_or_else(|| json!({})),
393 ));
394 }
395 _ => {}
396 }
397 }
398
399 if tool_uses.is_empty() {
400 if text.is_empty() {
401 return Err(CloudAiError::Shape {
402 provider: NAME.into(),
403 detail: "no text content blocks in response".into(),
404 });
405 }
406 return Ok(CloudAiResponse {
407 provider: NAME.into(),
408 model: model_out,
409 text,
410 finish_reason: stop_reason,
411 input_tokens,
412 output_tokens,
413 latency_ms: started.elapsed().as_millis() as u64,
414 });
415 }
416
417 messages.push(json!({ "role": "assistant", "content": content }));
421 let mut results = Vec::new();
422 for (id, name, input) in &tool_uses {
423 let out = match dispatcher.call(name, input).await {
424 Ok(v) => v.to_string(),
425 Err(e) => json!({ "error": e }).to_string(),
426 };
427 results.push(json!({ "type": "tool_result", "tool_use_id": id, "content": out }));
428 }
429 messages.push(json!({ "role": "user", "content": results }));
430 }
431
432 Err(CloudAiError::Shape {
433 provider: NAME.into(),
434 detail: format!("tool loop did not converge within {max_iters} iterations"),
435 })
436 }
437}
438
439fn user_content(req: &CloudAiRequest) -> serde_json::Value {
442 if req.images.is_empty() {
443 return json!(req.prompt);
444 }
445 let mut blocks = vec![json!({ "type": "text", "text": req.prompt })];
446 for img in &req.images {
447 blocks.push(image_block(img));
448 }
449 json!(blocks)
450}
451
452fn image_block(img: &str) -> serde_json::Value {
455 if let Some(rest) = img.strip_prefix("data:") {
456 if let Some((media_type, data)) = rest.split_once(";base64,") {
457 return json!({
458 "type": "image",
459 "source": { "type": "base64", "media_type": media_type, "data": data },
460 });
461 }
462 }
463 json!({ "type": "image", "source": { "type": "url", "url": img } })
464}
465
466fn apply_thinking(payload: &mut serde_json::Value, req: &CloudAiRequest) {
470 if req.reasoning != Some(true) {
471 return;
472 }
473 let budget = 1024u32;
474 let max = payload["max_tokens"].as_u64().unwrap_or(1024) as u32;
475 if max <= budget {
476 payload["max_tokens"] = json!(budget + 1024);
477 }
478 if let Some(obj) = payload.as_object_mut() {
479 obj.remove("temperature");
480 }
481 payload["thinking"] = json!({ "type": "enabled", "budget_tokens": budget });
482}
483
484pub fn build_body(req: &CloudAiRequest) -> serde_json::Value {
485 let mut payload = json!({
486 "model": req.model,
487 "max_tokens": req.max_tokens.unwrap_or(1024),
488 "messages": [{ "role": "user", "content": user_content(req) }],
489 });
490 if let Some(t) = req.temperature {
491 payload["temperature"] = json!(t);
492 }
493 if let Some(s) = req.system.as_deref().filter(|s| !s.is_empty()) {
494 payload["system"] = json!(s);
499 }
500 apply_thinking(&mut payload, req);
501 payload
502}
503
504#[derive(Debug, Deserialize)]
505struct RawResponse {
506 model: String,
507 content: Vec<ContentBlock>,
508 stop_reason: Option<String>,
509 usage: Usage,
510}
511
512#[derive(Debug, Deserialize)]
513#[serde(tag = "type", rename_all = "snake_case")]
514enum ContentBlock {
515 Text {
516 text: String,
517 },
518 #[serde(other)]
519 Other,
520}
521
522#[derive(Debug, Deserialize)]
523struct Usage {
524 input_tokens: u32,
525 output_tokens: u32,
526}
527
528pub fn parse_response(body: &str, latency_ms: u64) -> Result<CloudAiResponse, CloudAiError> {
529 let raw: RawResponse = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
530 provider: NAME.into(),
531 reason: e.to_string(),
532 })?;
533
534 let text = raw
535 .content
536 .iter()
537 .filter_map(|b| match b {
538 ContentBlock::Text { text } => Some(text.as_str()),
539 ContentBlock::Other => None,
540 })
541 .collect::<Vec<_>>()
542 .join("");
543
544 if text.is_empty() {
545 return Err(CloudAiError::Shape {
546 provider: NAME.into(),
547 detail: "no text content blocks in response".into(),
548 });
549 }
550
551 Ok(CloudAiResponse {
552 provider: NAME.into(),
553 model: raw.model,
554 text,
555 finish_reason: raw.stop_reason.unwrap_or_else(|| "unknown".into()),
556 input_tokens: raw.usage.input_tokens,
557 output_tokens: raw.usage.output_tokens,
558 latency_ms,
559 })
560}