flow_adapter_ai/providers/
local.rs1use async_trait::async_trait;
14use serde::Deserialize;
15use std::time::Instant;
16
17use crate::error::{redact_error_body, CloudAiError};
18use crate::providers::openai_compat::{
19 build_body, build_embeddings_body, build_streaming_body, parse_embeddings_response,
20 parse_response, run_tools_loop, stream_response,
21};
22use crate::registry::{CloudAiProvider, ProviderCategory};
23use crate::request::{
24 CloudAiRequest, CloudAiResponse, EmbeddingResponse, ToolDispatcher, ToolSpec,
25};
26use crate::stream::LlmStreamSink;
27
28const NAME: &str = "local";
29const ENV_VAR: &str = "FLOW_LOCAL_AI_KEY";
31
32const DEFAULT_MODELS: &[&str] = &["qwen3.6-35b-a3b"];
33
34pub fn normalize_local_base(input: &str) -> String {
40 let mut b = input.trim().trim_end_matches('/');
41 if let Some(stripped) = b.strip_suffix("/v1/chat/completions") {
42 b = stripped.trim_end_matches('/');
43 }
44 if let Some(stripped) = b.strip_suffix("/v1") {
45 b = stripped.trim_end_matches('/');
46 }
47 b.to_string()
48}
49
50pub fn local_chat_url(input: &str) -> String {
52 format!("{}/v1/chat/completions", normalize_local_base(input))
53}
54
55pub fn local_models_url(input: &str) -> String {
57 format!("{}/v1/models", normalize_local_base(input))
58}
59
60pub fn local_embeddings_url(input: &str) -> String {
62 format!("{}/v1/embeddings", normalize_local_base(input))
63}
64
65pub struct LocalOpenAiProvider {
66 client: reqwest::Client,
67}
68
69impl LocalOpenAiProvider {
70 pub fn new() -> Self {
71 Self {
72 client: reqwest::Client::new(),
73 }
74 }
75
76 pub async fn list_models(&self, base: &str) -> Result<Vec<String>, CloudAiError> {
81 let url = local_models_url(base);
82 let resp = self
83 .client
84 .get(&url)
85 .send()
86 .await
87 .map_err(|e| CloudAiError::Http {
88 provider: NAME.into(),
89 reason: e.to_string(),
90 })?;
91 let status = resp.status();
92 let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
93 provider: NAME.into(),
94 reason: e.to_string(),
95 })?;
96 if !status.is_success() {
97 return Err(CloudAiError::Status {
98 provider: NAME.into(),
99 status: status.as_u16(),
100 body: redact_error_body(&raw_body),
101 });
102 }
103 parse_model_list(&raw_body)
104 }
105}
106
107#[derive(Debug, Deserialize)]
108struct ModelList {
109 data: Vec<ModelEntry>,
110}
111
112#[derive(Debug, Deserialize)]
113struct ModelEntry {
114 id: String,
115}
116
117pub fn parse_model_list(body: &str) -> Result<Vec<String>, CloudAiError> {
118 let parsed: ModelList = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
119 provider: NAME.into(),
120 reason: e.to_string(),
121 })?;
122 Ok(parsed.data.into_iter().map(|m| m.id).collect())
123}
124
125impl Default for LocalOpenAiProvider {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131#[async_trait]
132impl CloudAiProvider for LocalOpenAiProvider {
133 fn name(&self) -> &str {
134 NAME
135 }
136
137 fn env_var(&self) -> &str {
138 ENV_VAR
139 }
140
141 fn default_models(&self) -> &[&str] {
142 DEFAULT_MODELS
143 }
144
145 fn category(&self) -> ProviderCategory {
146 ProviderCategory::Local
147 }
148
149 async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError> {
150 let base = req.base_url.as_deref().filter(|s| !s.is_empty()).ok_or(
151 CloudAiError::Shape {
152 provider: NAME.into(),
153 detail: "no base_url configured; set the Local AI base URL in Settings".into(),
154 },
155 )?;
156 let endpoint = local_chat_url(base);
157
158 let body = build_body(req);
159
160 let started = Instant::now();
161 let mut builder = self
162 .client
163 .post(&endpoint)
164 .header("content-type", "application/json");
165 if !req.api_key.is_empty() {
168 builder = builder.bearer_auth(&req.api_key);
169 }
170
171 let resp = builder
172 .json(&body)
173 .send()
174 .await
175 .map_err(|e| CloudAiError::Http {
176 provider: NAME.into(),
177 reason: e.to_string(),
178 })?;
179
180 let status = resp.status();
181 let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
182 provider: NAME.into(),
183 reason: e.to_string(),
184 })?;
185 let latency_ms = started.elapsed().as_millis() as u64;
186
187 if !status.is_success() {
188 return Err(CloudAiError::Status {
189 provider: NAME.into(),
190 status: status.as_u16(),
191 body: redact_error_body(&raw_body),
192 });
193 }
194
195 parse_response(NAME, &raw_body, latency_ms)
196 }
197
198 async fn invoke_stream(
199 &self,
200 req: &CloudAiRequest,
201 sink: &dyn LlmStreamSink,
202 ) -> Result<CloudAiResponse, CloudAiError> {
203 let base = req.base_url.as_deref().filter(|s| !s.is_empty()).ok_or(
204 CloudAiError::Shape {
205 provider: NAME.into(),
206 detail: "no base_url configured; set the Local AI base URL in Settings".into(),
207 },
208 )?;
209 let endpoint = local_chat_url(base);
210 let call_id = req.call_id.clone().unwrap_or_default();
211 let body = build_streaming_body(req);
212 let mut builder = self
213 .client
214 .post(&endpoint)
215 .header("content-type", "application/json")
216 .header("accept", "text/event-stream");
217 if !req.api_key.is_empty() {
218 builder = builder.bearer_auth(&req.api_key);
219 }
220 let builder = builder.json(&body);
221 stream_response(builder, NAME, &call_id, sink).await
222 }
223
224 async fn invoke_tools(
225 &self,
226 req: &CloudAiRequest,
227 tools: &[ToolSpec],
228 dispatcher: &dyn ToolDispatcher,
229 max_iters: usize,
230 ) -> Result<CloudAiResponse, CloudAiError> {
231 let base = req.base_url.as_deref().filter(|s| !s.is_empty()).ok_or(
232 CloudAiError::Shape {
233 provider: NAME.into(),
234 detail: "no base_url configured; set the Local AI base URL in Settings".into(),
235 },
236 )?;
237 let endpoint = local_chat_url(base);
238 run_tools_loop(
239 &self.client,
240 &endpoint,
241 NAME,
242 &req.api_key,
243 req,
244 tools,
245 dispatcher,
246 max_iters,
247 )
248 .await
249 }
250
251 async fn embed(&self, req: &CloudAiRequest) -> Result<EmbeddingResponse, CloudAiError> {
252 let base = req.base_url.as_deref().filter(|s| !s.is_empty()).ok_or(
253 CloudAiError::Shape {
254 provider: NAME.into(),
255 detail: "no base_url configured; set the Local AI base URL in Settings".into(),
256 },
257 )?;
258 let endpoint = local_embeddings_url(base);
259 let body = build_embeddings_body(req);
260 let started = Instant::now();
261 let mut builder = self
262 .client
263 .post(&endpoint)
264 .header("content-type", "application/json");
265 if !req.api_key.is_empty() {
266 builder = builder.bearer_auth(&req.api_key);
267 }
268 let resp = builder
269 .json(&body)
270 .send()
271 .await
272 .map_err(|e| CloudAiError::Http {
273 provider: NAME.into(),
274 reason: e.to_string(),
275 })?;
276 let status = resp.status();
277 let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
278 provider: NAME.into(),
279 reason: e.to_string(),
280 })?;
281 let latency_ms = started.elapsed().as_millis() as u64;
282 if !status.is_success() {
283 return Err(CloudAiError::Status {
284 provider: NAME.into(),
285 status: status.as_u16(),
286 body: redact_error_body(&raw_body),
287 });
288 }
289 parse_embeddings_response(NAME, &req.model, &raw_body, latency_ms)
290 }
291}