Skip to main content

flow_adapter_ai/providers/
openai_compat.rs

1//! Shared OpenAI Chat Completions wire format.
2//!
3//! Both the cloud `openai` provider (api.openai.com) and the `local`
4//! provider (Ollama / LM Studio / llama.cpp on localhost) speak the exact
5//! same `/v1/chat/completions` request and response shape. This module owns
6//! the body builder, non-streaming response parser, and the streaming SSE
7//! reader so the two providers can't drift.
8
9use futures_util::StreamExt;
10use serde::Deserialize;
11use serde_json::json;
12use std::time::Instant;
13
14use crate::error::{redact_error_body, CloudAiError};
15use crate::request::{
16    CloudAiRequest, CloudAiResponse, EmbeddingResponse, ToolDispatcher, ToolSpec,
17};
18use crate::stream::{LlmStreamEvent, LlmStreamSink};
19
20/// The user message `content`: a plain string when there are no images, or an
21/// OpenAI multimodal array (a `text` part plus one `image_url` part per image)
22/// when the request carries images. The executor pre-resolves local image
23/// paths to `data:` URLs, so each entry is a ready-to-send URL.
24fn user_content(req: &CloudAiRequest) -> serde_json::Value {
25    if req.images.is_empty() {
26        return json!(req.prompt);
27    }
28    let mut parts = vec![json!({ "type": "text", "text": req.prompt })];
29    for url in &req.images {
30        parts.push(json!({ "type": "image_url", "image_url": { "url": url } }));
31    }
32    json!(parts)
33}
34
35/// Initial `messages` array (`[system?, user]`). Shared by the single-turn body
36/// and the tool loop (which appends assistant/tool turns to it).
37fn build_messages(req: &CloudAiRequest) -> Vec<serde_json::Value> {
38    let mut messages = Vec::with_capacity(2);
39    if let Some(s) = req.system.as_deref().filter(|s| !s.is_empty()) {
40        // Chat Completions accepts a `system`-role message as the first
41        // entry - strictly preferred over folding the system text into the
42        // user message because the model is trained to weight system
43        // instructions differently.
44        messages.push(json!({ "role": "system", "content": s }));
45    }
46    messages.push(json!({ "role": "user", "content": user_content(req) }));
47    messages
48}
49
50/// OpenAI `tools` array (`type: "function"`) from the request's `ToolSpec`s,
51/// or `None` when no tools are bound.
52fn tools_json(tools: &[ToolSpec]) -> Option<serde_json::Value> {
53    if tools.is_empty() {
54        return None;
55    }
56    let arr: Vec<_> = tools
57        .iter()
58        .map(|t| {
59            json!({
60                "type": "function",
61                "function": {
62                    "name": t.name,
63                    "description": t.description,
64                    "parameters": t.parameters,
65                },
66            })
67        })
68        .collect();
69    Some(json!(arr))
70}
71
72/// Apply the sampling + reasoning params shared by every chat body.
73fn apply_params(payload: &mut serde_json::Value, req: &CloudAiRequest) {
74    if let Some(m) = req.max_tokens {
75        payload["max_tokens"] = json!(m);
76    }
77    if let Some(t) = req.temperature {
78        payload["temperature"] = json!(t);
79    }
80    if let Some(p) = req.top_p {
81        payload["top_p"] = json!(p);
82    }
83    if let Some(k) = req.top_k {
84        // `top_k` isn't part of upstream OpenAI Chat Completions, but
85        // llama.cpp / Ollama / LM Studio accept it. Cloud OpenAI safely
86        // ignores unknown fields, so emitting it unconditionally for the
87        // shared body is harmless.
88        payload["top_k"] = json!(k);
89    }
90    if let Some(stops) = req.stop.as_ref().filter(|s| !s.is_empty()) {
91        payload["stop"] = json!(stops);
92    }
93    if let Some(think) = req.reasoning {
94        // llama.cpp / vLLM honor `chat_template_kwargs.enable_thinking` to
95        // toggle a model's reasoning mode per request; cloud OpenAI ignores
96        // unknown fields, so emitting it for the shared body is harmless.
97        payload["chat_template_kwargs"] = json!({ "enable_thinking": think });
98    }
99}
100
101/// Build a Chat Completions request body from a `CloudAiRequest`.
102pub fn build_body(req: &CloudAiRequest) -> serde_json::Value {
103    let mut payload = json!({
104        "model": req.model,
105        "messages": build_messages(req),
106    });
107    apply_params(&mut payload, req);
108    if let Some(tools) = tools_json(&req.tools) {
109        payload["tools"] = tools;
110    }
111    if let Some(schema) = &req.response_schema {
112        // `strict: false` keeps the server lenient about the exact JSON-Schema
113        // subset (a hand-written schema needn't meet OpenAI's strict rules);
114        // servers that don't support `response_format` ignore the unknown field
115        // and fall back to the prompt constraint.
116        payload["response_format"] = json!({
117            "type": "json_schema",
118            "json_schema": { "name": "output", "schema": schema, "strict": false },
119        });
120    }
121    payload
122}
123
124/// Strip inline reasoning so only the user-facing answer is surfaced: removes
125/// `<think>…</think>` spans (some local models inline their chain-of-thought
126/// there). Separate `reasoning_content` fields are never read into `text`, so
127/// they're already excluded. Idempotent; trims surrounding whitespace.
128pub fn strip_reasoning(text: &str) -> String {
129    let mut out = String::with_capacity(text.len());
130    let mut rest = text;
131    while let Some(start) = rest.find("<think>") {
132        out.push_str(&rest[..start]);
133        match rest[start..].find("</think>") {
134            Some(end) => rest = &rest[start + end + "</think>".len()..],
135            // Unclosed `<think>` - the model was cut off mid-thought; drop the
136            // remainder so no partial reasoning leaks into the answer.
137            None => {
138                rest = "";
139                break;
140            }
141        }
142    }
143    out.push_str(rest);
144    out.trim().to_string()
145}
146
147#[derive(Debug, Deserialize)]
148struct RawResponse {
149    model: String,
150    choices: Vec<Choice>,
151    usage: Option<Usage>,
152}
153
154#[derive(Debug, Deserialize)]
155struct Choice {
156    message: ChoiceMessage,
157    finish_reason: Option<String>,
158}
159
160#[derive(Debug, Deserialize)]
161struct ChoiceMessage {
162    #[serde(default)]
163    content: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167struct Usage {
168    #[serde(default)]
169    prompt_tokens: u32,
170    #[serde(default)]
171    completion_tokens: u32,
172}
173
174/// Parse a Chat Completions response body. `provider` is the caller's name,
175/// stamped into errors and the returned `CloudAiResponse.provider`.
176pub fn parse_response(
177    provider: &str,
178    body: &str,
179    latency_ms: u64,
180) -> Result<CloudAiResponse, CloudAiError> {
181    let raw: RawResponse = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
182        provider: provider.into(),
183        reason: e.to_string(),
184    })?;
185
186    let choice = raw.choices.into_iter().next().ok_or(CloudAiError::Shape {
187        provider: provider.into(),
188        detail: "no choices in response".into(),
189    })?;
190
191    let content = choice.message.content.unwrap_or_default();
192    if content.is_empty() {
193        // A reasoning ("thinking") model can spend its whole token budget
194        // in `reasoning_content` and get cut off before emitting any
195        // `content` - the server reports `finish_reason: "length"`. Surface
196        // that so the user knows to raise max tokens or disable thinking,
197        // rather than seeing an opaque "empty content".
198        let detail = match choice.finish_reason.as_deref() {
199            Some("length") => "empty assistant content (finish_reason: length - the model hit \
200                 max_tokens, likely while \"thinking\"; raise max tokens or disable the model's \
201                 reasoning/thinking mode)"
202                .to_string(),
203            other => format!(
204                "empty assistant content (finish_reason: {})",
205                other.unwrap_or("unknown")
206            ),
207        };
208        return Err(CloudAiError::Shape {
209            provider: provider.into(),
210            detail,
211        });
212    }
213
214    let usage = raw.usage.unwrap_or(Usage {
215        prompt_tokens: 0,
216        completion_tokens: 0,
217    });
218    let text = strip_reasoning(&content);
219
220    Ok(CloudAiResponse {
221        provider: provider.into(),
222        model: raw.model,
223        text,
224        finish_reason: choice.finish_reason.unwrap_or_else(|| "unknown".into()),
225        input_tokens: usage.prompt_tokens,
226        output_tokens: usage.completion_tokens,
227        latency_ms,
228    })
229}
230
231/// Build the request body for a streaming Chat Completions call. Same shape
232/// as `build_body` plus `stream: true` and a `stream_options` block that
233/// asks the server to send a final `usage` chunk so the caller can report
234/// token counts even on streaming responses.
235pub fn build_streaming_body(req: &CloudAiRequest) -> serde_json::Value {
236    let mut body = build_body(req);
237    body["stream"] = json!(true);
238    body["stream_options"] = json!({ "include_usage": true });
239    body
240}
241
242/// Stream an OpenAI-compatible Chat Completions response, emitting each
243/// content delta through `sink`. Returns the accumulated full response
244/// once the stream closes so the caller still gets a `CloudAiResponse`
245/// to surface in the node output.
246///
247/// Handles the canonical SSE shape (one `data: <json>\n` per chunk plus
248/// a terminating `data: [DONE]`). Servers that don't include a `usage`
249/// chunk leave token counts at zero, mirroring the non-streaming path.
250pub async fn stream_response(
251    builder: reqwest::RequestBuilder,
252    provider: &str,
253    call_id: &str,
254    sink: &dyn LlmStreamSink,
255) -> Result<CloudAiResponse, CloudAiError> {
256    let started = Instant::now();
257    let resp = builder.send().await.map_err(|e| {
258        let err = CloudAiError::Http {
259            provider: provider.into(),
260            reason: e.to_string(),
261        };
262        err
263    })?;
264
265    let status = resp.status();
266    if !status.is_success() {
267        let raw = resp.text().await.unwrap_or_default();
268        let err = CloudAiError::Status {
269            provider: provider.into(),
270            status: status.as_u16(),
271            body: redact_error_body(&raw),
272        };
273        sink.emit(LlmStreamEvent::error(call_id, provider, err.to_string()))
274            .await;
275        return Err(err);
276    }
277
278    let mut stream = resp.bytes_stream();
279    let mut buffer = String::new();
280    let mut full_text = String::new();
281    let mut model_name = String::new();
282    let mut finish_reason = String::from("unknown");
283    let mut input_tokens: u32 = 0;
284    let mut output_tokens: u32 = 0;
285    let mut errored: Option<String> = None;
286
287    while let Some(chunk) = stream.next().await {
288        let bytes = match chunk {
289            Ok(b) => b,
290            Err(e) => {
291                errored = Some(e.to_string());
292                break;
293            }
294        };
295        buffer.push_str(&String::from_utf8_lossy(&bytes));
296        // SSE frames are separated by `\n\n`; lines within a frame start
297        // with `data: ` (other prefixes like `event:` / `id:` are ignored).
298        while let Some(idx) = buffer.find("\n\n") {
299            let frame = buffer[..idx].to_string();
300            buffer.drain(..idx + 2);
301            for raw_line in frame.split('\n') {
302                let line = raw_line.trim_start();
303                let Some(payload) = line.strip_prefix("data:") else {
304                    continue;
305                };
306                let payload = payload.trim();
307                if payload.is_empty() {
308                    continue;
309                }
310                if payload == "[DONE]" {
311                    sink.emit(LlmStreamEvent {
312                        call_id: call_id.to_string(),
313                        provider: provider.to_string(),
314                        model: model_name.clone(),
315                        delta: String::new(),
316                        done: true,
317                        error: None,
318                    })
319                    .await;
320                    let latency_ms = started.elapsed().as_millis() as u64;
321                    return Ok(CloudAiResponse {
322                        provider: provider.to_string(),
323                        model: if model_name.is_empty() {
324                            "unknown".to_string()
325                        } else {
326                            model_name
327                        },
328                        text: full_text,
329                        finish_reason,
330                        input_tokens,
331                        output_tokens,
332                        latency_ms,
333                    });
334                }
335                let value: serde_json::Value = match serde_json::from_str(payload) {
336                    Ok(v) => v,
337                    Err(_) => continue,
338                };
339                if let Some(name) = value.get("model").and_then(|m| m.as_str()) {
340                    if model_name.is_empty() {
341                        model_name = name.to_string();
342                    }
343                }
344                if let Some(usage) = value.get("usage") {
345                    input_tokens = usage
346                        .get("prompt_tokens")
347                        .and_then(|v| v.as_u64())
348                        .map(|v| v as u32)
349                        .unwrap_or(input_tokens);
350                    output_tokens = usage
351                        .get("completion_tokens")
352                        .and_then(|v| v.as_u64())
353                        .map(|v| v as u32)
354                        .unwrap_or(output_tokens);
355                }
356                if let Some(choice) = value.get("choices").and_then(|c| c.get(0)) {
357                    if let Some(reason) = choice.get("finish_reason").and_then(|r| r.as_str())
358                    {
359                        finish_reason = reason.to_string();
360                    }
361                    if let Some(delta) = choice
362                        .get("delta")
363                        .and_then(|d| d.get("content"))
364                        .and_then(|c| c.as_str())
365                    {
366                        if !delta.is_empty() {
367                            full_text.push_str(delta);
368                            sink.emit(LlmStreamEvent {
369                                call_id: call_id.to_string(),
370                                provider: provider.to_string(),
371                                model: model_name.clone(),
372                                delta: delta.to_string(),
373                                done: false,
374                                error: None,
375                            })
376                            .await;
377                        }
378                    }
379                }
380            }
381        }
382    }
383
384    if let Some(err) = errored {
385        sink.emit(LlmStreamEvent::error(call_id, provider, err.clone()))
386            .await;
387        return Err(CloudAiError::Http {
388            provider: provider.into(),
389            reason: err,
390        });
391    }
392    // Stream closed without `[DONE]` - some servers omit it. Treat the
393    // accumulated text as the final response and emit a terminal event.
394    sink.emit(LlmStreamEvent {
395        call_id: call_id.to_string(),
396        provider: provider.to_string(),
397        model: model_name.clone(),
398        delta: String::new(),
399        done: true,
400        error: None,
401    })
402    .await;
403    let latency_ms = started.elapsed().as_millis() as u64;
404    if full_text.is_empty() {
405        return Err(CloudAiError::Shape {
406            provider: provider.into(),
407            detail: "stream closed before any content was emitted".into(),
408        });
409    }
410    Ok(CloudAiResponse {
411        provider: provider.to_string(),
412        model: if model_name.is_empty() {
413            "unknown".to_string()
414        } else {
415            model_name
416        },
417        text: strip_reasoning(&full_text),
418        finish_reason,
419        input_tokens,
420        output_tokens,
421        latency_ms,
422    })
423}
424
425// ---------- embeddings (`/v1/embeddings`) ----------
426
427/// Build an OpenAI-compatible `/v1/embeddings` body. The text to embed is the
428/// request `prompt`.
429pub fn build_embeddings_body(req: &CloudAiRequest) -> serde_json::Value {
430    json!({ "model": req.model, "input": req.prompt })
431}
432
433#[derive(Debug, Deserialize)]
434struct RawEmbeddings {
435    #[serde(default)]
436    model: Option<String>,
437    data: Vec<EmbeddingDatum>,
438}
439
440#[derive(Debug, Deserialize)]
441struct EmbeddingDatum {
442    embedding: Vec<f32>,
443}
444
445/// Parse an OpenAI-compatible embeddings response into the first vector.
446pub fn parse_embeddings_response(
447    provider: &str,
448    model_fallback: &str,
449    body: &str,
450    latency_ms: u64,
451) -> Result<EmbeddingResponse, CloudAiError> {
452    let raw: RawEmbeddings = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
453        provider: provider.into(),
454        reason: e.to_string(),
455    })?;
456    let first = raw.data.into_iter().next().ok_or(CloudAiError::Shape {
457        provider: provider.into(),
458        detail: "no embedding data in response".into(),
459    })?;
460    let dims = first.embedding.len();
461    Ok(EmbeddingResponse {
462        provider: provider.into(),
463        model: raw.model.unwrap_or_else(|| model_fallback.to_string()),
464        embedding: first.embedding,
465        dims,
466        latency_ms,
467    })
468}
469
470// ---------- tool-use loop ----------
471
472#[derive(Debug, Deserialize)]
473struct RawToolCall {
474    #[serde(default)]
475    id: String,
476    #[serde(default)]
477    function: RawToolFn,
478}
479
480#[derive(Debug, Default, Deserialize)]
481struct RawToolFn {
482    #[serde(default)]
483    name: String,
484    /// OpenAI passes arguments as a JSON-encoded string.
485    #[serde(default)]
486    arguments: String,
487}
488
489struct ParsedTurn {
490    content: String,
491    tool_calls: Vec<RawToolCall>,
492    finish_reason: String,
493    model: String,
494    input_tokens: u32,
495    output_tokens: u32,
496}
497
498/// Parse one assistant turn (content + any tool calls + usage) from a
499/// non-streaming Chat Completions body.
500fn parse_turn(provider: &str, body: &str) -> Result<ParsedTurn, CloudAiError> {
501    #[derive(Deserialize)]
502    struct R {
503        #[serde(default)]
504        model: String,
505        choices: Vec<C>,
506        usage: Option<Usage>,
507    }
508    #[derive(Deserialize)]
509    struct C {
510        message: M,
511        #[serde(default)]
512        finish_reason: Option<String>,
513    }
514    #[derive(Deserialize)]
515    struct M {
516        #[serde(default)]
517        content: Option<String>,
518        #[serde(default)]
519        tool_calls: Vec<RawToolCall>,
520    }
521    let r: R = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
522        provider: provider.into(),
523        reason: e.to_string(),
524    })?;
525    let c = r.choices.into_iter().next().ok_or(CloudAiError::Shape {
526        provider: provider.into(),
527        detail: "no choices in response".into(),
528    })?;
529    let usage = r.usage.unwrap_or(Usage {
530        prompt_tokens: 0,
531        completion_tokens: 0,
532    });
533    Ok(ParsedTurn {
534        content: c.message.content.unwrap_or_default(),
535        tool_calls: c.message.tool_calls,
536        finish_reason: c.finish_reason.unwrap_or_else(|| "unknown".into()),
537        model: r.model,
538        input_tokens: usage.prompt_tokens,
539        output_tokens: usage.completion_tokens,
540    })
541}
542
543/// Run the OpenAI-compatible tool-use loop: expose `tools`, and on each turn,
544/// if the model requests tool calls, run them through `dispatcher`, append the
545/// results, and re-call - until the model answers without tool calls (returned
546/// as the `CloudAiResponse`) or `max_iters` is exhausted. `api_key` empty means
547/// no bearer header (local servers). Non-streaming: tool-call deltas don't
548/// reassemble cleanly over SSE.
549#[allow(clippy::too_many_arguments)]
550pub async fn run_tools_loop(
551    client: &reqwest::Client,
552    url: &str,
553    provider: &str,
554    api_key: &str,
555    req: &CloudAiRequest,
556    tools: &[ToolSpec],
557    dispatcher: &dyn ToolDispatcher,
558    max_iters: usize,
559) -> Result<CloudAiResponse, CloudAiError> {
560    let started = Instant::now();
561    let mut messages = build_messages(req);
562    let tools_body = tools_json(tools).unwrap_or_else(|| json!([]));
563    let mut input_tokens: u32 = 0;
564    let mut output_tokens: u32 = 0;
565
566    for _ in 0..max_iters.max(1) {
567        let mut body = json!({
568            "model": req.model,
569            "messages": messages,
570            "tools": tools_body,
571        });
572        apply_params(&mut body, req);
573
574        let mut builder = client.post(url).json(&body);
575        if !api_key.is_empty() {
576            builder = builder.bearer_auth(api_key);
577        }
578        let resp = builder.send().await.map_err(|e| CloudAiError::Http {
579            provider: provider.into(),
580            reason: e.to_string(),
581        })?;
582        let status = resp.status();
583        let raw = resp.text().await.unwrap_or_default();
584        if !status.is_success() {
585            return Err(CloudAiError::Status {
586                provider: provider.into(),
587                status: status.as_u16(),
588                body: redact_error_body(&raw),
589            });
590        }
591
592        let turn = parse_turn(provider, &raw)?;
593        input_tokens = input_tokens.saturating_add(turn.input_tokens);
594        output_tokens = output_tokens.saturating_add(turn.output_tokens);
595
596        if turn.tool_calls.is_empty() {
597            return Ok(CloudAiResponse {
598                provider: provider.into(),
599                model: if turn.model.is_empty() {
600                    req.model.clone()
601                } else {
602                    turn.model
603                },
604                text: strip_reasoning(&turn.content),
605                finish_reason: turn.finish_reason,
606                input_tokens,
607                output_tokens,
608                latency_ms: started.elapsed().as_millis() as u64,
609            });
610        }
611
612        // Echo the assistant's tool-call turn back into the history…
613        messages.push(json!({
614            "role": "assistant",
615            "content": turn.content,
616            "tool_calls": turn.tool_calls.iter().map(|tc| json!({
617                "id": tc.id,
618                "type": "function",
619                "function": { "name": tc.function.name, "arguments": tc.function.arguments },
620            })).collect::<Vec<_>>(),
621        }));
622
623        // …then run each tool and append its result (errors fed back as the
624        // result so the model can recover rather than the run hard-failing).
625        for tc in &turn.tool_calls {
626            let args: serde_json::Value =
627                serde_json::from_str(&tc.function.arguments).unwrap_or_else(|_| json!({}));
628            let content = match dispatcher.call(&tc.function.name, &args).await {
629                Ok(v) => v.to_string(),
630                Err(e) => json!({ "error": e }).to_string(),
631            };
632            messages.push(json!({
633                "role": "tool",
634                "tool_call_id": tc.id,
635                "content": content,
636            }));
637        }
638    }
639
640    Err(CloudAiError::Shape {
641        provider: provider.into(),
642        detail: format!("tool loop did not converge within {max_iters} iterations"),
643    })
644}