Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,29 +1,25 @@
|
|
1 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
|
2 |
from fastapi.staticfiles import StaticFiles
|
3 |
from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
|
4 |
-
from
|
5 |
-
from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering
|
6 |
from typing import Optional, Dict, Any, List
|
7 |
import logging
|
8 |
import time
|
9 |
import os
|
10 |
import io
|
11 |
import json
|
|
|
12 |
from PIL import Image
|
13 |
from docx import Document
|
14 |
import fitz # PyMuPDF
|
15 |
import pandas as pd
|
16 |
from functools import lru_cache
|
17 |
-
import re
|
18 |
import torch
|
19 |
import numpy as np
|
20 |
from pydantic import BaseModel
|
21 |
import asyncio
|
22 |
import google.generativeai as genai
|
23 |
|
24 |
-
# Set the TRANSFORMERS_CACHE environment variable to a writable directory
|
25 |
-
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
|
26 |
-
|
27 |
# Configure logging
|
28 |
logging.basicConfig(
|
29 |
level=logging.INFO,
|
@@ -37,15 +33,18 @@ os.makedirs(upload_dir, exist_ok=True)
|
|
37 |
|
38 |
app = FastAPI(
|
39 |
title="Cosmic AI Assistant",
|
40 |
-
description="An advanced AI assistant with space-themed interface and
|
41 |
version="2.0.0"
|
42 |
)
|
43 |
|
44 |
-
# Mount static files
|
45 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
46 |
|
47 |
-
#
|
48 |
-
|
|
|
|
|
|
|
49 |
|
50 |
# Gemini API Configuration
|
51 |
API_KEY = "AIzaSyCwmgD8KxzWiuivtySNtcZF_rfTvx9s9sY" # Replace with your actual API key
|
@@ -55,14 +54,32 @@ genai.configure(api_key=API_KEY)
|
|
55 |
MODELS = {
|
56 |
"summarization": "sshleifer/distilbart-cnn-12-6",
|
57 |
"image-to-text": "Salesforce/blip-image-captioning-large",
|
58 |
-
"question-answering": "deepset/roberta-base-squad2",
|
59 |
"visual-qa": "dandelin/vilt-b32-finetuned-vqa",
|
60 |
-
"
|
61 |
-
"
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
}
|
64 |
|
65 |
-
#
|
|
|
|
|
|
|
|
|
66 |
@lru_cache(maxsize=8)
|
67 |
def load_model(task: str, model_name: str = None):
|
68 |
"""Cached model loader with proper task names and error handling"""
|
@@ -72,11 +89,8 @@ def load_model(task: str, model_name: str = None):
|
|
72 |
|
73 |
model_to_load = model_name or MODELS.get(task)
|
74 |
|
75 |
-
if task
|
76 |
-
|
77 |
-
model = genai.GenerativeModel(model_to_load)
|
78 |
-
logger.info(f"Gemini model loaded in {time.time() - start_time:.2f}s")
|
79 |
-
return model
|
80 |
|
81 |
if task == "visual-qa":
|
82 |
processor = ViltProcessor.from_pretrained(model_to_load)
|
@@ -97,48 +111,96 @@ def load_model(task: str, model_name: str = None):
|
|
97 |
logger.info(f"VQA raw output: {answer}")
|
98 |
return answer
|
99 |
|
100 |
-
logger.info(f"Visual QA model loaded in {time.time() - start_time:.2f}s")
|
101 |
return vqa_function
|
102 |
|
103 |
-
|
104 |
-
logger.info(f"Pipeline model loaded in {time.time() - start_time:.2f}s")
|
105 |
-
return model
|
106 |
|
107 |
except Exception as e:
|
108 |
logger.error(f"Model load failed: {str(e)}")
|
109 |
raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
|
110 |
|
111 |
-
def
|
112 |
-
"""Function to generate response with Gemini"""
|
113 |
if not user_input:
|
114 |
-
return "Please
|
115 |
try:
|
116 |
chatbot = load_model("chatbot")
|
117 |
-
|
|
|
|
|
|
|
|
|
118 |
return response.text.strip()
|
119 |
except Exception as e:
|
120 |
return f"Error: {str(e)}"
|
121 |
|
122 |
def translate_text(text: str, target_language: str):
|
123 |
-
"""Translate text to any target language using
|
124 |
if not text:
|
125 |
return "Please provide text to translate."
|
|
|
126 |
try:
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
except Exception as e:
|
|
|
132 |
return f"Translation error: {str(e)}"
|
133 |
|
134 |
def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
135 |
-
"""Enhanced intent detection with dynamic translation support
|
136 |
target_language = "English" # Default
|
137 |
|
138 |
if file:
|
139 |
content_type = file.content_type.lower() if file.content_type else ""
|
140 |
filename = file.filename.lower() if file.filename else ""
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
if content_type.startswith('image/'):
|
143 |
if text and any(q in text.lower() for q in ['what is', 'what\'s', 'describe', 'tell me about', 'explain','how many', 'what color', 'is there', 'are they', 'does the']):
|
144 |
return "visual-qa", target_language
|
@@ -156,11 +218,22 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
|
156 |
if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']):
|
157 |
return "chatbot", target_language
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
vqa_patterns = [
|
166 |
r'how (many|much)',
|
@@ -184,15 +257,6 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
|
184 |
|
185 |
if any(re.search(pattern, text_lower) for pattern in summarization_patterns):
|
186 |
return "summarize", target_language
|
187 |
-
|
188 |
-
question_patterns = [
|
189 |
-
r'\b(what|when|where|why|how|who|which)\b',
|
190 |
-
r'\?',
|
191 |
-
r'\b(explain|tell me|describe|define)\b'
|
192 |
-
]
|
193 |
-
|
194 |
-
if any(re.search(pattern, text_lower) for pattern in question_patterns):
|
195 |
-
return "question-answering", target_language
|
196 |
|
197 |
generation_patterns = [
|
198 |
r'\b(write|generate|create|compose)\b',
|
@@ -205,6 +269,12 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
|
205 |
if len(text) > 100:
|
206 |
return "summarize", target_language
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
return "chatbot", target_language
|
209 |
|
210 |
class ProcessResponse(BaseModel):
|
@@ -212,85 +282,10 @@ class ProcessResponse(BaseModel):
|
|
212 |
type: str
|
213 |
additional_data: Optional[Dict[str, Any]] = None
|
214 |
|
215 |
-
|
216 |
-
@app.get("/chatbot", response_class=HTMLResponse)
|
217 |
async def chatbot_interface():
|
218 |
-
|
219 |
-
|
220 |
-
<html lang="en">
|
221 |
-
<head>
|
222 |
-
<meta charset="UTF-8">
|
223 |
-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
224 |
-
<title>Cosmic AI Chatbot</title>
|
225 |
-
<style>
|
226 |
-
body { font-family: Arial, sans-serif; background: #282c34; color: white; text-align: center; }
|
227 |
-
.chat-box { width: 50%; margin: 20px auto; background: #444; padding: 20px; border-radius: 10px; }
|
228 |
-
#chat-box { max-height: 400px; overflow-y: auto; text-align: left; }
|
229 |
-
input, button, select { width: 80%; padding: 10px; margin: 5px; border-radius: 5px; }
|
230 |
-
button { background: #0084ff; color: white; cursor: pointer; }
|
231 |
-
.message { margin: 10px 0; padding: 10px; border-radius: 5px; }
|
232 |
-
.user-message { background: #555; }
|
233 |
-
.bot-message { background: #666; }
|
234 |
-
</style>
|
235 |
-
</head>
|
236 |
-
<body>
|
237 |
-
<div class="chat-box">
|
238 |
-
<h1>Cosmic AI Chatbot</h1>
|
239 |
-
<div id="chat-box">
|
240 |
-
<div class="message bot-message">
|
241 |
-
<b>Bot:</b> Hello! I am your Cosmic AI Assistant. Upload a file or ask a question, and I can:<br>
|
242 |
-
- Summarize documents<br>
|
243 |
-
- Describe images<br>
|
244 |
-
- Answer your questions<br>
|
245 |
-
- Translate text to any language<br>
|
246 |
-
- Generate visualization code
|
247 |
-
</div>
|
248 |
-
</div>
|
249 |
-
<input type="text" id="user-input" placeholder="Type your message...">
|
250 |
-
<select id="translate-to">
|
251 |
-
<option value="">No translation</option>
|
252 |
-
<option value="English">English</option>
|
253 |
-
<option value="French">French</option>
|
254 |
-
<option value="German">German</option>
|
255 |
-
<option value="Spanish">Spanish</option>
|
256 |
-
<option value="Italian">Italian</option>
|
257 |
-
<option value="Russian">Russian</option>
|
258 |
-
<option value="Chinese">Chinese</option>
|
259 |
-
<option value="Japanese">Japanese</option>
|
260 |
-
</select>
|
261 |
-
<button onclick="sendMessage()">Send</button>
|
262 |
-
</div>
|
263 |
-
<script>
|
264 |
-
async function sendMessage() {
|
265 |
-
let inputField = document.getElementById("user-input");
|
266 |
-
let translateTo = document.getElementById("translate-to").value;
|
267 |
-
let chatBox = document.getElementById("chat-box");
|
268 |
-
let userMessage = inputField.value.trim();
|
269 |
-
if (!userMessage) return;
|
270 |
-
|
271 |
-
let messageToSend = translateTo ? `Translate this to ${translateTo}: ${userMessage}` : userMessage;
|
272 |
-
|
273 |
-
chatBox.innerHTML += `<div class="message user-message"><b>You:</b> ${userMessage}</div>`;
|
274 |
-
let response = await fetch("/chat", {
|
275 |
-
method: "POST",
|
276 |
-
headers: { "Content-Type": "application/json" },
|
277 |
-
body: JSON.stringify({ message: messageToSend })
|
278 |
-
});
|
279 |
-
let result = await response.json();
|
280 |
-
chatBox.innerHTML += `<div class="message bot-message"><b>Bot:</b> ${result.response}</div>`;
|
281 |
-
inputField.value = "";
|
282 |
-
chatBox.scrollTop = chatBox.scrollHeight;
|
283 |
-
}
|
284 |
-
|
285 |
-
document.getElementById("user-input").addEventListener("keypress", function(e) {
|
286 |
-
if (e.key === "Enter") {
|
287 |
-
sendMessage();
|
288 |
-
}
|
289 |
-
});
|
290 |
-
</script>
|
291 |
-
</body>
|
292 |
-
</html>
|
293 |
-
"""
|
294 |
|
295 |
@app.post("/chat")
|
296 |
async def chat_endpoint(data: dict):
|
@@ -298,7 +293,7 @@ async def chat_endpoint(data: dict):
|
|
298 |
if not message:
|
299 |
raise HTTPException(status_code=400, detail="No message provided")
|
300 |
try:
|
301 |
-
response =
|
302 |
return {"response": response}
|
303 |
except Exception as e:
|
304 |
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
|
@@ -319,14 +314,30 @@ async def process_input(
|
|
319 |
|
320 |
try:
|
321 |
if intent == "chatbot":
|
322 |
-
response =
|
323 |
return {"response": response, "type": "chat"}
|
324 |
|
325 |
elif intent == "translate":
|
326 |
content = await extract_text_from_file(file) if file else text
|
327 |
-
|
328 |
-
|
329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
elif intent == "summarize":
|
332 |
content = await extract_text_from_file(file) if file else text
|
@@ -364,28 +375,6 @@ async def process_input(
|
|
364 |
final_summary = re.sub(r'\s+', ' ', final_summary).strip()
|
365 |
return {"response": final_summary, "type": "summary"}
|
366 |
|
367 |
-
elif intent == "question-answering":
|
368 |
-
context = await extract_text_from_file(file) if file else None
|
369 |
-
|
370 |
-
if not context and not text:
|
371 |
-
raise HTTPException(status_code=400, detail="No context provided")
|
372 |
-
|
373 |
-
qa_pipeline = load_model("question-answering")
|
374 |
-
|
375 |
-
if not context and "?" in text:
|
376 |
-
parts = text.split("?", 1)
|
377 |
-
question = parts[0] + "?"
|
378 |
-
context = parts[1].strip() if len(parts) > 1 and parts[1].strip() else text
|
379 |
-
else:
|
380 |
-
question = text if text else "Summarize this document"
|
381 |
-
|
382 |
-
result = qa_pipeline(
|
383 |
-
question=question,
|
384 |
-
context=context[:2000] if context else text[:2000]
|
385 |
-
)
|
386 |
-
|
387 |
-
return {"response": result["answer"], "type": "answer"}
|
388 |
-
|
389 |
elif intent == "image-to-text":
|
390 |
if not file or not file.content_type.startswith('image/'):
|
391 |
raise HTTPException(status_code=400, detail="An image file is required")
|
@@ -423,6 +412,11 @@ async def process_input(
|
|
423 |
answer = answer.capitalize()
|
424 |
if not answer.endswith(('.', '!', '?')):
|
425 |
answer += '.'
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
logger.info(f"Final VQA answer: {answer}")
|
428 |
|
@@ -447,25 +441,19 @@ async def process_input(
|
|
447 |
df = pd.read_excel(io.BytesIO(file_content))
|
448 |
|
449 |
code = generate_visualization_code(df, text)
|
|
|
|
|
450 |
|
451 |
-
return {"response":
|
452 |
|
453 |
elif intent == "text-generation":
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
max_length=200,
|
459 |
-
num_return_sequences=1,
|
460 |
-
temperature=0.8,
|
461 |
-
top_p=0.92,
|
462 |
-
do_sample=True
|
463 |
-
)
|
464 |
-
|
465 |
-
return {"response": generated[0]["generated_text"], "type": "generated_text"}
|
466 |
|
467 |
else:
|
468 |
-
response =
|
469 |
return {"response": response, "type": "chat"}
|
470 |
|
471 |
except Exception as e:
|
@@ -476,42 +464,51 @@ async def process_input(
|
|
476 |
logger.info(f"Request processed in {process_time:.2f} seconds")
|
477 |
|
478 |
async def extract_text_from_file(file: UploadFile) -> str:
|
479 |
-
"""Enhanced text extraction with
|
480 |
if not file:
|
481 |
return ""
|
482 |
-
|
483 |
content = await file.read()
|
484 |
filename = file.filename.lower()
|
485 |
-
|
486 |
try:
|
487 |
if filename.endswith('.pdf'):
|
488 |
try:
|
489 |
doc = fitz.open(stream=content, filetype="pdf")
|
|
|
|
|
490 |
text = ""
|
491 |
for page in doc:
|
492 |
text += page.get_text()
|
493 |
return text
|
494 |
except Exception as pdf_error:
|
495 |
-
logger.warning(f"PyMuPDF failed
|
496 |
from pdfminer.high_level import extract_text
|
497 |
from io import BytesIO
|
498 |
return extract_text(BytesIO(content))
|
499 |
-
|
500 |
elif filename.endswith(('.docx', '.doc')):
|
501 |
doc = Document(io.BytesIO(content))
|
502 |
return "\n".join(para.text for para in doc.paragraphs)
|
|
|
503 |
elif filename.endswith('.txt'):
|
504 |
return content.decode('utf-8', errors='replace')
|
|
|
505 |
elif filename.endswith('.rtf'):
|
506 |
text = content.decode('utf-8', errors='replace')
|
507 |
text = re.sub(r'\\[a-z]+', ' ', text)
|
508 |
text = re.sub(r'\{|\}|\\', '', text)
|
509 |
return text
|
|
|
510 |
else:
|
511 |
raise HTTPException(status_code=400, detail=f"Unsupported file format: {filename}")
|
|
|
512 |
except Exception as e:
|
513 |
logger.error(f"File extraction error: {str(e)}", exc_info=True)
|
514 |
-
raise HTTPException(
|
|
|
|
|
|
|
515 |
|
516 |
def generate_visualization_code(df: pd.DataFrame, request: str = None) -> str:
|
517 |
"""Generate visualization code based on data analysis"""
|
@@ -583,53 +580,6 @@ plt.savefig('distribution_plot.png')
|
|
583 |
plt.show()
|
584 |
print(df['{num_col}'].describe())"""
|
585 |
|
586 |
-
elif len(numeric_cols) >= 3 and ("pairplot" in request_lower or "multi" in request_lower):
|
587 |
-
return f"""import pandas as pd
|
588 |
-
import matplotlib.pyplot as plt
|
589 |
-
import seaborn as sns
|
590 |
-
df = pd.read_excel('data.xlsx')
|
591 |
-
plt.figure(figsize=(12, 10))
|
592 |
-
sns.set(style="ticks")
|
593 |
-
plot = sns.pairplot(df[{numeric_cols[:5]}], diag_kind='kde', plot_kws={{'alpha': 0.6}})
|
594 |
-
plot.fig.suptitle('Correlation Matrix of Numeric Variables', y=1.02, fontsize=16)
|
595 |
-
plt.tight_layout()
|
596 |
-
plt.savefig('pairplot.png')
|
597 |
-
plt.show()
|
598 |
-
correlation_matrix = df[{numeric_cols[:5]}].corr()
|
599 |
-
plt.figure(figsize=(10, 8))
|
600 |
-
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt='.2f', linewidths=0.5)
|
601 |
-
plt.title('Correlation Matrix')
|
602 |
-
plt.tight_layout()
|
603 |
-
plt.savefig('correlation_matrix.png')
|
604 |
-
plt.show()"""
|
605 |
-
|
606 |
-
elif len(date_cols) >= 1 and len(numeric_cols) >= 1 and ("time" in request_lower or "trend" in request_lower):
|
607 |
-
date_col = date_cols[0]
|
608 |
-
num_col = numeric_cols[0]
|
609 |
-
return f"""import pandas as pd
|
610 |
-
import matplotlib.pyplot as plt
|
611 |
-
import seaborn as sns
|
612 |
-
import matplotlib.dates as mdates
|
613 |
-
df = pd.read_excel('data.xlsx')
|
614 |
-
df['{date_col}'] = pd.to_datetime(df['{date_col}'])
|
615 |
-
df = df.sort_values(by='{date_col}')
|
616 |
-
plt.figure(figsize=(12, 6))
|
617 |
-
plt.plot(df['{date_col}'], df['{num_col}'], marker='o', linestyle='-', color='#7b2cbf', linewidth=2, markersize=6)
|
618 |
-
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
619 |
-
plt.gca().xaxis.set_major_locator(mdates.AutoDateLocator())
|
620 |
-
plt.title('Trend of {num_col} over time', fontsize=15)
|
621 |
-
plt.xlabel('Date', fontsize=12)
|
622 |
-
plt.ylabel('{num_col}', fontsize=12)
|
623 |
-
plt.grid(True, alpha=0.3)
|
624 |
-
plt.xticks(rotation=45)
|
625 |
-
plt.tight_layout()
|
626 |
-
plt.savefig('time_series.png')
|
627 |
-
plt.show()
|
628 |
-
from scipy import stats
|
629 |
-
x = np.arange(len(df))
|
630 |
-
slope, intercept, r_value, p_value, std_err = stats.linregress(x, df['{num_col}'])
|
631 |
-
print(f"Trend: {{'Positive' if slope > 0 else 'Negative'}}, Slope: {{slope:.4f}}, R²: {{r_value**2:.4f}}")"""
|
632 |
-
|
633 |
else:
|
634 |
return f"""import pandas as pd
|
635 |
import matplotlib.pyplot as plt
|
@@ -669,7 +619,7 @@ plt.show()"""
|
|
669 |
@app.get("/", include_in_schema=False)
|
670 |
async def home():
|
671 |
"""Redirect to the static index.html file"""
|
672 |
-
return RedirectResponse(url="/static/
|
673 |
|
674 |
@app.get("/health", include_in_schema=True)
|
675 |
async def health_check():
|
@@ -684,36 +634,35 @@ async def list_models():
|
|
684 |
@app.on_event("startup")
|
685 |
async def startup_event():
|
686 |
"""Pre-load models at startup with timeout"""
|
|
|
687 |
logger.info("Starting model pre-loading...")
|
688 |
-
|
689 |
-
# Load Gemini models synchronously
|
690 |
-
for task in ["chatbot", "translation"]:
|
691 |
-
try:
|
692 |
-
load_model(task) # Synchronous call
|
693 |
-
logger.info(f"Successfully loaded {task} model")
|
694 |
-
except Exception as e:
|
695 |
-
logger.error(f"Error pre-loading {task}: {str(e)}")
|
696 |
-
|
697 |
-
# Load Hugging Face models asynchronously
|
698 |
async def load_model_with_timeout(task):
|
699 |
try:
|
700 |
-
await asyncio.wait_for(load_model
|
701 |
logger.info(f"Successfully loaded {task} model")
|
702 |
except asyncio.TimeoutError:
|
703 |
logger.warning(f"Timeout loading {task} model - will load on demand")
|
704 |
except Exception as e:
|
705 |
logger.error(f"Error pre-loading {task}: {str(e)}")
|
706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
await asyncio.gather(
|
708 |
load_model_with_timeout("summarization"),
|
709 |
load_model_with_timeout("image-to-text"),
|
710 |
-
load_model_with_timeout("visual-qa")
|
|
|
711 |
)
|
712 |
|
713 |
if __name__ == "__main__":
|
714 |
import uvicorn
|
715 |
-
# Ensure the upload_dir is writable
|
716 |
-
logger.info(f"Checking write permissions for {upload_dir}")
|
717 |
-
if not os.access(upload_dir, os.W_OK):
|
718 |
-
logger.error(f"No write permissions for {upload_dir}")
|
719 |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
1 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
|
2 |
from fastapi.staticfiles import StaticFiles
|
3 |
from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
|
4 |
+
from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, M2M100ForConditionalGeneration, M2M100Tokenizer
|
|
|
5 |
from typing import Optional, Dict, Any, List
|
6 |
import logging
|
7 |
import time
|
8 |
import os
|
9 |
import io
|
10 |
import json
|
11 |
+
import re
|
12 |
from PIL import Image
|
13 |
from docx import Document
|
14 |
import fitz # PyMuPDF
|
15 |
import pandas as pd
|
16 |
from functools import lru_cache
|
|
|
17 |
import torch
|
18 |
import numpy as np
|
19 |
from pydantic import BaseModel
|
20 |
import asyncio
|
21 |
import google.generativeai as genai
|
22 |
|
|
|
|
|
|
|
23 |
# Configure logging
|
24 |
logging.basicConfig(
|
25 |
level=logging.INFO,
|
|
|
33 |
|
34 |
app = FastAPI(
|
35 |
title="Cosmic AI Assistant",
|
36 |
+
description="An advanced AI assistant with space-themed interface and translation features",
|
37 |
version="2.0.0"
|
38 |
)
|
39 |
|
40 |
+
# Mount static files
|
41 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
42 |
|
43 |
+
# Mount videos directory
|
44 |
+
app.mount("/videos", StaticFiles(directory="videos"), name="videos")
|
45 |
+
|
46 |
+
# Mount videos directory
|
47 |
+
app.mount("/images", StaticFiles(directory="images"), name="images")
|
48 |
|
49 |
# Gemini API Configuration
|
50 |
API_KEY = "AIzaSyCwmgD8KxzWiuivtySNtcZF_rfTvx9s9sY" # Replace with your actual API key
|
|
|
54 |
MODELS = {
|
55 |
"summarization": "sshleifer/distilbart-cnn-12-6",
|
56 |
"image-to-text": "Salesforce/blip-image-captioning-large",
|
|
|
57 |
"visual-qa": "dandelin/vilt-b32-finetuned-vqa",
|
58 |
+
"chatbot": "gemini-1.5-pro", # Handles both chat and text generation
|
59 |
+
"translation": "facebook/m2m100_418M"
|
60 |
+
}
|
61 |
+
|
62 |
+
# Supported languages for translation
|
63 |
+
SUPPORTED_LANGUAGES = {
|
64 |
+
"english": "en",
|
65 |
+
"french": "fr",
|
66 |
+
"german": "de",
|
67 |
+
"spanish": "es",
|
68 |
+
"italian": "it",
|
69 |
+
"russian": "ru",
|
70 |
+
"chinese": "zh",
|
71 |
+
"japanese": "ja",
|
72 |
+
"arabic": "ar",
|
73 |
+
"hindi": "hi",
|
74 |
+
"portuguese": "pt",
|
75 |
+
"korean": "ko"
|
76 |
}
|
77 |
|
78 |
+
# Global variables for pre-loaded translation model
|
79 |
+
translation_model = None
|
80 |
+
translation_tokenizer = None
|
81 |
+
|
82 |
+
# Cache for model loading (excluding translation)
|
83 |
@lru_cache(maxsize=8)
|
84 |
def load_model(task: str, model_name: str = None):
|
85 |
"""Cached model loader with proper task names and error handling"""
|
|
|
89 |
|
90 |
model_to_load = model_name or MODELS.get(task)
|
91 |
|
92 |
+
if task == "chatbot": # Gemini handles both chat and text generation
|
93 |
+
return genai.GenerativeModel(model_to_load)
|
|
|
|
|
|
|
94 |
|
95 |
if task == "visual-qa":
|
96 |
processor = ViltProcessor.from_pretrained(model_to_load)
|
|
|
111 |
logger.info(f"VQA raw output: {answer}")
|
112 |
return answer
|
113 |
|
|
|
114 |
return vqa_function
|
115 |
|
116 |
+
return pipeline(task, model=model_to_load)
|
|
|
|
|
117 |
|
118 |
except Exception as e:
|
119 |
logger.error(f"Model load failed: {str(e)}")
|
120 |
raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
|
121 |
|
122 |
+
def get_gemini_response(user_input: str, is_generation: bool = False):
|
123 |
+
"""Function to generate response with Gemini for both chat and text generation"""
|
124 |
if not user_input:
|
125 |
+
return "Please provide some input."
|
126 |
try:
|
127 |
chatbot = load_model("chatbot")
|
128 |
+
if is_generation:
|
129 |
+
prompt = f"Generate creative text based on this prompt: {user_input}"
|
130 |
+
else:
|
131 |
+
prompt = user_input
|
132 |
+
response = chatbot.generate_content(prompt)
|
133 |
return response.text.strip()
|
134 |
except Exception as e:
|
135 |
return f"Error: {str(e)}"
|
136 |
|
137 |
def translate_text(text: str, target_language: str):
|
138 |
+
"""Translate text to any target language using pre-loaded M2M100 model"""
|
139 |
if not text:
|
140 |
return "Please provide text to translate."
|
141 |
+
|
142 |
try:
|
143 |
+
global translation_model, translation_tokenizer
|
144 |
+
|
145 |
+
target_lang = target_language.lower()
|
146 |
+
if target_lang not in SUPPORTED_LANGUAGES:
|
147 |
+
similar = [lang for lang in SUPPORTED_LANGUAGES if target_lang in lang or lang in target_lang]
|
148 |
+
if similar:
|
149 |
+
target_lang = similar[0]
|
150 |
+
else:
|
151 |
+
return f"Language '{target_language}' not supported. Available languages: {', '.join(SUPPORTED_LANGUAGES.keys())}"
|
152 |
+
|
153 |
+
lang_code = SUPPORTED_LANGUAGES[target_lang]
|
154 |
+
|
155 |
+
if translation_model is None or translation_tokenizer is None:
|
156 |
+
raise Exception("Translation model not initialized")
|
157 |
+
|
158 |
+
match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower())
|
159 |
+
if match:
|
160 |
+
text_to_translate = match.group(1)
|
161 |
+
else:
|
162 |
+
content_match = re.search(r'(?:translate|convert).*to\s+[a-zA-Z]+\s*[:\s]*(.+)', text, re.IGNORECASE)
|
163 |
+
text_to_translate = content_match.group(1) if content_match else text
|
164 |
+
|
165 |
+
translation_tokenizer.src_lang = "en"
|
166 |
+
encoded = translation_tokenizer(text_to_translate, return_tensors="pt", padding=True, truncation=True).to(translation_model.device)
|
167 |
+
|
168 |
+
start_time = time.time()
|
169 |
+
generated_tokens = translation_model.generate(
|
170 |
+
**encoded,
|
171 |
+
forced_bos_token_id=translation_tokenizer.get_lang_id(lang_code),
|
172 |
+
max_length=512,
|
173 |
+
num_beams=1,
|
174 |
+
early_stopping=True
|
175 |
+
)
|
176 |
+
translated_text = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
177 |
+
logger.info(f"Translation took {time.time() - start_time:.2f} seconds")
|
178 |
+
|
179 |
+
return translated_text
|
180 |
+
|
181 |
except Exception as e:
|
182 |
+
logger.error(f"Translation error: {str(e)}", exc_info=True)
|
183 |
return f"Translation error: {str(e)}"
|
184 |
|
185 |
def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
186 |
+
"""Enhanced intent detection with dynamic translation support"""
|
187 |
target_language = "English" # Default
|
188 |
|
189 |
if file:
|
190 |
content_type = file.content_type.lower() if file.content_type else ""
|
191 |
filename = file.filename.lower() if file.filename else ""
|
192 |
|
193 |
+
# Added: Catch "what’s this" and "does this fly" first for images
|
194 |
+
if content_type.startswith('image/') and text:
|
195 |
+
text_lower = text.lower()
|
196 |
+
if "what’s this" in text_lower:
|
197 |
+
return "visual-qa", target_language
|
198 |
+
if "does this fly" in text_lower:
|
199 |
+
return "visual-qa", target_language
|
200 |
+
# Added: Broaden "fly" questions for VQA
|
201 |
+
if "fly" in text_lower and any(q in text_lower for q in ['does', 'can', 'will']):
|
202 |
+
return "visual-qa", target_language
|
203 |
+
|
204 |
if content_type.startswith('image/'):
|
205 |
if text and any(q in text.lower() for q in ['what is', 'what\'s', 'describe', 'tell me about', 'explain','how many', 'what color', 'is there', 'are they', 'does the']):
|
206 |
return "visual-qa", target_language
|
|
|
218 |
if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']):
|
219 |
return "chatbot", target_language
|
220 |
|
221 |
+
translate_patterns = [
|
222 |
+
r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
223 |
+
r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
224 |
+
r'how to say.*in\s+\[?([a-zA-Z]+)\]?:?\s*(.*)'
|
225 |
+
]
|
226 |
+
|
227 |
+
for pattern in translate_patterns:
|
228 |
+
translate_match = re.search(pattern, text_lower)
|
229 |
+
if translate_match:
|
230 |
+
potential_lang = translate_match.group(1).lower()
|
231 |
+
if potential_lang in SUPPORTED_LANGUAGES:
|
232 |
+
target_language = potential_lang.capitalize()
|
233 |
+
return "translate", target_language
|
234 |
+
else:
|
235 |
+
logger.warning(f"Invalid language detected: {potential_lang}")
|
236 |
+
return "chatbot", target_language
|
237 |
|
238 |
vqa_patterns = [
|
239 |
r'how (many|much)',
|
|
|
257 |
|
258 |
if any(re.search(pattern, text_lower) for pattern in summarization_patterns):
|
259 |
return "summarize", target_language
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
generation_patterns = [
|
262 |
r'\b(write|generate|create|compose)\b',
|
|
|
269 |
if len(text) > 100:
|
270 |
return "summarize", target_language
|
271 |
|
272 |
+
if file and file.content_type and file.content_type.startswith('image/'):
|
273 |
+
if text and "what’s this" in text_lower:
|
274 |
+
return "visual-qa", target_language
|
275 |
+
if text and any(q in text_lower for q in ['does this', 'is this', 'can this']):
|
276 |
+
return "visual-qa", target_language
|
277 |
+
|
278 |
return "chatbot", target_language
|
279 |
|
280 |
class ProcessResponse(BaseModel):
|
|
|
282 |
type: str
|
283 |
additional_data: Optional[Dict[str, Any]] = None
|
284 |
|
285 |
+
@app.get("/chatbot")
|
|
|
286 |
async def chatbot_interface():
|
287 |
+
"""Redirect to the static index.html file for the chatbot interface"""
|
288 |
+
return RedirectResponse(url="/static/index.html")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
@app.post("/chat")
|
291 |
async def chat_endpoint(data: dict):
|
|
|
293 |
if not message:
|
294 |
raise HTTPException(status_code=400, detail="No message provided")
|
295 |
try:
|
296 |
+
response = get_gemini_response(message)
|
297 |
return {"response": response}
|
298 |
except Exception as e:
|
299 |
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
|
|
|
314 |
|
315 |
try:
|
316 |
if intent == "chatbot":
|
317 |
+
response = get_gemini_response(text)
|
318 |
return {"response": response, "type": "chat"}
|
319 |
|
320 |
elif intent == "translate":
|
321 |
content = await extract_text_from_file(file) if file else text
|
322 |
+
if "all languages" in text.lower():
|
323 |
+
translations = {}
|
324 |
+
phrase_to_translate = "I want to explore the stars" if "I want to explore the stars" in text else content
|
325 |
+
for lang, code in SUPPORTED_LANGUAGES.items():
|
326 |
+
translation_tokenizer.src_lang = "en"
|
327 |
+
encoded = translation_tokenizer(phrase_to_translate, return_tensors="pt").to(translation_model.device)
|
328 |
+
generated_tokens = translation_model.generate(
|
329 |
+
**encoded,
|
330 |
+
forced_bos_token_id=translation_tokenizer.get_lang_id(code),
|
331 |
+
max_length=512,
|
332 |
+
num_beams=1
|
333 |
+
)
|
334 |
+
translations[lang] = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
335 |
+
response = "\n".join(f"{lang.capitalize()}: {translations[lang]}" for lang in translations)
|
336 |
+
logger.info(f"Translated to all supported languages: {', '.join(translations.keys())}")
|
337 |
+
return {"response": response, "type": "translation"}
|
338 |
+
else:
|
339 |
+
translated_text = translate_text(content, target_language)
|
340 |
+
return {"response": translated_text, "type": "translation"}
|
341 |
|
342 |
elif intent == "summarize":
|
343 |
content = await extract_text_from_file(file) if file else text
|
|
|
375 |
final_summary = re.sub(r'\s+', ' ', final_summary).strip()
|
376 |
return {"response": final_summary, "type": "summary"}
|
377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
elif intent == "image-to-text":
|
379 |
if not file or not file.content_type.startswith('image/'):
|
380 |
raise HTTPException(status_code=400, detail="An image file is required")
|
|
|
412 |
answer = answer.capitalize()
|
413 |
if not answer.endswith(('.', '!', '?')):
|
414 |
answer += '.'
|
415 |
+
chatbot = load_model("chatbot")
|
416 |
+
if "fly" in question.lower():
|
417 |
+
answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip()
|
418 |
+
else:
|
419 |
+
answer = chatbot.generate_content(f"Make this cosmic and poetic: {answer}").text.strip()
|
420 |
|
421 |
logger.info(f"Final VQA answer: {answer}")
|
422 |
|
|
|
441 |
df = pd.read_excel(io.BytesIO(file_content))
|
442 |
|
443 |
code = generate_visualization_code(df, text)
|
444 |
+
stats = df.describe().to_string()
|
445 |
+
response = f"Stats:\n{stats}\n\nChart Code:\n{code}"
|
446 |
|
447 |
+
return {"response": response, "type": "visualization_code"}
|
448 |
|
449 |
elif intent == "text-generation":
|
450 |
+
response = get_gemini_response(text, is_generation=True)
|
451 |
+
lines = response.split(". ")
|
452 |
+
formatted_poem = "\n".join(line.strip() + ("." if not line.endswith(".") else "") for line in lines if line)
|
453 |
+
return {"response": formatted_poem, "type": "generated_text"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
|
455 |
else:
|
456 |
+
response = get_gemini_response(text or "Hello! How can I assist you?")
|
457 |
return {"response": response, "type": "chat"}
|
458 |
|
459 |
except Exception as e:
|
|
|
464 |
logger.info(f"Request processed in {process_time:.2f} seconds")
|
465 |
|
466 |
async def extract_text_from_file(file: UploadFile) -> str:
|
467 |
+
"""Enhanced text extraction with multiple fallbacks"""
|
468 |
if not file:
|
469 |
return ""
|
470 |
+
|
471 |
content = await file.read()
|
472 |
filename = file.filename.lower()
|
473 |
+
|
474 |
try:
|
475 |
if filename.endswith('.pdf'):
|
476 |
try:
|
477 |
doc = fitz.open(stream=content, filetype="pdf")
|
478 |
+
if doc.is_encrypted:
|
479 |
+
return "PDF is encrypted and cannot be read"
|
480 |
text = ""
|
481 |
for page in doc:
|
482 |
text += page.get_text()
|
483 |
return text
|
484 |
except Exception as pdf_error:
|
485 |
+
logger.warning(f"PyMuPDF failed: {str(pdf_error)}. Trying pdfminer.six...")
|
486 |
from pdfminer.high_level import extract_text
|
487 |
from io import BytesIO
|
488 |
return extract_text(BytesIO(content))
|
489 |
+
|
490 |
elif filename.endswith(('.docx', '.doc')):
|
491 |
doc = Document(io.BytesIO(content))
|
492 |
return "\n".join(para.text for para in doc.paragraphs)
|
493 |
+
|
494 |
elif filename.endswith('.txt'):
|
495 |
return content.decode('utf-8', errors='replace')
|
496 |
+
|
497 |
elif filename.endswith('.rtf'):
|
498 |
text = content.decode('utf-8', errors='replace')
|
499 |
text = re.sub(r'\\[a-z]+', ' ', text)
|
500 |
text = re.sub(r'\{|\}|\\', '', text)
|
501 |
return text
|
502 |
+
|
503 |
else:
|
504 |
raise HTTPException(status_code=400, detail=f"Unsupported file format: {filename}")
|
505 |
+
|
506 |
except Exception as e:
|
507 |
logger.error(f"File extraction error: {str(e)}", exc_info=True)
|
508 |
+
raise HTTPException(
|
509 |
+
status_code=500,
|
510 |
+
detail=f"Error extracting text: {str(e)}. Supported formats: PDF, DOCX, TXT, RTF"
|
511 |
+
)
|
512 |
|
513 |
def generate_visualization_code(df: pd.DataFrame, request: str = None) -> str:
|
514 |
"""Generate visualization code based on data analysis"""
|
|
|
580 |
plt.show()
|
581 |
print(df['{num_col}'].describe())"""
|
582 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
583 |
else:
|
584 |
return f"""import pandas as pd
|
585 |
import matplotlib.pyplot as plt
|
|
|
619 |
@app.get("/", include_in_schema=False)
|
620 |
async def home():
|
621 |
"""Redirect to the static index.html file"""
|
622 |
+
return RedirectResponse(url="/static/tito.html")
|
623 |
|
624 |
@app.get("/health", include_in_schema=True)
|
625 |
async def health_check():
|
|
|
634 |
@app.on_event("startup")
|
635 |
async def startup_event():
|
636 |
"""Pre-load models at startup with timeout"""
|
637 |
+
global translation_model, translation_tokenizer
|
638 |
logger.info("Starting model pre-loading...")
|
639 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
640 |
async def load_model_with_timeout(task):
|
641 |
try:
|
642 |
+
await asyncio.wait_for(asyncio.to_thread(load_model, task), timeout=60.0)
|
643 |
logger.info(f"Successfully loaded {task} model")
|
644 |
except asyncio.TimeoutError:
|
645 |
logger.warning(f"Timeout loading {task} model - will load on demand")
|
646 |
except Exception as e:
|
647 |
logger.error(f"Error pre-loading {task}: {str(e)}")
|
648 |
+
|
649 |
+
try:
|
650 |
+
model_name = MODELS["translation"]
|
651 |
+
translation_model = M2M100ForConditionalGeneration.from_pretrained(model_name)
|
652 |
+
translation_tokenizer = M2M100Tokenizer.from_pretrained(model_name)
|
653 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
654 |
+
translation_model.to(device)
|
655 |
+
logger.info("Translation model pre-loaded successfully")
|
656 |
+
except Exception as e:
|
657 |
+
logger.error(f"Error pre-loading translation model: {str(e)}")
|
658 |
+
|
659 |
await asyncio.gather(
|
660 |
load_model_with_timeout("summarization"),
|
661 |
load_model_with_timeout("image-to-text"),
|
662 |
+
load_model_with_timeout("visual-qa"),
|
663 |
+
load_model_with_timeout("chatbot")
|
664 |
)
|
665 |
|
666 |
if __name__ == "__main__":
|
667 |
import uvicorn
|
|
|
|
|
|
|
|
|
668 |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|