Skip to main content

flow_adapter_ai/providers/
nvidia.rs

1//! NVIDIA provider.
2//!
3//! Targets NVIDIA's hosted inference API at `https://integrate.api.nvidia.com`,
4//! which exposes the OpenAI `/v1/chat/completions` shape with Bearer-token auth
5//! (`nvapi-...` keys). Apart from the endpoint, env var, and default model list
6//! it is identical to the cloud `openai` provider, so it reuses the shared
7//! `openai_compat` request/response/streaming/tool helpers.
8
9use async_trait::async_trait;
10use std::time::Instant;
11
12use crate::error::{redact_error_body, CloudAiError};
13use crate::providers::openai_compat::{
14    build_body, build_streaming_body, parse_response, run_tools_loop, stream_response,
15};
16use crate::registry::{CloudAiProvider, ModelCapabilityInfo};
17use crate::request::{CloudAiRequest, CloudAiResponse, ToolDispatcher, ToolSpec};
18use crate::stream::LlmStreamSink;
19
20const NAME: &str = "nvidia";
21const ENV_VAR: &str = "NVIDIA_API_KEY";
22const ENDPOINT: &str = "https://integrate.api.nvidia.com/v1/chat/completions";
23
24const DEFAULT_MODELS: &[&str] = &[
25    "nvidia/llama-3.3-nemotron-super-49b-v1",
26    "nvidia/llama-3.1-nemotron-70b-instruct",
27    "minimaxai/minimax-m2.7",
28    "qwen/qwen3-coder-480b-a35b-instruct",
29    "mistralai/mistral-nemotron",
30    "mistralai/mistral-large-3-675b-instruct-2512",
31];
32
33pub struct NvidiaProvider {
34    client: reqwest::Client,
35}
36
37impl NvidiaProvider {
38    pub fn new() -> Self {
39        Self {
40            client: reqwest::Client::new(),
41        }
42    }
43}
44
45impl Default for NvidiaProvider {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51#[async_trait]
52impl CloudAiProvider for NvidiaProvider {
53    fn name(&self) -> &str {
54        NAME
55    }
56
57    fn env_var(&self) -> &str {
58        ENV_VAR
59    }
60
61    fn default_models(&self) -> &[&str] {
62        DEFAULT_MODELS
63    }
64
65    fn model_capabilities(&self) -> Vec<ModelCapabilityInfo> {
66        vec![
67            ModelCapabilityInfo::new(
68                "nvidia/llama-3.3-nemotron-super-49b-v1",
69                &["reasoning", "tool_use", "code"],
70            ),
71            ModelCapabilityInfo::new(
72                "nvidia/llama-3.1-nemotron-70b-instruct",
73                &["tool_use", "code"],
74            ),
75            ModelCapabilityInfo::new("minimaxai/minimax-m2.7", &["reasoning", "tool_use", "code"]),
76            ModelCapabilityInfo::new("qwen/qwen3-coder-480b-a35b-instruct", &["tool_use", "code"]),
77            ModelCapabilityInfo::new("mistralai/mistral-nemotron", &["tool_use", "code"]),
78            ModelCapabilityInfo::new(
79                "mistralai/mistral-large-3-675b-instruct-2512",
80                &["tool_use", "code"],
81            ),
82        ]
83    }
84
85    async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError> {
86        let body = build_body(req);
87
88        let started = Instant::now();
89        let resp = self
90            .client
91            .post(ENDPOINT)
92            .bearer_auth(&req.api_key)
93            .header("content-type", "application/json")
94            .json(&body)
95            .send()
96            .await
97            .map_err(|e| CloudAiError::Http {
98                provider: NAME.into(),
99                reason: e.to_string(),
100            })?;
101
102        let status = resp.status();
103        let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
104            provider: NAME.into(),
105            reason: e.to_string(),
106        })?;
107        let latency_ms = started.elapsed().as_millis() as u64;
108
109        if !status.is_success() {
110            return Err(CloudAiError::Status {
111                provider: NAME.into(),
112                status: status.as_u16(),
113                body: redact_error_body(&raw_body),
114            });
115        }
116
117        parse_response(NAME, &raw_body, latency_ms)
118    }
119
120    async fn invoke_stream(
121        &self,
122        req: &CloudAiRequest,
123        sink: &dyn LlmStreamSink,
124    ) -> Result<CloudAiResponse, CloudAiError> {
125        let call_id = req.call_id.clone().unwrap_or_default();
126        let body = build_streaming_body(req);
127        let builder = self
128            .client
129            .post(ENDPOINT)
130            .bearer_auth(&req.api_key)
131            .header("content-type", "application/json")
132            .header("accept", "text/event-stream")
133            .json(&body);
134        stream_response(builder, NAME, &call_id, sink).await
135    }
136
137    async fn invoke_tools(
138        &self,
139        req: &CloudAiRequest,
140        tools: &[ToolSpec],
141        dispatcher: &dyn ToolDispatcher,
142        max_iters: usize,
143    ) -> Result<CloudAiResponse, CloudAiError> {
144        run_tools_loop(
145            &self.client,
146            ENDPOINT,
147            NAME,
148            &req.api_key,
149            req,
150            tools,
151            dispatcher,
152            max_iters,
153        )
154        .await
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn body_includes_max_tokens_and_temperature() {
164        let req = CloudAiRequest {
165            model: "nvidia/llama-3.1-nemotron-70b-instruct".into(),
166            prompt: "hello".into(),
167            max_tokens: Some(64),
168            temperature: Some(0.25),
169            api_key: "x".into(),
170            system: None,
171            base_url: None,
172            ..Default::default()
173        };
174        let body = build_body(&req);
175        assert_eq!(body["model"], "nvidia/llama-3.1-nemotron-70b-instruct");
176        assert_eq!(body["max_tokens"], 64);
177        assert_eq!(body["temperature"], 0.25);
178        // No system => only the user message, no leading system entry.
179        assert_eq!(body["messages"].as_array().unwrap().len(), 1);
180        assert_eq!(body["messages"][0]["role"], "user");
181    }
182
183    #[test]
184    fn body_prepends_system_message_when_set() {
185        let req = CloudAiRequest {
186            model: "nvidia/llama-3.1-nemotron-70b-instruct".into(),
187            prompt: "hello".into(),
188            max_tokens: None,
189            temperature: None,
190            api_key: "x".into(),
191            system: Some("You are Flow's DSL generator.".into()),
192            base_url: None,
193            ..Default::default()
194        };
195        let body = build_body(&req);
196        let msgs = body["messages"].as_array().unwrap();
197        assert_eq!(msgs.len(), 2);
198        assert_eq!(msgs[0]["role"], "system");
199        assert_eq!(msgs[0]["content"], "You are Flow's DSL generator.");
200        assert_eq!(msgs[1]["role"], "user");
201        assert_eq!(msgs[1]["content"], "hello");
202    }
203}