1use async_trait::async_trait;
2use futures_util::StreamExt;
3use serde::Deserialize;
4use serde_json::{json, Value};
5use std::time::Instant;
6
7use crate::error::{redact_error_body, CloudAiError};
8use crate::registry::{CloudAiProvider, ModelCapabilityInfo};
9use crate::request::{
10 CloudAiRequest, CloudAiResponse, EmbeddingResponse, ToolDispatcher, ToolSpec,
11};
12use crate::stream::{LlmStreamEvent, LlmStreamSink};
13
14const NAME: &str = "gemini";
15const ENV_VAR: &str = "GOOGLE_API_KEY";
16const ENDPOINT_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
17
18const DEFAULT_MODELS: &[&str] = &[
24 "gemini-2.5-pro",
25 "gemini-2.5-flash",
26 "gemini-2.0-flash",
27 "gemini-1.5-flash",
28];
29
30pub struct GeminiProvider {
31 client: reqwest::Client,
32}
33
34impl GeminiProvider {
35 pub fn new() -> Self {
36 Self {
37 client: reqwest::Client::new(),
38 }
39 }
40}
41
42impl Default for GeminiProvider {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48#[async_trait]
49impl CloudAiProvider for GeminiProvider {
50 fn name(&self) -> &str {
51 NAME
52 }
53
54 fn env_var(&self) -> &str {
55 ENV_VAR
56 }
57
58 fn default_models(&self) -> &[&str] {
59 DEFAULT_MODELS
60 }
61
62 fn model_capabilities(&self) -> Vec<ModelCapabilityInfo> {
63 vec![
64 ModelCapabilityInfo::new(
65 "gemini-2.5-pro",
66 &["reasoning", "vision", "tool_use", "code"],
67 ),
68 ModelCapabilityInfo::new(
69 "gemini-2.5-flash",
70 &["reasoning", "vision", "tool_use", "code"],
71 ),
72 ModelCapabilityInfo::new("gemini-2.0-flash", &["vision", "tool_use", "code"]),
73 ModelCapabilityInfo::new("gemini-1.5-flash", &["vision", "tool_use", "code"]),
74 ]
75 }
76
77 async fn invoke(&self, req: &CloudAiRequest) -> Result<CloudAiResponse, CloudAiError> {
78 let url = format!("{ENDPOINT_BASE}/{}:generateContent", req.model);
85 let body = build_body(req);
86
87 let started = Instant::now();
88 let resp = self
89 .client
90 .post(&url)
91 .header("content-type", "application/json")
92 .header("x-goog-api-key", &req.api_key)
93 .json(&body)
94 .send()
95 .await
96 .map_err(|e| CloudAiError::Http {
97 provider: NAME.into(),
98 reason: e.to_string(),
99 })?;
100
101 let status = resp.status();
102 let raw_body = resp.text().await.map_err(|e| CloudAiError::Http {
103 provider: NAME.into(),
104 reason: e.to_string(),
105 })?;
106 let latency_ms = started.elapsed().as_millis() as u64;
107
108 if !status.is_success() {
109 return Err(CloudAiError::Status {
110 provider: NAME.into(),
111 status: status.as_u16(),
112 body: redact_error_body(&raw_body),
113 });
114 }
115
116 parse_response(&raw_body, &req.model, latency_ms)
117 }
118
119 async fn invoke_stream(
120 &self,
121 req: &CloudAiRequest,
122 sink: &dyn LlmStreamSink,
123 ) -> Result<CloudAiResponse, CloudAiError> {
124 let call_id = req.call_id.clone().unwrap_or_default();
125 let url = format!(
129 "{ENDPOINT_BASE}/{}:streamGenerateContent?alt=sse",
130 req.model
131 );
132 let body = build_body(req);
133 let started = Instant::now();
134 let resp = self
135 .client
136 .post(&url)
137 .header("content-type", "application/json")
138 .header("accept", "text/event-stream")
139 .header("x-goog-api-key", &req.api_key)
140 .json(&body)
141 .send()
142 .await
143 .map_err(|e| CloudAiError::Http {
144 provider: NAME.into(),
145 reason: e.to_string(),
146 })?;
147
148 let status = resp.status();
149 if !status.is_success() {
150 let raw = resp.text().await.unwrap_or_default();
151 let err = CloudAiError::Status {
152 provider: NAME.into(),
153 status: status.as_u16(),
154 body: redact_error_body(&raw),
155 };
156 sink.emit(LlmStreamEvent::error(&call_id, NAME, err.to_string()))
157 .await;
158 return Err(err);
159 }
160
161 let mut stream = resp.bytes_stream();
162 let mut buffer = String::new();
163 let mut full_text = String::new();
164 let mut finish_reason = String::from("unknown");
165 let mut input_tokens: u32 = 0;
166 let mut output_tokens: u32 = 0;
167 let mut errored: Option<String> = None;
168
169 while let Some(chunk) = stream.next().await {
170 let bytes = match chunk {
171 Ok(b) => b,
172 Err(e) => {
173 errored = Some(e.to_string());
174 break;
175 }
176 };
177 buffer.push_str(&String::from_utf8_lossy(&bytes));
178 while let Some(idx) = buffer.find("\n\n") {
179 let frame = buffer[..idx].to_string();
180 buffer.drain(..idx + 2);
181 for raw_line in frame.split('\n') {
182 let line = raw_line.trim_start();
183 let Some(payload) = line.strip_prefix("data:") else {
184 continue;
185 };
186 let payload = payload.trim();
187 if payload.is_empty() {
188 continue;
189 }
190 let value: serde_json::Value = match serde_json::from_str(payload) {
191 Ok(v) => v,
192 Err(_) => continue,
193 };
194 if let Some(usage) = value.get("usageMetadata") {
195 input_tokens = usage
196 .get("promptTokenCount")
197 .and_then(|v| v.as_u64())
198 .map(|v| v as u32)
199 .unwrap_or(input_tokens);
200 output_tokens = usage
201 .get("candidatesTokenCount")
202 .and_then(|v| v.as_u64())
203 .map(|v| v as u32)
204 .unwrap_or(output_tokens);
205 }
206 if let Some(candidate) = value.get("candidates").and_then(|c| c.get(0)) {
207 if let Some(reason) =
208 candidate.get("finishReason").and_then(|r| r.as_str())
209 {
210 finish_reason = reason.to_string();
211 }
212 if let Some(parts) = candidate
213 .get("content")
214 .and_then(|c| c.get("parts"))
215 .and_then(|p| p.as_array())
216 {
217 for part in parts {
218 if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
219 if !text.is_empty() {
220 full_text.push_str(text);
221 sink.emit(LlmStreamEvent {
222 call_id: call_id.clone(),
223 provider: NAME.to_string(),
224 model: req.model.clone(),
225 delta: text.to_string(),
226 done: false,
227 error: None,
228 })
229 .await;
230 }
231 }
232 }
233 }
234 }
235 }
236 }
237 }
238
239 if let Some(err) = errored {
240 sink.emit(LlmStreamEvent::error(&call_id, NAME, err.clone()))
241 .await;
242 return Err(CloudAiError::Http {
243 provider: NAME.into(),
244 reason: err,
245 });
246 }
247 sink.emit(LlmStreamEvent {
248 call_id: call_id.clone(),
249 provider: NAME.to_string(),
250 model: req.model.clone(),
251 delta: String::new(),
252 done: true,
253 error: None,
254 })
255 .await;
256 if full_text.is_empty() {
257 return Err(CloudAiError::Shape {
258 provider: NAME.into(),
259 detail: "stream closed before any text was emitted".into(),
260 });
261 }
262 let latency_ms = started.elapsed().as_millis() as u64;
263 Ok(CloudAiResponse {
264 provider: NAME.to_string(),
265 model: req.model.clone(),
266 text: full_text,
267 finish_reason,
268 input_tokens,
269 output_tokens,
270 latency_ms,
271 })
272 }
273
274 async fn invoke_tools(
275 &self,
276 req: &CloudAiRequest,
277 tools: &[ToolSpec],
278 dispatcher: &dyn ToolDispatcher,
279 max_iters: usize,
280 ) -> Result<CloudAiResponse, CloudAiError> {
281 let started = Instant::now();
282 let url = format!("{ENDPOINT_BASE}/{}:generateContent", req.model);
283 let mut contents = vec![json!({ "role": "user", "parts": user_parts(req) })];
284 let fn_decls: Vec<_> = tools
285 .iter()
286 .map(|t| {
287 json!({ "name": t.name, "description": t.description, "parameters": t.parameters })
288 })
289 .collect();
290 let tools_body = json!([{ "functionDeclarations": fn_decls }]);
291 let mut input_tokens: u32 = 0;
292 let mut output_tokens: u32 = 0;
293
294 for _ in 0..max_iters.max(1) {
295 let mut body = json!({ "contents": contents, "tools": tools_body });
296 if let Some(s) = req.system.as_deref().filter(|s| !s.is_empty()) {
297 body["system_instruction"] = json!({ "parts": [{ "text": s }] });
298 }
299 if let Some(cfg) = generation_config(req) {
300 body["generationConfig"] = cfg;
301 }
302
303 let resp = self
304 .client
305 .post(&url)
306 .header("content-type", "application/json")
307 .header("x-goog-api-key", &req.api_key)
308 .json(&body)
309 .send()
310 .await
311 .map_err(|e| CloudAiError::Http {
312 provider: NAME.into(),
313 reason: e.to_string(),
314 })?;
315 let status = resp.status();
316 let raw = resp.text().await.unwrap_or_default();
317 if !status.is_success() {
318 return Err(CloudAiError::Status {
319 provider: NAME.into(),
320 status: status.as_u16(),
321 body: redact_error_body(&raw),
322 });
323 }
324 let v: Value = serde_json::from_str(&raw).map_err(|e| CloudAiError::Parse {
325 provider: NAME.into(),
326 reason: e.to_string(),
327 })?;
328 if let Some(u) = v.get("usageMetadata") {
329 input_tokens = input_tokens.saturating_add(
330 u.get("promptTokenCount").and_then(|x| x.as_u64()).unwrap_or(0) as u32,
331 );
332 output_tokens = output_tokens.saturating_add(
333 u.get("candidatesTokenCount").and_then(|x| x.as_u64()).unwrap_or(0) as u32,
334 );
335 }
336 let cand = v.get("candidates").and_then(|c| c.get(0));
337 let finish = cand
338 .and_then(|c| c.get("finishReason"))
339 .and_then(|r| r.as_str())
340 .unwrap_or("STOP")
341 .to_string();
342 let parts = cand
343 .and_then(|c| c.get("content"))
344 .and_then(|c| c.get("parts"))
345 .and_then(|p| p.as_array())
346 .cloned()
347 .unwrap_or_default();
348
349 let mut text = String::new();
350 let mut calls: Vec<(String, Value)> = Vec::new();
351 for part in &parts {
352 if let Some(t) = part.get("text").and_then(|t| t.as_str()) {
353 text.push_str(t);
354 }
355 if let Some(fc) = part.get("functionCall") {
356 calls.push((
357 fc.get("name").and_then(|n| n.as_str()).unwrap_or("").to_string(),
358 fc.get("args").cloned().unwrap_or_else(|| json!({})),
359 ));
360 }
361 }
362
363 if calls.is_empty() {
364 if text.is_empty() {
365 return Err(CloudAiError::Shape {
366 provider: NAME.into(),
367 detail: "empty candidate content".into(),
368 });
369 }
370 return Ok(CloudAiResponse {
371 provider: NAME.into(),
372 model: req.model.clone(),
373 text,
374 finish_reason: finish,
375 input_tokens,
376 output_tokens,
377 latency_ms: started.elapsed().as_millis() as u64,
378 });
379 }
380
381 contents.push(json!({ "role": "model", "parts": parts }));
384 let mut resp_parts = Vec::new();
385 for (name, args) in &calls {
386 let out = match dispatcher.call(name, args).await {
387 Ok(v) => v,
388 Err(e) => json!({ "error": e }),
389 };
390 resp_parts.push(json!({
391 "functionResponse": { "name": name, "response": { "result": out } }
392 }));
393 }
394 contents.push(json!({ "role": "user", "parts": resp_parts }));
395 }
396
397 Err(CloudAiError::Shape {
398 provider: NAME.into(),
399 detail: format!("tool loop did not converge within {max_iters} iterations"),
400 })
401 }
402
403 async fn embed(&self, req: &CloudAiRequest) -> Result<EmbeddingResponse, CloudAiError> {
404 let url = format!("{ENDPOINT_BASE}/{}:embedContent", req.model);
405 let body = json!({
406 "model": format!("models/{}", req.model),
407 "content": { "parts": [{ "text": req.prompt }] },
408 });
409 let started = Instant::now();
410 let resp = self
411 .client
412 .post(&url)
413 .header("content-type", "application/json")
414 .header("x-goog-api-key", &req.api_key)
415 .json(&body)
416 .send()
417 .await
418 .map_err(|e| CloudAiError::Http {
419 provider: NAME.into(),
420 reason: e.to_string(),
421 })?;
422 let status = resp.status();
423 let raw = resp.text().await.unwrap_or_default();
424 let latency_ms = started.elapsed().as_millis() as u64;
425 if !status.is_success() {
426 return Err(CloudAiError::Status {
427 provider: NAME.into(),
428 status: status.as_u16(),
429 body: redact_error_body(&raw),
430 });
431 }
432 let v: Value = serde_json::from_str(&raw).map_err(|e| CloudAiError::Parse {
433 provider: NAME.into(),
434 reason: e.to_string(),
435 })?;
436 let embedding: Vec<f32> = v
437 .get("embedding")
438 .and_then(|e| e.get("values"))
439 .and_then(|vals| vals.as_array())
440 .map(|a| a.iter().filter_map(|x| x.as_f64().map(|f| f as f32)).collect())
441 .unwrap_or_default();
442 if embedding.is_empty() {
443 return Err(CloudAiError::Shape {
444 provider: NAME.into(),
445 detail: "no embedding values in response".into(),
446 });
447 }
448 let dims = embedding.len();
449 Ok(EmbeddingResponse {
450 provider: NAME.into(),
451 model: req.model.clone(),
452 embedding,
453 dims,
454 latency_ms,
455 })
456 }
457}
458
459fn user_parts(req: &CloudAiRequest) -> Vec<serde_json::Value> {
462 let mut parts = vec![json!({ "text": req.prompt })];
463 for img in &req.images {
464 parts.push(image_part(img));
465 }
466 parts
467}
468
469fn image_part(img: &str) -> serde_json::Value {
470 if let Some(rest) = img.strip_prefix("data:") {
471 if let Some((mime, data)) = rest.split_once(";base64,") {
472 return json!({ "inlineData": { "mimeType": mime, "data": data } });
473 }
474 }
475 json!({ "fileData": { "fileUri": img } })
476}
477
478fn generation_config(req: &CloudAiRequest) -> Option<serde_json::Value> {
481 let mut config = serde_json::Map::new();
482 if let Some(m) = req.max_tokens {
483 config.insert("maxOutputTokens".into(), json!(m));
484 }
485 if let Some(t) = req.temperature {
486 config.insert("temperature".into(), json!(t));
487 }
488 if let Some(think) = req.reasoning {
489 config.insert(
492 "thinkingConfig".into(),
493 json!({ "thinkingBudget": if think { 1024 } else { 0 } }),
494 );
495 }
496 if config.is_empty() {
497 None
498 } else {
499 Some(serde_json::Value::Object(config))
500 }
501}
502
503pub fn build_body(req: &CloudAiRequest) -> serde_json::Value {
504 let mut payload = json!({
505 "contents": [{ "parts": user_parts(req) }],
506 });
507 if let Some(s) = req.system.as_deref().filter(|s| !s.is_empty()) {
508 payload["system_instruction"] = json!({
514 "parts": [{ "text": s }]
515 });
516 }
517 if let Some(cfg) = generation_config(req) {
518 payload["generationConfig"] = cfg;
519 }
520 payload
521}
522
523#[derive(Debug, Deserialize)]
524struct RawResponse {
525 candidates: Vec<Candidate>,
526 #[serde(rename = "usageMetadata", default)]
527 usage_metadata: Option<UsageMetadata>,
528}
529
530#[derive(Debug, Deserialize)]
531struct Candidate {
532 content: Option<Content>,
533 #[serde(rename = "finishReason", default)]
534 finish_reason: Option<String>,
535}
536
537#[derive(Debug, Deserialize)]
538struct Content {
539 parts: Vec<Part>,
540}
541
542#[derive(Debug, Deserialize)]
543struct Part {
544 #[serde(default)]
545 text: Option<String>,
546}
547
548#[derive(Debug, Deserialize)]
549struct UsageMetadata {
550 #[serde(rename = "promptTokenCount", default)]
551 prompt_token_count: u32,
552 #[serde(rename = "candidatesTokenCount", default)]
553 candidates_token_count: u32,
554}
555
556pub fn parse_response(
557 body: &str,
558 model: &str,
559 latency_ms: u64,
560) -> Result<CloudAiResponse, CloudAiError> {
561 let raw: RawResponse = serde_json::from_str(body).map_err(|e| CloudAiError::Parse {
562 provider: NAME.into(),
563 reason: e.to_string(),
564 })?;
565
566 let candidate = raw
567 .candidates
568 .into_iter()
569 .next()
570 .ok_or(CloudAiError::Shape {
571 provider: NAME.into(),
572 detail: "no candidates in response".into(),
573 })?;
574
575 let text = candidate
576 .content
577 .map(|c| {
578 c.parts
579 .into_iter()
580 .filter_map(|p| p.text)
581 .collect::<Vec<_>>()
582 .join("")
583 })
584 .unwrap_or_default();
585
586 if text.is_empty() {
587 return Err(CloudAiError::Shape {
588 provider: NAME.into(),
589 detail: "empty candidate content".into(),
590 });
591 }
592
593 let (input, output) = match raw.usage_metadata {
594 Some(u) => (u.prompt_token_count, u.candidates_token_count),
595 None => (0, 0),
596 };
597
598 Ok(CloudAiResponse {
599 provider: NAME.into(),
600 model: model.to_string(),
601 text,
602 finish_reason: candidate.finish_reason.unwrap_or_else(|| "unknown".into()),
603 input_tokens: input,
604 output_tokens: output,
605 latency_ms,
606 })
607}