Update routers/gemini.py
Browse files- routers/gemini.py +101 -101
routers/gemini.py
CHANGED
@@ -1,101 +1,101 @@
|
|
1 |
-
from fastapi import APIRouter, Depends, HTTPException
|
2 |
-
from fastapi.responses import StreamingResponse
|
3 |
-
from pydantic import BaseModel
|
4 |
-
from typing import Annotated
|
5 |
-
from google import genai
|
6 |
-
from google.genai import types
|
7 |
-
from auth import verify_token
|
8 |
-
import os
|
9 |
-
import httpx
|
10 |
-
import base64
|
11 |
-
|
12 |
-
router = APIRouter(prefix="/gemini", tags=["gemini"])
|
13 |
-
|
14 |
-
gemini_key = os.environ.get('GEMINI_KEY', '') # Changed variable name for clarity
|
15 |
-
# Configure the genai client safely
|
16 |
-
if gemini_key:
|
17 |
-
genai.
|
18 |
-
else:
|
19 |
-
print("Warning: GEMINI_KEY environment variable not set.")
|
20 |
-
# Optionally raise an error or handle the missing key appropriately
|
21 |
-
# raise ValueError("GEMINI_KEY environment variable is required.")
|
22 |
-
|
23 |
-
|
24 |
-
class LLMRequest(BaseModel):
|
25 |
-
model: str
|
26 |
-
prompt: str
|
27 |
-
|
28 |
-
class GeminiMultimodalRequest(BaseModel):
|
29 |
-
model: str
|
30 |
-
prompt: str
|
31 |
-
image: str # url or base64
|
32 |
-
|
33 |
-
@router.post("/")
|
34 |
-
async def gemini_chat(request: LLMRequest, token: Annotated[str, Depends(verify_token)]):
|
35 |
-
if not gemini_key:
|
36 |
-
raise HTTPException(status_code=500, detail="Gemini API key not configured")
|
37 |
-
model = genai.GenerativeModel(request.model)
|
38 |
-
async def generate():
|
39 |
-
# Use generate_content_async for async streaming
|
40 |
-
response = await model.generate_content_async(
|
41 |
-
contents=[request.prompt], stream=True)
|
42 |
-
|
43 |
-
async for chunk in response:
|
44 |
-
if chunk.text:
|
45 |
-
yield chunk.text
|
46 |
-
|
47 |
-
return StreamingResponse(generate(), media_type="text/plain")
|
48 |
-
|
49 |
-
|
50 |
-
@router.post("/multimodal")
|
51 |
-
async def gemini_multimodal(request: GeminiMultimodalRequest, token: Annotated[str, Depends(verify_token)]):
|
52 |
-
if not gemini_key:
|
53 |
-
raise HTTPException(status_code=500, detail="Gemini API key not configured")
|
54 |
-
model = genai.GenerativeModel(request.model)
|
55 |
-
image_part = None
|
56 |
-
try:
|
57 |
-
if request.image.startswith('http'):
|
58 |
-
async with httpx.AsyncClient() as client:
|
59 |
-
img_response = await client.get(request.image)
|
60 |
-
img_response.raise_for_status() # Check if the request was successful
|
61 |
-
# Determine mime type dynamically if possible, default to jpeg
|
62 |
-
content_type = img_response.headers.get('Content-Type', 'image/jpeg')
|
63 |
-
image_part = types.Part.from_bytes(data=img_response.content, mime_type=content_type)
|
64 |
-
elif request.image.startswith('data:image'):
|
65 |
-
# Handle base64 data URI
|
66 |
-
header, encoded = request.image.split(',', 1)
|
67 |
-
mime_type = header.split(':')[1].split(';')[0]
|
68 |
-
image_data = base64.b64decode(encoded)
|
69 |
-
image_part = types.Part.from_bytes(data=image_data, mime_type=mime_type)
|
70 |
-
else:
|
71 |
-
# Assume raw base64
|
72 |
-
image_part = types.Part.from_bytes(data=base64.b64decode(request.image), mime_type="image/jpeg") # Default mime or raise error
|
73 |
-
|
74 |
-
except httpx.HTTPStatusError as e:
|
75 |
-
raise HTTPException(status_code=400, detail=f"Failed to fetch image from URL: {e.response.status_code}")
|
76 |
-
except (ValueError, TypeError, base64.binascii.Error) as e:
|
77 |
-
raise HTTPException(status_code=400, detail=f"Invalid image data: {e}")
|
78 |
-
except Exception as e: # Catch other potential errors during image processing
|
79 |
-
raise HTTPException(status_code=500, detail=f"Error processing image: {e}")
|
80 |
-
|
81 |
-
|
82 |
-
if image_part is None:
|
83 |
-
raise HTTPException(status_code=400, detail="Could not process image input.")
|
84 |
-
|
85 |
-
try:
|
86 |
-
response = await model.generate_content_async(
|
87 |
-
contents=[request.prompt, image_part]
|
88 |
-
)
|
89 |
-
# Access the text content safely
|
90 |
-
response_text = response.text if hasattr(response, 'text') else "No text content generated."
|
91 |
-
return {"response": response_text}
|
92 |
-
except Exception as e:
|
93 |
-
# Log the error for debugging
|
94 |
-
print(f"Error during Gemini content generation: {e}")
|
95 |
-
# Provide a more informative error message
|
96 |
-
error_detail = f"Error generating content with Gemini model: {e}"
|
97 |
-
# Check for specific API errors if the SDK provides them
|
98 |
-
if hasattr(e, 'message'):
|
99 |
-
error_detail = f"Gemini API Error: {e.message}"
|
100 |
-
raise HTTPException(status_code=500, detail=error_detail)
|
101 |
-
|
|
|
1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
2 |
+
from fastapi.responses import StreamingResponse
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from typing import Annotated
|
5 |
+
from google import genai
|
6 |
+
from google.genai import types
|
7 |
+
from auth import verify_token
|
8 |
+
import os
|
9 |
+
import httpx
|
10 |
+
import base64
|
11 |
+
|
12 |
+
router = APIRouter(prefix="/gemini", tags=["gemini"])
|
13 |
+
|
14 |
+
gemini_key = os.environ.get('GEMINI_KEY', '') # Changed variable name for clarity
|
15 |
+
# Configure the genai client safely
|
16 |
+
if gemini_key:
|
17 |
+
genai.Client(api_key=gemini_key)
|
18 |
+
else:
|
19 |
+
print("Warning: GEMINI_KEY environment variable not set.")
|
20 |
+
# Optionally raise an error or handle the missing key appropriately
|
21 |
+
# raise ValueError("GEMINI_KEY environment variable is required.")
|
22 |
+
|
23 |
+
|
24 |
+
class LLMRequest(BaseModel):
|
25 |
+
model: str
|
26 |
+
prompt: str
|
27 |
+
|
28 |
+
class GeminiMultimodalRequest(BaseModel):
|
29 |
+
model: str
|
30 |
+
prompt: str
|
31 |
+
image: str # url or base64
|
32 |
+
|
33 |
+
@router.post("/")
|
34 |
+
async def gemini_chat(request: LLMRequest, token: Annotated[str, Depends(verify_token)]):
|
35 |
+
if not gemini_key:
|
36 |
+
raise HTTPException(status_code=500, detail="Gemini API key not configured")
|
37 |
+
model = genai.GenerativeModel(request.model)
|
38 |
+
async def generate():
|
39 |
+
# Use generate_content_async for async streaming
|
40 |
+
response = await model.generate_content_async(
|
41 |
+
contents=[request.prompt], stream=True)
|
42 |
+
|
43 |
+
async for chunk in response:
|
44 |
+
if chunk.text:
|
45 |
+
yield chunk.text
|
46 |
+
|
47 |
+
return StreamingResponse(generate(), media_type="text/plain")
|
48 |
+
|
49 |
+
|
50 |
+
@router.post("/multimodal")
|
51 |
+
async def gemini_multimodal(request: GeminiMultimodalRequest, token: Annotated[str, Depends(verify_token)]):
|
52 |
+
if not gemini_key:
|
53 |
+
raise HTTPException(status_code=500, detail="Gemini API key not configured")
|
54 |
+
model = genai.GenerativeModel(request.model)
|
55 |
+
image_part = None
|
56 |
+
try:
|
57 |
+
if request.image.startswith('http'):
|
58 |
+
async with httpx.AsyncClient() as client:
|
59 |
+
img_response = await client.get(request.image)
|
60 |
+
img_response.raise_for_status() # Check if the request was successful
|
61 |
+
# Determine mime type dynamically if possible, default to jpeg
|
62 |
+
content_type = img_response.headers.get('Content-Type', 'image/jpeg')
|
63 |
+
image_part = types.Part.from_bytes(data=img_response.content, mime_type=content_type)
|
64 |
+
elif request.image.startswith('data:image'):
|
65 |
+
# Handle base64 data URI
|
66 |
+
header, encoded = request.image.split(',', 1)
|
67 |
+
mime_type = header.split(':')[1].split(';')[0]
|
68 |
+
image_data = base64.b64decode(encoded)
|
69 |
+
image_part = types.Part.from_bytes(data=image_data, mime_type=mime_type)
|
70 |
+
else:
|
71 |
+
# Assume raw base64
|
72 |
+
image_part = types.Part.from_bytes(data=base64.b64decode(request.image), mime_type="image/jpeg") # Default mime or raise error
|
73 |
+
|
74 |
+
except httpx.HTTPStatusError as e:
|
75 |
+
raise HTTPException(status_code=400, detail=f"Failed to fetch image from URL: {e.response.status_code}")
|
76 |
+
except (ValueError, TypeError, base64.binascii.Error) as e:
|
77 |
+
raise HTTPException(status_code=400, detail=f"Invalid image data: {e}")
|
78 |
+
except Exception as e: # Catch other potential errors during image processing
|
79 |
+
raise HTTPException(status_code=500, detail=f"Error processing image: {e}")
|
80 |
+
|
81 |
+
|
82 |
+
if image_part is None:
|
83 |
+
raise HTTPException(status_code=400, detail="Could not process image input.")
|
84 |
+
|
85 |
+
try:
|
86 |
+
response = await model.generate_content_async(
|
87 |
+
contents=[request.prompt, image_part]
|
88 |
+
)
|
89 |
+
# Access the text content safely
|
90 |
+
response_text = response.text if hasattr(response, 'text') else "No text content generated."
|
91 |
+
return {"response": response_text}
|
92 |
+
except Exception as e:
|
93 |
+
# Log the error for debugging
|
94 |
+
print(f"Error during Gemini content generation: {e}")
|
95 |
+
# Provide a more informative error message
|
96 |
+
error_detail = f"Error generating content with Gemini model: {e}"
|
97 |
+
# Check for specific API errors if the SDK provides them
|
98 |
+
if hasattr(e, 'message'):
|
99 |
+
error_detail = f"Gemini API Error: {e.message}"
|
100 |
+
raise HTTPException(status_code=500, detail=error_detail)
|
101 |
+
|