Skip to main content

flow_adapter_ai/providers/
gemini.rs

1use async_trait::async_trait;
2use futures_util::StreamExt;
3use serde::Deserialize;
4use serde_json::{json, Value};
5use std::time::Instant;
6
7use crate::error::{redact_error_body, CloudAiError};
8use crate::registry::{CloudAiProvider, ModelCapabilityInfo};
9use crate::request::{
10    CloudAiRequest, CloudAiResponse, EmbeddingResponse, ToolDispatcher, ToolSpec,
11};
12use crate::stream::{LlmStreamEvent, LlmStreamSink};
13
14const NAME: &str = "gemini";
15const ENV_VAR: &str = "GOOGLE_API_KEY";
16const ENDPOINT_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
17
18// Keep this list aligned with what the Google Generative Language API
19// actually serves under v1beta. As of mid-2026, `gemini-2.0-pro` no longer
20// resolves (404 from generativelanguage.googleapis.com), so the current
21// generation models are 2.5 with the previous-gen 2.0-flash + 1.5-flash
22// kept around as cheaper fallbacks.
23const DEFAULT_MODELS: &[&str] = &[
24    "gemini-2.5-pro",
25    "gemini-2.5-flash",
26    "gemini-2.0-flash",
27    "gemini-1.5-flash",
28];
29
30pub struct GeminiProvider {
31    client: reqwest::Client,
32}
33
34impl GeminiProvider {
35    pub fn new() -> Self {
36        Self {
37            client: reqwest::Client::new(),
38        }
39    }
40}
41
42impl Default for GeminiProvider {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48#[async_trait]
49impl CloudAiProvider for GeminiProvider {
50    fn name(&self) -> &str {
51        NAME
52    }
53
54    fn env_var(&self) -> &str {
55        ENV_VAR
56    }
57
58    fn default_models(&self) -> &[&str] {
59        DEFAULT_MODELS
60    }
61
62    fn model_capabilities(&self) -> Vec<ModelCapabilityInfo> {
63        vec![
64            ModelCapabilityInfo::new(
65                "gemini-2.5-pro",
66                &["reasoning", "vision", "tool_use", "code"],
67            ),
68            ModelCapabilityInfo::new(
69                "gemini-2.5-flash",
70                &["reasoning", "vision", "tool_use", "code"],
71            ),
72            ModelCapabilityInfo::new("gemini-2.0-flash", &["vision", "tool_use", "code"]),
73            ModelCapabilityInfo::new("gemini-1.5-flash", &["vision", "tool_use", "code"]),
74        ]
75    }
76
77    async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError> {
78        // Send the API key via the `x-goog-api-key` header rather than the
79        // `?key=` query string. `reqwest` error messages on transport failures
80        // (DNS, TLS, redirect, timeout) routinely include the full URL - with
81        // the key in the query string those error strings, surfaced through
82        // CloudAiError::Http to logs and the renderer, would leak the key.
83        // Header values are not echoed by reqwest's Display impl.
84        let url = format!("{ENDPOINT_BASE}/{}:generateContent", req.model);
85        let body = build_body(req);
86
87        let started = Instant::now();
88        let resp = self
89            .client
90            .post(&url)
91            .header("content-type", "application/json")
92            .header("x-goog-api-key", &req.api_key)
93            .json(&body)
94            .send()
95            .await
96            .map_err(|e| CloudAiError::Http {
97                provider: NAME.into(),
98                reason: e.to_string(),
99            })?;
100
101        let status = resp.status();
102        let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
103            provider: NAME.into(),
104            reason: e.to_string(),
105        })?;
106        let latency_ms = started.elapsed().as_millis() as u64;
107
108        if !status.is_success() {
109            return Err(CloudAiError::Status {
110                provider: NAME.into(),
111                status: status.as_u16(),
112                body: redact_error_body(&raw_body),
113            });
114        }
115
116        parse_response(&raw_body, &req.model, latency_ms)
117    }
118
119    async fn invoke_stream(
120        &self,
121        req: &CloudAiRequest,
122        sink: &dyn LlmStreamSink,
123    ) -> Result<CloudAiResponse, CloudAiError> {
124        let call_id = req.call_id.clone().unwrap_or_default();
125        // `alt=sse` switches Gemini's stream from chunked JSON arrays to
126        // OpenAI-style `data: <json>` frames, which is far easier to
127        // parse and matches the path the openai_compat helper takes.
128        let url = format!(
129            "{ENDPOINT_BASE}/{}:streamGenerateContent?alt=sse",
130            req.model
131        );
132        let body = build_body(req);
133        let started = Instant::now();
134        let resp = self
135            .client
136            .post(&url)
137            .header("content-type", "application/json")
138            .header("accept", "text/event-stream")
139            .header("x-goog-api-key", &req.api_key)
140            .json(&body)
141            .send()
142            .await
143            .map_err(|e| CloudAiError::Http {
144                provider: NAME.into(),
145                reason: e.to_string(),
146            })?;
147
148        let status = resp.status();
149        if !status.is_success() {
150            let raw = resp.text().await.unwrap_or_default();
151            let err = CloudAiError::Status {
152                provider: NAME.into(),
153                status: status.as_u16(),
154                body: redact_error_body(&raw),
155            };
156            sink.emit(LlmStreamEvent::error(&call_id, NAME, err.to_string()))
157                .await;
158            return Err(err);
159        }
160
161        let mut stream = resp.bytes_stream();
162        let mut buffer = String::new();
163        let mut full_text = String::new();
164        let mut finish_reason = String::from("unknown");
165        let mut input_tokens: u32 = 0;
166        let mut output_tokens: u32 = 0;
167        let mut errored: Option<String> = None;
168
169        while let Some(chunk) = stream.next().await {
170            let bytes = match chunk {
171                Ok(b) => b,
172                Err(e) => {
173                    errored = Some(e.to_string());
174                    break;
175                }
176            };
177            buffer.push_str(&String::from_utf8_lossy(&bytes));
178            while let Some(idx) = buffer.find("\n\n") {
179                let frame = buffer[..idx].to_string();
180                buffer.drain(..idx + 2);
181                for raw_line in frame.split('\n') {
182                    let line = raw_line.trim_start();
183                    let Some(payload) = line.strip_prefix("data:") else {
184                        continue;
185                    };
186                    let payload = payload.trim();
187                    if payload.is_empty() {
188                        continue;
189                    }
190                    let value: serde_json::Value = match serde_json::from_str(payload) {
191                        Ok(v) => v,
192                        Err(_) => continue,
193                    };
194                    if let Some(usage) = value.get("usageMetadata") {
195                        input_tokens = usage
196                            .get("promptTokenCount")
197                            .and_then(|v| v.as_u64())
198                            .map(|v| v as u32)
199                            .unwrap_or(input_tokens);
200                        output_tokens = usage
201                            .get("candidatesTokenCount")
202                            .and_then(|v| v.as_u64())
203                            .map(|v| v as u32)
204                            .unwrap_or(output_tokens);
205                    }
206                    if let Some(candidate) = value.get("candidates").and_then(|c| c.get(0)) {
207                        if let Some(reason) =
208                            candidate.get("finishReason").and_then(|r| r.as_str())
209                        {
210                            finish_reason = reason.to_string();
211                        }
212                        if let Some(parts) = candidate
213                            .get("content")
214                            .and_then(|c| c.get("parts"))
215                            .and_then(|p| p.as_array())
216                        {
217                            for part in parts {
218                                if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
219                                    if !text.is_empty() {
220                                        full_text.push_str(text);
221                                        sink.emit(LlmStreamEvent {
222                                            call_id: call_id.clone(),
223                                            provider: NAME.to_string(),
224                                            model: req.model.clone(),
225                                            delta: text.to_string(),
226                                            done: false,
227                                            error: None,
228                                        })
229                                        .await;
230                                    }
231                                }
232                            }
233                        }
234                    }
235                }
236            }
237        }
238
239        if let Some(err) = errored {
240            sink.emit(LlmStreamEvent::error(&call_id, NAME, err.clone()))
241                .await;
242            return Err(CloudAiError::Http {
243                provider: NAME.into(),
244                reason: err,
245            });
246        }
247        sink.emit(LlmStreamEvent {
248            call_id: call_id.clone(),
249            provider: NAME.to_string(),
250            model: req.model.clone(),
251            delta: String::new(),
252            done: true,
253            error: None,
254        })
255        .await;
256        if full_text.is_empty() {
257            return Err(CloudAiError::Shape {
258                provider: NAME.into(),
259                detail: "stream closed before any text was emitted".into(),
260            });
261        }
262        let latency_ms = started.elapsed().as_millis() as u64;
263        Ok(CloudAiResponse {
264            provider: NAME.to_string(),
265            model: req.model.clone(),
266            text: full_text,
267            finish_reason,
268            input_tokens,
269            output_tokens,
270            latency_ms,
271        })
272    }
273
274    async fn invoke_tools(
275        &self,
276        req: &CloudAiRequest,
277        tools: &[ToolSpec],
278        dispatcher: &dyn ToolDispatcher,
279        max_iters: usize,
280    ) -> Result<CloudAiResponse, CloudAiError> {
281        let started = Instant::now();
282        let url = format!("{ENDPOINT_BASE}/{}:generateContent", req.model);
283        let mut contents = vec![json!({ "role": "user", "parts": user_parts(req) })];
284        let fn_decls: Vec<_> = tools
285            .iter()
286            .map(|t| {
287                json!({ "name": t.name, "description": t.description, "parameters": t.parameters })
288            })
289            .collect();
290        let tools_body = json!([{ "functionDeclarations": fn_decls }]);
291        let mut input_tokens: u32 = 0;
292        let mut output_tokens: u32 = 0;
293
294        for _ in 0..max_iters.max(1) {
295            let mut body = json!({ "contents": contents, "tools": tools_body });
296            if let Some(s) = req.system.as_deref().filter(|s| !s.is_empty()) {
297                body["system_instruction"] = json!({ "parts": [{ "text": s }] });
298            }
299            if let Some(cfg) = generation_config(req) {
300                body["generationConfig"] = cfg;
301            }
302
303            let resp = self
304                .client
305                .post(&url)
306                .header("content-type", "application/json")
307                .header("x-goog-api-key", &req.api_key)
308                .json(&body)
309                .send()
310                .await
311                .map_err(|e| CloudAiError::Http {
312                    provider: NAME.into(),
313                    reason: e.to_string(),
314                })?;
315            let status = resp.status();
316            let raw = resp.text().await.unwrap_or_default();
317            if !status.is_success() {
318                return Err(CloudAiError::Status {
319                    provider: NAME.into(),
320                    status: status.as_u16(),
321                    body: redact_error_body(&raw),
322                });
323            }
324            let v: Value = serde_json::from_str(&raw).map_err(|e| CloudAiError::Parse {
325                provider: NAME.into(),
326                reason: e.to_string(),
327            })?;
328            if let Some(u) = v.get("usageMetadata") {
329                input_tokens = input_tokens.saturating_add(
330                    u.get("promptTokenCount").and_then(|x| x.as_u64()).unwrap_or(0) as u32,
331                );
332                output_tokens = output_tokens.saturating_add(
333                    u.get("candidatesTokenCount").and_then(|x| x.as_u64()).unwrap_or(0) as u32,
334                );
335            }
336            let cand = v.get("candidates").and_then(|c| c.get(0));
337            let finish = cand
338                .and_then(|c| c.get("finishReason"))
339                .and_then(|r| r.as_str())
340                .unwrap_or("STOP")
341                .to_string();
342            let parts = cand
343                .and_then(|c| c.get("content"))
344                .and_then(|c| c.get("parts"))
345                .and_then(|p| p.as_array())
346                .cloned()
347                .unwrap_or_default();
348
349            let mut text = String::new();
350            let mut calls: Vec<(String, Value)> = Vec::new();
351            for part in &parts {
352                if let Some(t) = part.get("text").and_then(|t| t.as_str()) {
353                    text.push_str(t);
354                }
355                if let Some(fc) = part.get("functionCall") {
356                    calls.push((
357                        fc.get("name").and_then(|n| n.as_str()).unwrap_or("").to_string(),
358                        fc.get("args").cloned().unwrap_or_else(|| json!({})),
359                    ));
360                }
361            }
362
363            if calls.is_empty() {
364                if text.is_empty() {
365                    return Err(CloudAiError::Shape {
366                        provider: NAME.into(),
367                        detail: "empty candidate content".into(),
368                    });
369                }
370                return Ok(CloudAiResponse {
371                    provider: NAME.into(),
372                    model: req.model.clone(),
373                    text,
374                    finish_reason: finish,
375                    input_tokens,
376                    output_tokens,
377                    latency_ms: started.elapsed().as_millis() as u64,
378                });
379            }
380
381            // Echo the model's functionCall turn, then a user turn carrying each
382            // functionResponse (errors fed back so the model can recover).
383            contents.push(json!({ "role": "model", "parts": parts }));
384            let mut resp_parts = Vec::new();
385            for (name, args) in &calls {
386                let out = match dispatcher.call(name, args).await {
387                    Ok(v) => v,
388                    Err(e) => json!({ "error": e }),
389                };
390                resp_parts.push(json!({
391                    "functionResponse": { "name": name, "response": { "result": out } }
392                }));
393            }
394            contents.push(json!({ "role": "user", "parts": resp_parts }));
395        }
396
397        Err(CloudAiError::Shape {
398            provider: NAME.into(),
399            detail: format!("tool loop did not converge within {max_iters} iterations"),
400        })
401    }
402
403    async fn embed(&self, req: &CloudAiRequest) -> Result<EmbeddingResponse, CloudAiError> {
404        let url = format!("{ENDPOINT_BASE}/{}:embedContent", req.model);
405        let body = json!({
406            "model": format!("models/{}", req.model),
407            "content": { "parts": [{ "text": req.prompt }] },
408        });
409        let started = Instant::now();
410        let resp = self
411            .client
412            .post(&url)
413            .header("content-type", "application/json")
414            .header("x-goog-api-key", &req.api_key)
415            .json(&body)
416            .send()
417            .await
418            .map_err(|e| CloudAiError::Http {
419                provider: NAME.into(),
420                reason: e.to_string(),
421            })?;
422        let status = resp.status();
423        let raw = resp.text().await.unwrap_or_default();
424        let latency_ms = started.elapsed().as_millis() as u64;
425        if !status.is_success() {
426            return Err(CloudAiError::Status {
427                provider: NAME.into(),
428                status: status.as_u16(),
429                body: redact_error_body(&raw),
430            });
431        }
432        let v: Value = serde_json::from_str(&raw).map_err(|e| CloudAiError::Parse {
433            provider: NAME.into(),
434            reason: e.to_string(),
435        })?;
436        let embedding: Vec<f32> = v
437            .get("embedding")
438            .and_then(|e| e.get("values"))
439            .and_then(|vals| vals.as_array())
440            .map(|a| a.iter().filter_map(|x| x.as_f64().map(|f| f as f32)).collect())
441            .unwrap_or_default();
442        if embedding.is_empty() {
443            return Err(CloudAiError::Shape {
444                provider: NAME.into(),
445                detail: "no embedding values in response".into(),
446            });
447        }
448        let dims = embedding.len();
449        Ok(EmbeddingResponse {
450            provider: NAME.into(),
451            model: req.model.clone(),
452            embedding,
453            dims,
454            latency_ms,
455        })
456    }
457}
458
459/// User-content `parts`: the prompt text plus an image part per request image
460/// (`inlineData` for `data:` URLs, `fileData` for `http(s)` URLs).
461fn user_parts(req: &CloudAiRequest) -> Vec<serde_json::Value> {
462    let mut parts = vec![json!({ "text": req.prompt })];
463    for img in &req.images {
464        parts.push(image_part(img));
465    }
466    parts
467}
468
469fn image_part(img: &str) -> serde_json::Value {
470    if let Some(rest) = img.strip_prefix("data:") {
471        if let Some((mime, data)) = rest.split_once(";base64,") {
472            return json!({ "inlineData": { "mimeType": mime, "data": data } });
473        }
474    }
475    json!({ "fileData": { "fileUri": img } })
476}
477
478/// Build the `generationConfig` block (sampling + optional thinking budget), or
479/// `None` when empty.
480fn generation_config(req: &CloudAiRequest) -> Option<serde_json::Value> {
481    let mut config = serde_json::Map::new();
482    if let Some(m) = req.max_tokens {
483        config.insert("maxOutputTokens".into(), json!(m));
484    }
485    if let Some(t) = req.temperature {
486        config.insert("temperature".into(), json!(t));
487    }
488    if let Some(think) = req.reasoning {
489        // `thinkingBudget: 0` disables reasoning; a positive budget enables it
490        // (Gemini 2.5 thinking models).
491        config.insert(
492            "thinkingConfig".into(),
493            json!({ "thinkingBudget": if think { 1024 } else { 0 } }),
494        );
495    }
496    if config.is_empty() {
497        None
498    } else {
499        Some(serde_json::Value::Object(config))
500    }
501}
502
503pub fn build_body(req: &CloudAiRequest) -> serde_json::Value {
504    let mut payload = json!({
505        "contents": [{ "parts": user_parts(req) }],
506    });
507    if let Some(s) = req.system.as_deref().filter(|s| !s.is_empty()) {
508        // Gemini's v1beta API takes the system prompt as a top-level
509        // `system_instruction` containing the same `parts`/`text`
510        // shape used for `contents`. Sending it via this field (rather
511        // than concatenating into `contents`) is what unlocks the
512        // model's instruction-following behavior.
513        payload["system_instruction"] = json!({
514            "parts": [{ "text": s }]
515        });
516    }
517    if let Some(cfg) = generation_config(req) {
518        payload["generationConfig"] = cfg;
519    }
520    payload
521}
522
523#[derive(Debug, Deserialize)]
524struct RawResponse {
525    candidates: Vec<Candidate>,
526    #[serde(rename = "usageMetadata", default)]
527    usage_metadata: Option<UsageMetadata>,
528}
529
530#[derive(Debug, Deserialize)]
531struct Candidate {
532    content: Option<Content>,
533    #[serde(rename = "finishReason", default)]
534    finish_reason: Option<String>,
535}
536
537#[derive(Debug, Deserialize)]
538struct Content {
539    parts: Vec<Part>,
540}
541
542#[derive(Debug, Deserialize)]
543struct Part {
544    #[serde(default)]
545    text: Option<String>,
546}
547
548#[derive(Debug, Deserialize)]
549struct UsageMetadata {
550    #[serde(rename = "promptTokenCount", default)]
551    prompt_token_count: u32,
552    #[serde(rename = "candidatesTokenCount", default)]
553    candidates_token_count: u32,
554}
555
556pub fn parse_response(
557    body: &str,
558    model: &str,
559    latency_ms: u64,
560) -> Result<CloudAiResponse, CloudAiError> {
561    let raw: RawResponse = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
562        provider: NAME.into(),
563        reason: e.to_string(),
564    })?;
565
566    let candidate = raw
567        .candidates
568        .into_iter()
569        .next()
570        .ok_or(CloudAiError::Shape {
571            provider: NAME.into(),
572            detail: "no candidates in response".into(),
573        })?;
574
575    let text = candidate
576        .content
577        .map(|c| {
578            c.parts
579                .into_iter()
580                .filter_map(|p| p.text)
581                .collect::<Vec<_>>()
582                .join("")
583        })
584        .unwrap_or_default();
585
586    if text.is_empty() {
587        return Err(CloudAiError::Shape {
588            provider: NAME.into(),
589            detail: "empty candidate content".into(),
590        });
591    }
592
593    let (input, output) = match raw.usage_metadata {
594        Some(u) => (u.prompt_token_count, u.candidates_token_count),
595        None => (0, 0),
596    };
597
598    Ok(CloudAiResponse {
599        provider: NAME.into(),
600        model: model.to_string(),
601        text,
602        finish_reason: candidate.finish_reason.unwrap_or_else(|| "unknown".into()),
603        input_tokens: input,
604        output_tokens: output,
605        latency_ms,
606    })
607}