flow_adapter_ai/providers/
openai.rs1use async_trait::async_trait;
2use std::time::Instant;
3
4use crate::error::{redact_error_body, CloudAiError};
5use crate::providers::openai_compat::{
6 build_body, build_embeddings_body, build_streaming_body, parse_embeddings_response,
7 parse_response, run_tools_loop, stream_response,
8};
9use crate::registry::{CloudAiProvider, ModelCapabilityInfo};
10use crate::request::{
11 CloudAiRequest, CloudAiResponse, EmbeddingResponse, ToolDispatcher, ToolSpec,
12};
13use crate::stream::LlmStreamSink;
14
15const NAME: &str = "openai";
16const ENV_VAR: &str = "OPENAI_API_KEY";
17const ENDPOINT: &str = "https://api.openai.com/v1/chat/completions";
18const EMBEDDINGS_ENDPOINT: &str = "https://api.openai.com/v1/embeddings";
19
20const DEFAULT_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"];
21
22pub struct OpenAiProvider {
23 client: reqwest::Client,
24}
25
26impl OpenAiProvider {
27 pub fn new() -> Self {
28 Self {
29 client: reqwest::Client::new(),
30 }
31 }
32}
33
34impl Default for OpenAiProvider {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40#[async_trait]
41impl CloudAiProvider for OpenAiProvider {
42 fn name(&self) -> &str {
43 NAME
44 }
45
46 fn env_var(&self) -> &str {
47 ENV_VAR
48 }
49
50 fn default_models(&self) -> &[&str] {
51 DEFAULT_MODELS
52 }
53
54 fn model_capabilities(&self) -> Vec<ModelCapabilityInfo> {
55 vec![
56 ModelCapabilityInfo::new("gpt-4o", &["vision", "tool_use", "code"]),
57 ModelCapabilityInfo::new("gpt-4o-mini", &["vision", "tool_use", "code"]),
58 ModelCapabilityInfo::new("gpt-4-turbo", &["vision", "tool_use", "code"]),
59 ]
60 }
61
62 async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError> {
63 let body = build_body(req);
64
65 let started = Instant::now();
66 let resp = self
67 .client
68 .post(ENDPOINT)
69 .bearer_auth(&req.api_key)
70 .header("content-type", "application/json")
71 .json(&body)
72 .send()
73 .await
74 .map_err(|e| CloudAiError::Http {
75 provider: NAME.into(),
76 reason: e.to_string(),
77 })?;
78
79 let status = resp.status();
80 let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
81 provider: NAME.into(),
82 reason: e.to_string(),
83 })?;
84 let latency_ms = started.elapsed().as_millis() as u64;
85
86 if !status.is_success() {
87 return Err(CloudAiError::Status {
88 provider: NAME.into(),
89 status: status.as_u16(),
90 body: redact_error_body(&raw_body),
91 });
92 }
93
94 parse_response(NAME, &raw_body, latency_ms)
95 }
96
97 async fn invoke_stream(
98 &self,
99 req: &CloudAiRequest,
100 sink: &dyn LlmStreamSink,
101 ) -> Result<CloudAiResponse, CloudAiError> {
102 let call_id = req.call_id.clone().unwrap_or_default();
103 let body = build_streaming_body(req);
104 let builder = self
105 .client
106 .post(ENDPOINT)
107 .bearer_auth(&req.api_key)
108 .header("content-type", "application/json")
109 .header("accept", "text/event-stream")
110 .json(&body);
111 stream_response(builder, NAME, &call_id, sink).await
112 }
113
114 async fn invoke_tools(
115 &self,
116 req: &CloudAiRequest,
117 tools: &[ToolSpec],
118 dispatcher: &dyn ToolDispatcher,
119 max_iters: usize,
120 ) -> Result<CloudAiResponse, CloudAiError> {
121 run_tools_loop(
122 &self.client,
123 ENDPOINT,
124 NAME,
125 &req.api_key,
126 req,
127 tools,
128 dispatcher,
129 max_iters,
130 )
131 .await
132 }
133
134 async fn embed(&self, req: &CloudAiRequest) -> Result<EmbeddingResponse, CloudAiError> {
135 let body = build_embeddings_body(req);
136 let started = Instant::now();
137 let resp = self
138 .client
139 .post(EMBEDDINGS_ENDPOINT)
140 .bearer_auth(&req.api_key)
141 .header("content-type", "application/json")
142 .json(&body)
143 .send()
144 .await
145 .map_err(|e| CloudAiError::Http {
146 provider: NAME.into(),
147 reason: e.to_string(),
148 })?;
149 let status = resp.status();
150 let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
151 provider: NAME.into(),
152 reason: e.to_string(),
153 })?;
154 let latency_ms = started.elapsed().as_millis() as u64;
155 if !status.is_success() {
156 return Err(CloudAiError::Status {
157 provider: NAME.into(),
158 status: status.as_u16(),
159 body: redact_error_body(&raw_body),
160 });
161 }
162 parse_embeddings_response(NAME, &req.model, &raw_body, latency_ms)
163 }
164}