flow_adapter_ai/registry.rs
1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::error::CloudAiError;
6use crate::request::{
7 CloudAiRequest, CloudAiResponse, EmbeddingResponse, ToolDispatcher, ToolSpec,
8};
9use crate::stream::{LlmStreamEvent, LlmStreamSink};
10
11/// Where a provider runs. Drives the Settings + node-inspector UI grouping
12/// (one "LLM" area split into Local vs Cloud) without the frontend having to
13/// hardcode provider names. `Cloud` providers make outbound network calls
14/// gated by `allow_cloud_ai`; `Local` providers hit a user-configured
15/// on-device endpoint gated by `allow_local_ai`.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum ProviderCategory {
19 Local,
20 Cloud,
21}
22
23/// Cloud-AI provider abstraction. `Cloud` implementations make outbound
24/// network calls - the egress is intentional and gated by the
25/// `allow_cloud_ai` setting at the executor layer. `Local` implementations
26/// target a user-configured localhost endpoint and are gated by
27/// `allow_local_ai`.
28#[async_trait]
29pub trait CloudAiProvider: Send + Sync {
30 fn name(&self) -> &str;
31
32 /// Convention-based environment variable name for this provider's API key.
33 fn env_var(&self) -> &str;
34
35 /// Default models suggested in the UI dropdown. Users can also type any
36 /// model name the provider accepts.
37 fn default_models(&self) -> &[&str];
38
39 /// Per-model capability tags advertised to the UI, using the same
40 /// vocabulary as the Hub catalog (`code | tool_use | reasoning | vision |
41 /// embedding`). The node inspector gates capability-specific fields (image
42 /// input, thinking, tool binding) on these the same way it does for local
43 /// models resolved from the Hub. Cloud providers override this; the default
44 /// advertises nothing, so the inspector falls back to showing every field.
45 fn model_capabilities(&self) -> Vec<ModelCapabilityInfo> {
46 Vec::new()
47 }
48
49 /// Local vs cloud. Defaults to `Cloud` so existing providers need no
50 /// change; the local OpenAI-compatible provider overrides it.
51 fn category(&self) -> ProviderCategory {
52 ProviderCategory::Cloud
53 }
54
55 async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError>;
56
57 /// Streaming variant of `invoke`. Providers that support real
58 /// streaming override this to push tokens through `sink` as they
59 /// arrive; the default impl wraps `invoke` and emits a single
60 /// `done = true` event so non-streaming providers are still
61 /// observable on the chip without per-call branching at the call
62 /// site.
63 ///
64 /// Implementations MUST emit exactly one event with `done = true`
65 /// (either with the final partial delta or alone), even on failure.
66 async fn invoke_stream(
67 &self,
68 req: &CloudAiRequest,
69 sink: &dyn LlmStreamSink,
70 ) -> Result<CloudAiResponse, CloudAiError> {
71 let call_id = req.call_id.clone().unwrap_or_default();
72 match self.invoke(req).await {
73 Ok(resp) => {
74 sink.emit(LlmStreamEvent::final_delta(&call_id, &resp)).await;
75 Ok(resp)
76 }
77 Err(e) => {
78 sink.emit(LlmStreamEvent::error(&call_id, self.name(), e.to_string()))
79 .await;
80 Err(e)
81 }
82 }
83 }
84
85 /// Tool-use / function-calling loop. Providers that support tools override
86 /// this: they expose `tools` to the model, run each requested call through
87 /// `dispatcher`, feed the results back, and repeat until the model stops
88 /// asking (or `max_iters` is hit). The default ignores tools and does a
89 /// single `invoke`, so providers without tool support degrade gracefully.
90 async fn invoke_tools(
91 &self,
92 req: &CloudAiRequest,
93 _tools: &[ToolSpec],
94 _dispatcher: &dyn ToolDispatcher,
95 _max_iters: usize,
96 ) -> Result<CloudAiResponse, CloudAiError> {
97 self.invoke(req).await
98 }
99
100 /// Embeddings call for `task: "embedding"`. Providers with an embeddings
101 /// endpoint override this; the default reports the capability unsupported
102 /// (e.g. Anthropic exposes no embeddings API).
103 async fn embed(&self, _req: &CloudAiRequest) -> Result<EmbeddingResponse, CloudAiError> {
104 Err(CloudAiError::Unsupported {
105 provider: self.name().to_string(),
106 capability: "embeddings".to_string(),
107 })
108 }
109}
110
111#[derive(Default)]
112pub struct CloudAiRegistry {
113 providers: HashMap<String, Arc<dyn CloudAiProvider>>,
114}
115
116impl CloudAiRegistry {
117 pub fn new() -> Self {
118 Self::default()
119 }
120
121 pub fn register(&mut self, provider: Arc<dyn CloudAiProvider>) {
122 self.providers.insert(provider.name().to_string(), provider);
123 }
124
125 pub fn get(&self, name: &str) -> Option<Arc<dyn CloudAiProvider>> {
126 self.providers.get(name).cloned()
127 }
128
129 pub fn list(&self) -> Vec<ProviderInfo> {
130 let mut out: Vec<_> = self
131 .providers
132 .values()
133 .map(|p| ProviderInfo {
134 name: p.name().to_string(),
135 env_var: p.env_var().to_string(),
136 default_models: p.default_models().iter().map(|s| s.to_string()).collect(),
137 model_capabilities: p.model_capabilities(),
138 category: p.category(),
139 })
140 .collect();
141 out.sort_by(|a, b| a.name.cmp(&b.name));
142 out
143 }
144}
145
146#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
147pub struct ProviderInfo {
148 pub name: String,
149 pub env_var: String,
150 pub default_models: Vec<String>,
151 /// Capability tags per model id (see [`CloudAiProvider::model_capabilities`]).
152 #[serde(default)]
153 pub model_capabilities: Vec<ModelCapabilityInfo>,
154 pub category: ProviderCategory,
155}
156
157/// Capability tags for a single model, surfaced to the node inspector so cloud
158/// models get the same capability-gated fields as local (Hub-catalog) models.
159#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
160pub struct ModelCapabilityInfo {
161 pub model: String,
162 /// Tags from the Hub vocabulary: `code | tool_use | reasoning | vision |
163 /// embedding`.
164 pub capabilities: Vec<String>,
165}
166
167impl ModelCapabilityInfo {
168 /// Build an entry from a model id and its capability tags.
169 pub fn new(model: &str, capabilities: &[&str]) -> Self {
170 Self {
171 model: model.to_string(),
172 capabilities: capabilities.iter().map(|s| s.to_string()).collect(),
173 }
174 }
175}