1use futures_util::StreamExt;
10use serde::Deserialize;
11use serde_json::json;
12use std::time::Instant;
13
14use crate::error::{redact_error_body, CloudAiError};
15use crate::request::{
16 CloudAiRequest, CloudAiResponse, EmbeddingResponse, ToolDispatcher, ToolSpec,
17};
18use crate::stream::{LlmStreamEvent, LlmStreamSink};
19
20fn user_content(req: &CloudAiRequest) -> serde_json::Value {
25 if req.images.is_empty() {
26 return json!(req.prompt);
27 }
28 let mut parts = vec![json!({ "type": "text", "text": req.prompt })];
29 for url in &req.images {
30 parts.push(json!({ "type": "image_url", "image_url": { "url": url } }));
31 }
32 json!(parts)
33}
34
35fn build_messages(req: &CloudAiRequest) -> Vec<serde_json::Value> {
38 let mut messages = Vec::with_capacity(2);
39 if let Some(s) = req.system.as_deref().filter(|s| !s.is_empty()) {
40 messages.push(json!({ "role": "system", "content": s }));
45 }
46 messages.push(json!({ "role": "user", "content": user_content(req) }));
47 messages
48}
49
50fn tools_json(tools: &[ToolSpec]) -> Option<serde_json::Value> {
53 if tools.is_empty() {
54 return None;
55 }
56 let arr: Vec<_> = tools
57 .iter()
58 .map(|t| {
59 json!({
60 "type": "function",
61 "function": {
62 "name": t.name,
63 "description": t.description,
64 "parameters": t.parameters,
65 },
66 })
67 })
68 .collect();
69 Some(json!(arr))
70}
71
72fn apply_params(payload: &mut serde_json::Value, req: &CloudAiRequest) {
74 if let Some(m) = req.max_tokens {
75 payload["max_tokens"] = json!(m);
76 }
77 if let Some(t) = req.temperature {
78 payload["temperature"] = json!(t);
79 }
80 if let Some(p) = req.top_p {
81 payload["top_p"] = json!(p);
82 }
83 if let Some(k) = req.top_k {
84 payload["top_k"] = json!(k);
89 }
90 if let Some(stops) = req.stop.as_ref().filter(|s| !s.is_empty()) {
91 payload["stop"] = json!(stops);
92 }
93 if let Some(think) = req.reasoning {
94 payload["chat_template_kwargs"] = json!({ "enable_thinking": think });
98 }
99}
100
101pub fn build_body(req: &CloudAiRequest) -> serde_json::Value {
103 let mut payload = json!({
104 "model": req.model,
105 "messages": build_messages(req),
106 });
107 apply_params(&mut payload, req);
108 if let Some(tools) = tools_json(&req.tools) {
109 payload["tools"] = tools;
110 }
111 if let Some(schema) = &req.response_schema {
112 payload["response_format"] = json!({
117 "type": "json_schema",
118 "json_schema": { "name": "output", "schema": schema, "strict": false },
119 });
120 }
121 payload
122}
123
124pub fn strip_reasoning(text: &str) -> String {
129 let mut out = String::with_capacity(text.len());
130 let mut rest = text;
131 while let Some(start) = rest.find("<think>") {
132 out.push_str(&rest[..start]);
133 match rest[start..].find("</think>") {
134 Some(end) => rest = &rest[start + end + "</think>".len()..],
135 None => {
138 rest = "";
139 break;
140 }
141 }
142 }
143 out.push_str(rest);
144 out.trim().to_string()
145}
146
147#[derive(Debug, Deserialize)]
148struct RawResponse {
149 model: String,
150 choices: Vec<Choice>,
151 usage: Option<Usage>,
152}
153
154#[derive(Debug, Deserialize)]
155struct Choice {
156 message: ChoiceMessage,
157 finish_reason: Option<String>,
158}
159
160#[derive(Debug, Deserialize)]
161struct ChoiceMessage {
162 #[serde(default)]
163 content: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167struct Usage {
168 #[serde(default)]
169 prompt_tokens: u32,
170 #[serde(default)]
171 completion_tokens: u32,
172}
173
174pub fn parse_response(
177 provider: &str,
178 body: &str,
179 latency_ms: u64,
180) -> Result<CloudAiResponse, CloudAiError> {
181 let raw: RawResponse = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
182 provider: provider.into(),
183 reason: e.to_string(),
184 })?;
185
186 let choice = raw.choices.into_iter().next().ok_or(CloudAiError::Shape {
187 provider: provider.into(),
188 detail: "no choices in response".into(),
189 })?;
190
191 let content = choice.message.content.unwrap_or_default();
192 if content.is_empty() {
193 let detail = match choice.finish_reason.as_deref() {
199 Some("length") => "empty assistant content (finish_reason: length - the model hit \
200 max_tokens, likely while \"thinking\"; raise max tokens or disable the model's \
201 reasoning/thinking mode)"
202 .to_string(),
203 other => format!(
204 "empty assistant content (finish_reason: {})",
205 other.unwrap_or("unknown")
206 ),
207 };
208 return Err(CloudAiError::Shape {
209 provider: provider.into(),
210 detail,
211 });
212 }
213
214 let usage = raw.usage.unwrap_or(Usage {
215 prompt_tokens: 0,
216 completion_tokens: 0,
217 });
218 let text = strip_reasoning(&content);
219
220 Ok(CloudAiResponse {
221 provider: provider.into(),
222 model: raw.model,
223 text,
224 finish_reason: choice.finish_reason.unwrap_or_else(|| "unknown".into()),
225 input_tokens: usage.prompt_tokens,
226 output_tokens: usage.completion_tokens,
227 latency_ms,
228 })
229}
230
231pub fn build_streaming_body(req: &CloudAiRequest) -> serde_json::Value {
236 let mut body = build_body(req);
237 body["stream"] = json!(true);
238 body["stream_options"] = json!({ "include_usage": true });
239 body
240}
241
242pub async fn stream_response(
251 builder: reqwest::RequestBuilder,
252 provider: &str,
253 call_id: &str,
254 sink: &dyn LlmStreamSink,
255) -> Result<CloudAiResponse, CloudAiError> {
256 let started = Instant::now();
257 let resp = builder.send().await.map_err(|e| {
258 let err = CloudAiError::Http {
259 provider: provider.into(),
260 reason: e.to_string(),
261 };
262 err
263 })?;
264
265 let status = resp.status();
266 if !status.is_success() {
267 let raw = resp.text().await.unwrap_or_default();
268 let err = CloudAiError::Status {
269 provider: provider.into(),
270 status: status.as_u16(),
271 body: redact_error_body(&raw),
272 };
273 sink.emit(LlmStreamEvent::error(call_id, provider, err.to_string()))
274 .await;
275 return Err(err);
276 }
277
278 let mut stream = resp.bytes_stream();
279 let mut buffer = String::new();
280 let mut full_text = String::new();
281 let mut model_name = String::new();
282 let mut finish_reason = String::from("unknown");
283 let mut input_tokens: u32 = 0;
284 let mut output_tokens: u32 = 0;
285 let mut errored: Option<String> = None;
286
287 while let Some(chunk) = stream.next().await {
288 let bytes = match chunk {
289 Ok(b) => b,
290 Err(e) => {
291 errored = Some(e.to_string());
292 break;
293 }
294 };
295 buffer.push_str(&String::from_utf8_lossy(&bytes));
296 while let Some(idx) = buffer.find("\n\n") {
299 let frame = buffer[..idx].to_string();
300 buffer.drain(..idx + 2);
301 for raw_line in frame.split('\n') {
302 let line = raw_line.trim_start();
303 let Some(payload) = line.strip_prefix("data:") else {
304 continue;
305 };
306 let payload = payload.trim();
307 if payload.is_empty() {
308 continue;
309 }
310 if payload == "[DONE]" {
311 sink.emit(LlmStreamEvent {
312 call_id: call_id.to_string(),
313 provider: provider.to_string(),
314 model: model_name.clone(),
315 delta: String::new(),
316 done: true,
317 error: None,
318 })
319 .await;
320 let latency_ms = started.elapsed().as_millis() as u64;
321 return Ok(CloudAiResponse {
322 provider: provider.to_string(),
323 model: if model_name.is_empty() {
324 "unknown".to_string()
325 } else {
326 model_name
327 },
328 text: full_text,
329 finish_reason,
330 input_tokens,
331 output_tokens,
332 latency_ms,
333 });
334 }
335 let value: serde_json::Value = match serde_json::from_str(payload) {
336 Ok(v) => v,
337 Err(_) => continue,
338 };
339 if let Some(name) = value.get("model").and_then(|m| m.as_str()) {
340 if model_name.is_empty() {
341 model_name = name.to_string();
342 }
343 }
344 if let Some(usage) = value.get("usage") {
345 input_tokens = usage
346 .get("prompt_tokens")
347 .and_then(|v| v.as_u64())
348 .map(|v| v as u32)
349 .unwrap_or(input_tokens);
350 output_tokens = usage
351 .get("completion_tokens")
352 .and_then(|v| v.as_u64())
353 .map(|v| v as u32)
354 .unwrap_or(output_tokens);
355 }
356 if let Some(choice) = value.get("choices").and_then(|c| c.get(0)) {
357 if let Some(reason) = choice.get("finish_reason").and_then(|r| r.as_str())
358 {
359 finish_reason = reason.to_string();
360 }
361 if let Some(delta) = choice
362 .get("delta")
363 .and_then(|d| d.get("content"))
364 .and_then(|c| c.as_str())
365 {
366 if !delta.is_empty() {
367 full_text.push_str(delta);
368 sink.emit(LlmStreamEvent {
369 call_id: call_id.to_string(),
370 provider: provider.to_string(),
371 model: model_name.clone(),
372 delta: delta.to_string(),
373 done: false,
374 error: None,
375 })
376 .await;
377 }
378 }
379 }
380 }
381 }
382 }
383
384 if let Some(err) = errored {
385 sink.emit(LlmStreamEvent::error(call_id, provider, err.clone()))
386 .await;
387 return Err(CloudAiError::Http {
388 provider: provider.into(),
389 reason: err,
390 });
391 }
392 sink.emit(LlmStreamEvent {
395 call_id: call_id.to_string(),
396 provider: provider.to_string(),
397 model: model_name.clone(),
398 delta: String::new(),
399 done: true,
400 error: None,
401 })
402 .await;
403 let latency_ms = started.elapsed().as_millis() as u64;
404 if full_text.is_empty() {
405 return Err(CloudAiError::Shape {
406 provider: provider.into(),
407 detail: "stream closed before any content was emitted".into(),
408 });
409 }
410 Ok(CloudAiResponse {
411 provider: provider.to_string(),
412 model: if model_name.is_empty() {
413 "unknown".to_string()
414 } else {
415 model_name
416 },
417 text: strip_reasoning(&full_text),
418 finish_reason,
419 input_tokens,
420 output_tokens,
421 latency_ms,
422 })
423}
424
425pub fn build_embeddings_body(req: &CloudAiRequest) -> serde_json::Value {
430 json!({ "model": req.model, "input": req.prompt })
431}
432
433#[derive(Debug, Deserialize)]
434struct RawEmbeddings {
435 #[serde(default)]
436 model: Option<String>,
437 data: Vec<EmbeddingDatum>,
438}
439
440#[derive(Debug, Deserialize)]
441struct EmbeddingDatum {
442 embedding: Vec<f32>,
443}
444
445pub fn parse_embeddings_response(
447 provider: &str,
448 model_fallback: &str,
449 body: &str,
450 latency_ms: u64,
451) -> Result<EmbeddingResponse, CloudAiError> {
452 let raw: RawEmbeddings = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
453 provider: provider.into(),
454 reason: e.to_string(),
455 })?;
456 let first = raw.data.into_iter().next().ok_or(CloudAiError::Shape {
457 provider: provider.into(),
458 detail: "no embedding data in response".into(),
459 })?;
460 let dims = first.embedding.len();
461 Ok(EmbeddingResponse {
462 provider: provider.into(),
463 model: raw.model.unwrap_or_else(|| model_fallback.to_string()),
464 embedding: first.embedding,
465 dims,
466 latency_ms,
467 })
468}
469
470#[derive(Debug, Deserialize)]
473struct RawToolCall {
474 #[serde(default)]
475 id: String,
476 #[serde(default)]
477 function: RawToolFn,
478}
479
480#[derive(Debug, Default, Deserialize)]
481struct RawToolFn {
482 #[serde(default)]
483 name: String,
484 #[serde(default)]
486 arguments: String,
487}
488
489struct ParsedTurn {
490 content: String,
491 tool_calls: Vec<RawToolCall>,
492 finish_reason: String,
493 model: String,
494 input_tokens: u32,
495 output_tokens: u32,
496}
497
498fn parse_turn(provider: &str, body: &str) -> Result<ParsedTurn, CloudAiError> {
501 #[derive(Deserialize)]
502 struct R {
503 #[serde(default)]
504 model: String,
505 choices: Vec<C>,
506 usage: Option<Usage>,
507 }
508 #[derive(Deserialize)]
509 struct C {
510 message: M,
511 #[serde(default)]
512 finish_reason: Option<String>,
513 }
514 #[derive(Deserialize)]
515 struct M {
516 #[serde(default)]
517 content: Option<String>,
518 #[serde(default)]
519 tool_calls: Vec<RawToolCall>,
520 }
521 let r: R = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
522 provider: provider.into(),
523 reason: e.to_string(),
524 })?;
525 let c = r.choices.into_iter().next().ok_or(CloudAiError::Shape {
526 provider: provider.into(),
527 detail: "no choices in response".into(),
528 })?;
529 let usage = r.usage.unwrap_or(Usage {
530 prompt_tokens: 0,
531 completion_tokens: 0,
532 });
533 Ok(ParsedTurn {
534 content: c.message.content.unwrap_or_default(),
535 tool_calls: c.message.tool_calls,
536 finish_reason: c.finish_reason.unwrap_or_else(|| "unknown".into()),
537 model: r.model,
538 input_tokens: usage.prompt_tokens,
539 output_tokens: usage.completion_tokens,
540 })
541}
542
543#[allow(clippy::too_many_arguments)]
550pub async fn run_tools_loop(
551 client: &reqwest::Client,
552 url: &str,
553 provider: &str,
554 api_key: &str,
555 req: &CloudAiRequest,
556 tools: &[ToolSpec],
557 dispatcher: &dyn ToolDispatcher,
558 max_iters: usize,
559) -> Result<CloudAiResponse, CloudAiError> {
560 let started = Instant::now();
561 let mut messages = build_messages(req);
562 let tools_body = tools_json(tools).unwrap_or_else(|| json!([]));
563 let mut input_tokens: u32 = 0;
564 let mut output_tokens: u32 = 0;
565
566 for _ in 0..max_iters.max(1) {
567 let mut body = json!({
568 "model": req.model,
569 "messages": messages,
570 "tools": tools_body,
571 });
572 apply_params(&mut body, req);
573
574 let mut builder = client.post(url).json(&body);
575 if !api_key.is_empty() {
576 builder = builder.bearer_auth(api_key);
577 }
578 let resp = builder.send().await.map_err(|e| CloudAiError::Http {
579 provider: provider.into(),
580 reason: e.to_string(),
581 })?;
582 let status = resp.status();
583 let raw = resp.text().await.unwrap_or_default();
584 if !status.is_success() {
585 return Err(CloudAiError::Status {
586 provider: provider.into(),
587 status: status.as_u16(),
588 body: redact_error_body(&raw),
589 });
590 }
591
592 let turn = parse_turn(provider, &raw)?;
593 input_tokens = input_tokens.saturating_add(turn.input_tokens);
594 output_tokens = output_tokens.saturating_add(turn.output_tokens);
595
596 if turn.tool_calls.is_empty() {
597 return Ok(CloudAiResponse {
598 provider: provider.into(),
599 model: if turn.model.is_empty() {
600 req.model.clone()
601 } else {
602 turn.model
603 },
604 text: strip_reasoning(&turn.content),
605 finish_reason: turn.finish_reason,
606 input_tokens,
607 output_tokens,
608 latency_ms: started.elapsed().as_millis() as u64,
609 });
610 }
611
612 messages.push(json!({
614 "role": "assistant",
615 "content": turn.content,
616 "tool_calls": turn.tool_calls.iter().map(|tc| json!({
617 "id": tc.id,
618 "type": "function",
619 "function": { "name": tc.function.name, "arguments": tc.function.arguments },
620 })).collect::<Vec<_>>(),
621 }));
622
623 for tc in &turn.tool_calls {
626 let args: serde_json::Value =
627 serde_json::from_str(&tc.function.arguments).unwrap_or_else(|_| json!({}));
628 let content = match dispatcher.call(&tc.function.name, &args).await {
629 Ok(v) => v.to_string(),
630 Err(e) => json!({ "error": e }).to_string(),
631 };
632 messages.push(json!({
633 "role": "tool",
634 "tool_call_id": tc.id,
635 "content": content,
636 }));
637 }
638 }
639
640 Err(CloudAiError::Shape {
641 provider: provider.into(),
642 detail: format!("tool loop did not converge within {max_iters} iterations"),
643 })
644}