Skip to main content

flow_adapter_ai/
registry.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::error::CloudAiError;
6use crate::request::{
7    CloudAiRequest, CloudAiResponse, EmbeddingResponse, ToolDispatcher, ToolSpec,
8};
9use crate::stream::{LlmStreamEvent, LlmStreamSink};
10
11/// Where a provider runs. Drives the Settings + node-inspector UI grouping
12/// (one "LLM" area split into Local vs Cloud) without the frontend having to
13/// hardcode provider names. `Cloud` providers make outbound network calls
14/// gated by `allow_cloud_ai`; `Local` providers hit a user-configured
15/// on-device endpoint gated by `allow_local_ai`.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum ProviderCategory {
19    Local,
20    Cloud,
21}
22
23/// Cloud-AI provider abstraction. `Cloud` implementations make outbound
24/// network calls - the egress is intentional and gated by the
25/// `allow_cloud_ai` setting at the executor layer. `Local` implementations
26/// target a user-configured localhost endpoint and are gated by
27/// `allow_local_ai`.
28#[async_trait]
29pub trait CloudAiProvider: Send + Sync {
30    fn name(&self) -> &str;
31
32    /// Convention-based environment variable name for this provider's API key.
33    fn env_var(&self) -> &str;
34
35    /// Default models suggested in the UI dropdown. Users can also type any
36    /// model name the provider accepts.
37    fn default_models(&self) -> &[&str];
38
39    /// Per-model capability tags advertised to the UI, using the same
40    /// vocabulary as the Hub catalog (`code | tool_use | reasoning | vision |
41    /// embedding`). The node inspector gates capability-specific fields (image
42    /// input, thinking, tool binding) on these the same way it does for local
43    /// models resolved from the Hub. Cloud providers override this; the default
44    /// advertises nothing, so the inspector falls back to showing every field.
45    fn model_capabilities(&self) -> Vec<ModelCapabilityInfo> {
46        Vec::new()
47    }
48
49    /// Local vs cloud. Defaults to `Cloud` so existing providers need no
50    /// change; the local OpenAI-compatible provider overrides it.
51    fn category(&self) -> ProviderCategory {
52        ProviderCategory::Cloud
53    }
54
55    async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError>;
56
57    /// Streaming variant of `invoke`. Providers that support real
58    /// streaming override this to push tokens through `sink` as they
59    /// arrive; the default impl wraps `invoke` and emits a single
60    /// `done = true` event so non-streaming providers are still
61    /// observable on the chip without per-call branching at the call
62    /// site.
63    ///
64    /// Implementations MUST emit exactly one event with `done = true`
65    /// (either with the final partial delta or alone), even on failure.
66    async fn invoke_stream(
67        &self,
68        req: &CloudAiRequest,
69        sink: &dyn LlmStreamSink,
70    ) -> Result<CloudAiResponse, CloudAiError> {
71        let call_id = req.call_id.clone().unwrap_or_default();
72        match self.invoke(req).await {
73            Ok(resp) => {
74                sink.emit(LlmStreamEvent::final_delta(&call_id, &resp)).await;
75                Ok(resp)
76            }
77            Err(e) => {
78                sink.emit(LlmStreamEvent::error(&call_id, self.name(), e.to_string()))
79                    .await;
80                Err(e)
81            }
82        }
83    }
84
85    /// Tool-use / function-calling loop. Providers that support tools override
86    /// this: they expose `tools` to the model, run each requested call through
87    /// `dispatcher`, feed the results back, and repeat until the model stops
88    /// asking (or `max_iters` is hit). The default ignores tools and does a
89    /// single `invoke`, so providers without tool support degrade gracefully.
90    async fn invoke_tools(
91        &self,
92        req: &CloudAiRequest,
93        _tools: &[ToolSpec],
94        _dispatcher: &dyn ToolDispatcher,
95        _max_iters: usize,
96    ) -> Result<CloudAiResponse, CloudAiError> {
97        self.invoke(req).await
98    }
99
100    /// Embeddings call for `task: "embedding"`. Providers with an embeddings
101    /// endpoint override this; the default reports the capability unsupported
102    /// (e.g. Anthropic exposes no embeddings API).
103    async fn embed(&self, _req: &CloudAiRequest) -> Result<EmbeddingResponse, CloudAiError> {
104        Err(CloudAiError::Unsupported {
105            provider: self.name().to_string(),
106            capability: "embeddings".to_string(),
107        })
108    }
109}
110
111#[derive(Default)]
112pub struct CloudAiRegistry {
113    providers: HashMap<String, Arc<dyn CloudAiProvider>>,
114}
115
116impl CloudAiRegistry {
117    pub fn new() -> Self {
118        Self::default()
119    }
120
121    pub fn register(&mut self, provider: Arc<dyn CloudAiProvider>) {
122        self.providers.insert(provider.name().to_string(), provider);
123    }
124
125    pub fn get(&self, name: &str) -> Option<Arc<dyn CloudAiProvider>> {
126        self.providers.get(name).cloned()
127    }
128
129    pub fn list(&self) -> Vec<ProviderInfo> {
130        let mut out: Vec<_> = self
131            .providers
132            .values()
133            .map(|p| ProviderInfo {
134                name: p.name().to_string(),
135                env_var: p.env_var().to_string(),
136                default_models: p.default_models().iter().map(|s| s.to_string()).collect(),
137                model_capabilities: p.model_capabilities(),
138                category: p.category(),
139            })
140            .collect();
141        out.sort_by(|a, b| a.name.cmp(&b.name));
142        out
143    }
144}
145
146#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
147pub struct ProviderInfo {
148    pub name: String,
149    pub env_var: String,
150    pub default_models: Vec<String>,
151    /// Capability tags per model id (see [`CloudAiProvider::model_capabilities`]).
152    #[serde(default)]
153    pub model_capabilities: Vec<ModelCapabilityInfo>,
154    pub category: ProviderCategory,
155}
156
157/// Capability tags for a single model, surfaced to the node inspector so cloud
158/// models get the same capability-gated fields as local (Hub-catalog) models.
159#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
160pub struct ModelCapabilityInfo {
161    pub model: String,
162    /// Tags from the Hub vocabulary: `code | tool_use | reasoning | vision |
163    /// embedding`.
164    pub capabilities: Vec<String>,
165}
166
167impl ModelCapabilityInfo {
168    /// Build an entry from a model id and its capability tags.
169    pub fn new(model: &str, capabilities: &[&str]) -> Self {
170        Self {
171            model: model.to_string(),
172            capabilities: capabilities.iter().map(|s| s.to_string()).collect(),
173        }
174    }
175}