Skip to main content

flow_adapter_ai/providers/
local.rs

1//! Local OpenAI-compatible provider.
2//!
3//! Targets an on-device inference server (Ollama / LM Studio / llama.cpp)
4//! that exposes the OpenAI `/v1/chat/completions` API. Unlike the cloud
5//! `openai` provider, the endpoint is **per-request** (`req.base_url`,
6//! sourced from the user's `local_ai_base_url` setting) and authentication
7//! is optional - local servers typically ignore the bearer token.
8//!
9//! Gated by `allow_local_ai` at the executor layer (not `allow_cloud_ai`):
10//! a localhost call is not network egress, so it stays usable even when
11//! cloud AI is disabled.
12
13use async_trait::async_trait;
14use serde::Deserialize;
15use std::time::Instant;
16
17use crate::error::{redact_error_body, CloudAiError};
18use crate::providers::openai_compat::{
19    build_body, build_embeddings_body, build_streaming_body, parse_embeddings_response,
20    parse_response, run_tools_loop, stream_response,
21};
22use crate::registry::{CloudAiProvider, ProviderCategory};
23use crate::request::{
24    CloudAiRequest, CloudAiResponse, EmbeddingResponse, ToolDispatcher, ToolSpec,
25};
26use crate::stream::LlmStreamSink;
27
28const NAME: &str = "local";
29// Optional: only used if the user stores a key for a server that requires one.
30const ENV_VAR: &str = "FLOW_LOCAL_AI_KEY";
31
32const DEFAULT_MODELS: &[&str] = &["qwen3.6-35b-a3b"];
33
34/// Reduce a user-entered server address to its origin/base, tolerating both
35/// a bare `http://127.0.0.1:1234` and a fully-qualified
36/// `http://127.0.0.1:1234/v1/chat/completions`. Trailing slashes and the
37/// known OpenAI-compat path suffixes are stripped so the result can have a
38/// fresh path appended.
39pub fn normalize_local_base(input: &str) -> String {
40    let mut b = input.trim().trim_end_matches('/');
41    if let Some(stripped) = b.strip_suffix("/v1/chat/completions") {
42        b = stripped.trim_end_matches('/');
43    }
44    if let Some(stripped) = b.strip_suffix("/v1") {
45        b = stripped.trim_end_matches('/');
46    }
47    b.to_string()
48}
49
50/// Resolve the chat-completions endpoint from any accepted base form.
51pub fn local_chat_url(input: &str) -> String {
52    format!("{}/v1/chat/completions", normalize_local_base(input))
53}
54
55/// Resolve the model-listing endpoint from any accepted base form.
56pub fn local_models_url(input: &str) -> String {
57    format!("{}/v1/models", normalize_local_base(input))
58}
59
60/// Resolve the embeddings endpoint from any accepted base form.
61pub fn local_embeddings_url(input: &str) -> String {
62    format!("{}/v1/embeddings", normalize_local_base(input))
63}
64
65pub struct LocalOpenAiProvider {
66    client: reqwest::Client,
67}
68
69impl LocalOpenAiProvider {
70    pub fn new() -> Self {
71        Self {
72            client: reqwest::Client::new(),
73        }
74    }
75
76    /// Query the server's `GET /v1/models` and return the available model
77    /// ids. Used by the Settings "Test connection" button to populate the
78    /// model dropdown with what's actually loaded. `base` accepts any of
79    /// the forms `normalize_local_base` tolerates.
80    pub async fn list_models(&self, base: &str) -> Result<Vec<String>, CloudAiError> {
81        let url = local_models_url(base);
82        let resp = self
83            .client
84            .get(&url)
85            .send()
86            .await
87            .map_err(|e| CloudAiError::Http {
88                provider: NAME.into(),
89                reason: e.to_string(),
90            })?;
91        let status = resp.status();
92        let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
93            provider: NAME.into(),
94            reason: e.to_string(),
95        })?;
96        if !status.is_success() {
97            return Err(CloudAiError::Status {
98                provider: NAME.into(),
99                status: status.as_u16(),
100                body: redact_error_body(&raw_body),
101            });
102        }
103        parse_model_list(&raw_body)
104    }
105}
106
107#[derive(Debug, Deserialize)]
108struct ModelList {
109    data: Vec<ModelEntry>,
110}
111
112#[derive(Debug, Deserialize)]
113struct ModelEntry {
114    id: String,
115}
116
117pub fn parse_model_list(body: &str) -> Result<Vec<String>, CloudAiError> {
118    let parsed: ModelList = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
119        provider: NAME.into(),
120        reason: e.to_string(),
121    })?;
122    Ok(parsed.data.into_iter().map(|m| m.id).collect())
123}
124
125impl Default for LocalOpenAiProvider {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131#[async_trait]
132impl CloudAiProvider for LocalOpenAiProvider {
133    fn name(&self) -> &str {
134        NAME
135    }
136
137    fn env_var(&self) -> &str {
138        ENV_VAR
139    }
140
141    fn default_models(&self) -> &[&str] {
142        DEFAULT_MODELS
143    }
144
145    fn category(&self) -> ProviderCategory {
146        ProviderCategory::Local
147    }
148
149    async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError> {
150        let base = req.base_url.as_deref().filter(|s| !s.is_empty()).ok_or(
151            CloudAiError::Shape {
152                provider: NAME.into(),
153                detail: "no base_url configured; set the Local AI base URL in Settings".into(),
154            },
155        )?;
156        let endpoint = local_chat_url(base);
157
158        let body = build_body(req);
159
160        let started = Instant::now();
161        let mut builder = self
162            .client
163            .post(&endpoint)
164            .header("content-type", "application/json");
165        // Local servers usually ignore auth; only attach a bearer token when
166        // the user actually stored a key for this provider.
167        if !req.api_key.is_empty() {
168            builder = builder.bearer_auth(&req.api_key);
169        }
170
171        let resp = builder
172            .json(&body)
173            .send()
174            .await
175            .map_err(|e| CloudAiError::Http {
176                provider: NAME.into(),
177                reason: e.to_string(),
178            })?;
179
180        let status = resp.status();
181        let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
182            provider: NAME.into(),
183            reason: e.to_string(),
184        })?;
185        let latency_ms = started.elapsed().as_millis() as u64;
186
187        if !status.is_success() {
188            return Err(CloudAiError::Status {
189                provider: NAME.into(),
190                status: status.as_u16(),
191                body: redact_error_body(&raw_body),
192            });
193        }
194
195        parse_response(NAME, &raw_body, latency_ms)
196    }
197
198    async fn invoke_stream(
199        &self,
200        req: &CloudAiRequest,
201        sink: &dyn LlmStreamSink,
202    ) -> Result<CloudAiResponse, CloudAiError> {
203        let base = req.base_url.as_deref().filter(|s| !s.is_empty()).ok_or(
204            CloudAiError::Shape {
205                provider: NAME.into(),
206                detail: "no base_url configured; set the Local AI base URL in Settings".into(),
207            },
208        )?;
209        let endpoint = local_chat_url(base);
210        let call_id = req.call_id.clone().unwrap_or_default();
211        let body = build_streaming_body(req);
212        let mut builder = self
213            .client
214            .post(&endpoint)
215            .header("content-type", "application/json")
216            .header("accept", "text/event-stream");
217        if !req.api_key.is_empty() {
218            builder = builder.bearer_auth(&req.api_key);
219        }
220        let builder = builder.json(&body);
221        stream_response(builder, NAME, &call_id, sink).await
222    }
223
224    async fn invoke_tools(
225        &self,
226        req: &CloudAiRequest,
227        tools: &[ToolSpec],
228        dispatcher: &dyn ToolDispatcher,
229        max_iters: usize,
230    ) -> Result<CloudAiResponse, CloudAiError> {
231        let base = req.base_url.as_deref().filter(|s| !s.is_empty()).ok_or(
232            CloudAiError::Shape {
233                provider: NAME.into(),
234                detail: "no base_url configured; set the Local AI base URL in Settings".into(),
235            },
236        )?;
237        let endpoint = local_chat_url(base);
238        run_tools_loop(
239            &self.client,
240            &endpoint,
241            NAME,
242            &req.api_key,
243            req,
244            tools,
245            dispatcher,
246            max_iters,
247        )
248        .await
249    }
250
251    async fn embed(&self, req: &CloudAiRequest) -> Result<EmbeddingResponse, CloudAiError> {
252        let base = req.base_url.as_deref().filter(|s| !s.is_empty()).ok_or(
253            CloudAiError::Shape {
254                provider: NAME.into(),
255                detail: "no base_url configured; set the Local AI base URL in Settings".into(),
256            },
257        )?;
258        let endpoint = local_embeddings_url(base);
259        let body = build_embeddings_body(req);
260        let started = Instant::now();
261        let mut builder = self
262            .client
263            .post(&endpoint)
264            .header("content-type", "application/json");
265        if !req.api_key.is_empty() {
266            builder = builder.bearer_auth(&req.api_key);
267        }
268        let resp = builder
269            .json(&body)
270            .send()
271            .await
272            .map_err(|e| CloudAiError::Http {
273                provider: NAME.into(),
274                reason: e.to_string(),
275            })?;
276        let status = resp.status();
277        let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
278            provider: NAME.into(),
279            reason: e.to_string(),
280        })?;
281        let latency_ms = started.elapsed().as_millis() as u64;
282        if !status.is_success() {
283            return Err(CloudAiError::Status {
284                provider: NAME.into(),
285                status: status.as_u16(),
286                body: redact_error_body(&raw_body),
287            });
288        }
289        parse_embeddings_response(NAME, &req.model, &raw_body, latency_ms)
290    }
291}