Skip to main content

flow_adapter_ai/providers/
openai.rs

1use async_trait::async_trait;
2use std::time::Instant;
3
4use crate::error::{redact_error_body, CloudAiError};
5use crate::providers::openai_compat::{
6    build_body, build_embeddings_body, build_streaming_body, parse_embeddings_response,
7    parse_response, run_tools_loop, stream_response,
8};
9use crate::registry::{CloudAiProvider, ModelCapabilityInfo};
10use crate::request::{
11    CloudAiRequest, CloudAiResponse, EmbeddingResponse, ToolDispatcher, ToolSpec,
12};
13use crate::stream::LlmStreamSink;
14
15const NAME: &str = "openai";
16const ENV_VAR: &str = "OPENAI_API_KEY";
17const ENDPOINT: &str = "https://api.openai.com/v1/chat/completions";
18const EMBEDDINGS_ENDPOINT: &str = "https://api.openai.com/v1/embeddings";
19
20const DEFAULT_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"];
21
22pub struct OpenAiProvider {
23    client: reqwest::Client,
24}
25
26impl OpenAiProvider {
27    pub fn new() -> Self {
28        Self {
29            client: reqwest::Client::new(),
30        }
31    }
32}
33
34impl Default for OpenAiProvider {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40#[async_trait]
41impl CloudAiProvider for OpenAiProvider {
42    fn name(&self) -> &str {
43        NAME
44    }
45
46    fn env_var(&self) -> &str {
47        ENV_VAR
48    }
49
50    fn default_models(&self) -> &[&str] {
51        DEFAULT_MODELS
52    }
53
54    fn model_capabilities(&self) -> Vec<ModelCapabilityInfo> {
55        vec![
56            ModelCapabilityInfo::new("gpt-4o", &["vision", "tool_use", "code"]),
57            ModelCapabilityInfo::new("gpt-4o-mini", &["vision", "tool_use", "code"]),
58            ModelCapabilityInfo::new("gpt-4-turbo", &["vision", "tool_use", "code"]),
59        ]
60    }
61
62    async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError> {
63        let body = build_body(req);
64
65        let started = Instant::now();
66        let resp = self
67            .client
68            .post(ENDPOINT)
69            .bearer_auth(&req.api_key)
70            .header("content-type", "application/json")
71            .json(&body)
72            .send()
73            .await
74            .map_err(|e| CloudAiError::Http {
75                provider: NAME.into(),
76                reason: e.to_string(),
77            })?;
78
79        let status = resp.status();
80        let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
81            provider: NAME.into(),
82            reason: e.to_string(),
83        })?;
84        let latency_ms = started.elapsed().as_millis() as u64;
85
86        if !status.is_success() {
87            return Err(CloudAiError::Status {
88                provider: NAME.into(),
89                status: status.as_u16(),
90                body: redact_error_body(&raw_body),
91            });
92        }
93
94        parse_response(NAME, &raw_body, latency_ms)
95    }
96
97    async fn invoke_stream(
98        &self,
99        req: &CloudAiRequest,
100        sink: &dyn LlmStreamSink,
101    ) -> Result<CloudAiResponse, CloudAiError> {
102        let call_id = req.call_id.clone().unwrap_or_default();
103        let body = build_streaming_body(req);
104        let builder = self
105            .client
106            .post(ENDPOINT)
107            .bearer_auth(&req.api_key)
108            .header("content-type", "application/json")
109            .header("accept", "text/event-stream")
110            .json(&body);
111        stream_response(builder, NAME, &call_id, sink).await
112    }
113
114    async fn invoke_tools(
115        &self,
116        req: &CloudAiRequest,
117        tools: &[ToolSpec],
118        dispatcher: &dyn ToolDispatcher,
119        max_iters: usize,
120    ) -> Result<CloudAiResponse, CloudAiError> {
121        run_tools_loop(
122            &self.client,
123            ENDPOINT,
124            NAME,
125            &req.api_key,
126            req,
127            tools,
128            dispatcher,
129            max_iters,
130        )
131        .await
132    }
133
134    async fn embed(&self, req: &CloudAiRequest) -> Result<EmbeddingResponse, CloudAiError> {
135        let body = build_embeddings_body(req);
136        let started = Instant::now();
137        let resp = self
138            .client
139            .post(EMBEDDINGS_ENDPOINT)
140            .bearer_auth(&req.api_key)
141            .header("content-type", "application/json")
142            .json(&body)
143            .send()
144            .await
145            .map_err(|e| CloudAiError::Http {
146                provider: NAME.into(),
147                reason: e.to_string(),
148            })?;
149        let status = resp.status();
150        let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
151            provider: NAME.into(),
152            reason: e.to_string(),
153        })?;
154        let latency_ms = started.elapsed().as_millis() as u64;
155        if !status.is_success() {
156            return Err(CloudAiError::Status {
157                provider: NAME.into(),
158                status: status.as_u16(),
159                body: redact_error_body(&raw_body),
160            });
161        }
162        parse_embeddings_response(NAME, &req.model, &raw_body, latency_ms)
163    }
164}