flow_models_server/
lib.rs1pub mod fetch;
13
14use std::path::PathBuf;
15use std::sync::Mutex;
16use std::time::Duration;
17
18use serde::{Deserialize, Serialize};
19use tokio::process::{Child, Command};
20
21use flow_adapter_ai::{local_models_url, LocalOpenAiProvider};
22
23#[derive(Debug, Clone, Default, Serialize, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct LlamaParams {
28 pub ctx_size: Option<u32>,
30 pub n_gpu_layers: Option<i32>,
32 pub threads: Option<u32>,
34 pub batch_size: Option<u32>,
36 pub parallel: Option<u32>,
38 pub seed: Option<i64>,
40 pub flash_attn: Option<bool>,
42 pub mlock: Option<bool>,
44 pub mmap: Option<bool>,
46 pub cache_type_k: Option<String>,
48 pub cache_type_v: Option<String>,
50 pub enable_thinking: Option<bool>,
54}
55
56const READINESS_TIMEOUT: Duration = Duration::from_secs(45);
58const READINESS_POLL: Duration = Duration::from_millis(500);
59
60#[derive(Debug, Clone, Serialize, Default)]
62#[serde(rename_all = "camelCase")]
63pub struct LlmServerStatus {
64 pub running: bool,
65 pub endpoint: Option<String>,
67 pub model_path: Option<String>,
69 pub pid: Option<u32>,
70}
71
72struct Running {
73 child: Child,
74 port: u16,
75 model_path: PathBuf,
76}
77
78#[derive(Clone)]
81pub struct LlmServerHandle {
82 inner: std::sync::Arc<Mutex<Option<Running>>>,
83}
84
85impl Default for LlmServerHandle {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl LlmServerHandle {
92 pub fn new() -> Self {
93 Self {
94 inner: std::sync::Arc::new(Mutex::new(None)),
95 }
96 }
97
98 pub async fn start(
102 &self,
103 binary: PathBuf,
104 model_path: PathBuf,
105 params: LlamaParams,
106 ) -> Result<String, String> {
107 if !binary.is_file() {
108 return Err(format!(
109 "llama-server binary not found at {}",
110 binary.display()
111 ));
112 }
113 if !model_path.is_file() {
114 return Err(format!("model file not found at {}", model_path.display()));
115 }
116
117 self.stop().await;
118
119 let port = free_loopback_port()?;
120 let mut command = Command::new(&binary);
121 if let Some(dir) = binary.parent() {
124 command.current_dir(dir);
125 }
126 command
127 .arg("--model")
128 .arg(&model_path)
129 .arg("--host")
130 .arg("127.0.0.1")
131 .arg("--port")
132 .arg(port.to_string());
133 if let Some(ctx) = params.ctx_size {
134 command.arg("--ctx-size").arg(ctx.to_string());
135 }
136 if let Some(ngl) = params.n_gpu_layers {
137 command.arg("--n-gpu-layers").arg(ngl.to_string());
138 }
139 if let Some(threads) = params.threads {
140 command.arg("--threads").arg(threads.to_string());
141 }
142 if let Some(batch) = params.batch_size {
143 command.arg("--batch-size").arg(batch.to_string());
144 }
145 if let Some(parallel) = params.parallel {
146 command.arg("--parallel").arg(parallel.to_string());
147 }
148 if let Some(seed) = params.seed {
149 command.arg("--seed").arg(seed.to_string());
150 }
151 if params.flash_attn == Some(true) {
152 command.arg("--flash-attn");
153 }
154 if params.mlock == Some(true) {
155 command.arg("--mlock");
156 }
157 if params.mmap == Some(false) {
158 command.arg("--no-mmap");
159 }
160 if let Some(k) = params.cache_type_k.as_deref().filter(|s| !s.is_empty()) {
161 command.arg("--cache-type-k").arg(k);
162 }
163 if let Some(v) = params.cache_type_v.as_deref().filter(|s| !s.is_empty()) {
164 command.arg("--cache-type-v").arg(v);
165 }
166 if params.enable_thinking == Some(false) {
167 command.arg("--reasoning-budget").arg("0");
169 }
170 let child = command
171 .kill_on_drop(true)
172 .spawn()
173 .map_err(|e| format!("spawn {}: {e}", binary.display()))?;
174
175 {
177 let mut guard = self.inner.lock().unwrap();
178 *guard = Some(Running {
179 child,
180 port,
181 model_path: model_path.clone(),
182 });
183 }
184
185 let base = format!("http://127.0.0.1:{port}");
186 match wait_until_ready(&base, &self.inner).await {
187 Ok(()) => Ok(base),
188 Err(e) => {
189 self.stop().await;
190 Err(format!(
191 "llama-server did not become ready at {}: {e}",
192 local_models_url(&base)
193 ))
194 }
195 }
196 }
197
198 pub async fn stop(&self) {
200 let running = self.inner.lock().unwrap().take();
201 if let Some(mut r) = running {
202 let _ = r.child.start_kill();
205 let _ = r.child.wait().await;
206 }
207 }
208
209 pub fn status(&self) -> LlmServerStatus {
212 let guard = self.inner.lock().unwrap();
213 match guard.as_ref() {
214 Some(r) => LlmServerStatus {
215 running: true,
216 endpoint: Some(format!("http://127.0.0.1:{}", r.port)),
217 model_path: Some(r.model_path.to_string_lossy().to_string()),
218 pid: r.child.id(),
219 },
220 None => LlmServerStatus::default(),
221 }
222 }
223}
224
225fn free_loopback_port() -> Result<u16, String> {
229 let listener = std::net::TcpListener::bind("127.0.0.1:0")
230 .map_err(|e| format!("could not find a free port: {e}"))?;
231 let port = listener
232 .local_addr()
233 .map_err(|e| format!("local_addr: {e}"))?
234 .port();
235 drop(listener);
236 Ok(port)
237}
238
239async fn wait_until_ready(base: &str, inner: &Mutex<Option<Running>>) -> Result<(), String> {
244 let provider = LocalOpenAiProvider::new();
245 let deadline = std::time::Instant::now() + READINESS_TIMEOUT;
246 let mut last_err = String::from("timed out");
247 while std::time::Instant::now() < deadline {
248 {
250 let mut guard = inner.lock().unwrap();
251 if let Some(r) = guard.as_mut() {
252 if let Ok(Some(status)) = r.child.try_wait() {
253 return Err(format!("llama-server exited early ({status})"));
254 }
255 }
256 }
257 match provider.list_models(base).await {
258 Ok(_) => return Ok(()),
259 Err(e) => last_err = e.to_string(),
260 }
261 tokio::time::sleep(READINESS_POLL).await;
262 }
263 Err(last_err)
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn free_port_is_in_range_and_distinct() {
272 let a = free_loopback_port().unwrap();
273 let b = free_loopback_port().unwrap();
274 assert!(a >= 1024);
275 assert!(b >= 1024);
276 }
278
279 #[test]
280 fn status_default_is_stopped() {
281 let h = LlmServerHandle::new();
282 let s = h.status();
283 assert!(!s.running);
284 assert!(s.endpoint.is_none());
285 }
286
287 #[tokio::test]
288 async fn start_rejects_missing_binary() {
289 let h = LlmServerHandle::new();
290 let err = h
291 .start(
292 PathBuf::from("/nonexistent/llama-server"),
293 PathBuf::from("/nonexistent/model.gguf"),
294 LlamaParams::default(),
295 )
296 .await
297 .unwrap_err();
298 assert!(err.contains("binary not found"), "got: {err}");
299 }
300}