Skip to main content

flow_adapter_ai/providers/
deepseek.rs

1//! DeepSeek provider.
2//!
3//! Targets DeepSeek's hosted inference API at `https://api.deepseek.com`,
4//! which exposes the OpenAI `/v1/chat/completions` shape with Bearer-token auth
5//! (`DEEPSEEK_API_KEY` keys). It is identical to the cloud `openai` provider,
6//! so it reuses the shared `openai_compat` request/response/streaming/tool helpers.
7
8use async_trait::async_trait;
9use std::time::Instant;
10
11use crate::error::{redact_error_body, CloudAiError};
12use crate::providers::openai_compat::{
13    build_body, build_streaming_body, parse_response, run_tools_loop, stream_response,
14};
15use crate::registry::{CloudAiProvider, ModelCapabilityInfo};
16use crate::request::{CloudAiRequest, CloudAiResponse, ToolDispatcher, ToolSpec};
17use crate::stream::LlmStreamSink;
18
19const NAME: &str = "deepseek";
20const ENV_VAR: &str = "DEEPSEEK_API_KEY";
21const ENDPOINT: &str = "https://api.deepseek.com/chat/completions";
22
23const DEFAULT_MODELS: &[&str] = &["deepseek-chat", "deepseek-reasoner"];
24
25pub struct DeepSeekProvider {
26    client: reqwest::Client,
27}
28
29impl DeepSeekProvider {
30    pub fn new() -> Self {
31        Self {
32            client: reqwest::Client::new(),
33        }
34    }
35}
36
37impl Default for DeepSeekProvider {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43#[async_trait]
44impl CloudAiProvider for DeepSeekProvider {
45    fn name(&self) -> &str {
46        NAME
47    }
48
49    fn env_var(&self) -> &str {
50        ENV_VAR
51    }
52
53    fn default_models(&self) -> &[&str] {
54        DEFAULT_MODELS
55    }
56
57    fn model_capabilities(&self) -> Vec<ModelCapabilityInfo> {
58        vec![
59            ModelCapabilityInfo::new("deepseek-chat", &["tool_use", "code"]),
60            ModelCapabilityInfo::new("deepseek-reasoner", &["reasoning", "code"]),
61        ]
62    }
63
64    async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError> {
65        let body = build_body(req);
66
67        let started = Instant::now();
68        let resp = self
69            .client
70            .post(ENDPOINT)
71            .bearer_auth(&req.api_key)
72            .header("content-type", "application/json")
73            .json(&body)
74            .send()
75            .await
76            .map_err(|e| CloudAiError::Http {
77                provider: NAME.into(),
78                reason: e.to_string(),
79            })?;
80
81        let status = resp.status();
82        let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
83            provider: NAME.into(),
84            reason: e.to_string(),
85        })?;
86        let latency_ms = started.elapsed().as_millis() as u64;
87
88        if !status.is_success() {
89            return Err(CloudAiError::Status {
90                provider: NAME.into(),
91                status: status.as_u16(),
92                body: redact_error_body(&raw_body),
93            });
94        }
95
96        parse_response(NAME, &raw_body, latency_ms)
97    }
98
99    async fn invoke_stream(
100        &self,
101        req: &CloudAiRequest,
102        sink: &dyn LlmStreamSink,
103    ) -> Result<CloudAiResponse, CloudAiError> {
104        let call_id = req.call_id.clone().unwrap_or_default();
105        let body = build_streaming_body(req);
106        let builder = self
107            .client
108            .post(ENDPOINT)
109            .bearer_auth(&req.api_key)
110            .header("content-type", "application/json")
111            .header("accept", "text/event-stream")
112            .json(&body);
113        stream_response(builder, NAME, &call_id, sink).await
114    }
115
116    async fn invoke_tools(
117        &self,
118        req: &CloudAiRequest,
119        tools: &[ToolSpec],
120        dispatcher: &dyn ToolDispatcher,
121        max_iters: usize,
122    ) -> Result<CloudAiResponse, CloudAiError> {
123        run_tools_loop(
124            &self.client,
125            ENDPOINT,
126            NAME,
127            &req.api_key,
128            req,
129            tools,
130            dispatcher,
131            max_iters,
132        )
133        .await
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn body_includes_max_tokens_and_temperature() {
143        let req = CloudAiRequest {
144            model: "deepseek-chat".into(),
145            prompt: "hello".into(),
146            max_tokens: Some(64),
147            temperature: Some(0.25),
148            api_key: "x".into(),
149            system: None,
150            base_url: None,
151            ..Default::default()
152        };
153        let body = build_body(&req);
154        assert_eq!(body["model"], "deepseek-chat");
155        assert_eq!(body["max_tokens"], 64);
156        assert_eq!(body["temperature"], 0.25);
157        assert_eq!(body["messages"].as_array().unwrap().len(), 1);
158        assert_eq!(body["messages"][0]["role"], "user");
159    }
160}