Hadiil commited on
Commit
110a761
·
verified ·
1 Parent(s): b755f09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -238
app.py CHANGED
@@ -1,29 +1,25 @@
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
4
- from fastapi.templating import Jinja2Templates
5
- from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering
6
  from typing import Optional, Dict, Any, List
7
  import logging
8
  import time
9
  import os
10
  import io
11
  import json
 
12
  from PIL import Image
13
  from docx import Document
14
  import fitz # PyMuPDF
15
  import pandas as pd
16
  from functools import lru_cache
17
- import re
18
  import torch
19
  import numpy as np
20
  from pydantic import BaseModel
21
  import asyncio
22
  import google.generativeai as genai
23
 
24
- # Set the TRANSFORMERS_CACHE environment variable to a writable directory
25
- os.environ["HF_HOME"] = "/tmp/huggingface_cache"
26
-
27
  # Configure logging
28
  logging.basicConfig(
29
  level=logging.INFO,
@@ -37,15 +33,18 @@ os.makedirs(upload_dir, exist_ok=True)
37
 
38
  app = FastAPI(
39
  title="Cosmic AI Assistant",
40
- description="An advanced AI assistant with space-themed interface and Gemini-powered features",
41
  version="2.0.0"
42
  )
43
 
44
- # Mount static files
45
  app.mount("/static", StaticFiles(directory="static"), name="static")
46
 
47
- # Setup templates
48
- templates = Jinja2Templates(directory="templates")
 
 
 
49
 
50
  # Gemini API Configuration
51
  API_KEY = "AIzaSyCwmgD8KxzWiuivtySNtcZF_rfTvx9s9sY" # Replace with your actual API key
@@ -55,14 +54,32 @@ genai.configure(api_key=API_KEY)
55
  MODELS = {
56
  "summarization": "sshleifer/distilbart-cnn-12-6",
57
  "image-to-text": "Salesforce/blip-image-captioning-large",
58
- "question-answering": "deepset/roberta-base-squad2",
59
  "visual-qa": "dandelin/vilt-b32-finetuned-vqa",
60
- "text-generation": "gpt2-medium",
61
- "chatbot": "gemini-1.5-pro",
62
- "translation": "gemini-1.5-pro"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  }
64
 
65
- # Cache for model loading
 
 
 
 
66
  @lru_cache(maxsize=8)
67
  def load_model(task: str, model_name: str = None):
68
  """Cached model loader with proper task names and error handling"""
@@ -72,11 +89,8 @@ def load_model(task: str, model_name: str = None):
72
 
73
  model_to_load = model_name or MODELS.get(task)
74
 
75
- if task in ["chatbot", "translation"]:
76
- logger.info(f"Initializing Gemini model: {model_to_load}")
77
- model = genai.GenerativeModel(model_to_load)
78
- logger.info(f"Gemini model loaded in {time.time() - start_time:.2f}s")
79
- return model
80
 
81
  if task == "visual-qa":
82
  processor = ViltProcessor.from_pretrained(model_to_load)
@@ -97,48 +111,96 @@ def load_model(task: str, model_name: str = None):
97
  logger.info(f"VQA raw output: {answer}")
98
  return answer
99
 
100
- logger.info(f"Visual QA model loaded in {time.time() - start_time:.2f}s")
101
  return vqa_function
102
 
103
- model = pipeline(task, model=model_to_load)
104
- logger.info(f"Pipeline model loaded in {time.time() - start_time:.2f}s")
105
- return model
106
 
107
  except Exception as e:
108
  logger.error(f"Model load failed: {str(e)}")
109
  raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
110
 
111
- def get_chatbot_response(user_input: str):
112
- """Function to generate response with Gemini"""
113
  if not user_input:
114
- return "Please ask a question."
115
  try:
116
  chatbot = load_model("chatbot")
117
- response = chatbot.generate_content(user_input)
 
 
 
 
118
  return response.text.strip()
119
  except Exception as e:
120
  return f"Error: {str(e)}"
121
 
122
  def translate_text(text: str, target_language: str):
123
- """Translate text to any target language using Gemini"""
124
  if not text:
125
  return "Please provide text to translate."
 
126
  try:
127
- translator = load_model("translation")
128
- prompt = f"Translate this text to {target_language}: {text}"
129
- response = translator.generate_content(prompt)
130
- return response.text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  except Exception as e:
 
132
  return f"Translation error: {str(e)}"
133
 
134
  def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
135
- """Enhanced intent detection with dynamic translation support including bracketed languages"""
136
  target_language = "English" # Default
137
 
138
  if file:
139
  content_type = file.content_type.lower() if file.content_type else ""
140
  filename = file.filename.lower() if file.filename else ""
141
 
 
 
 
 
 
 
 
 
 
 
 
142
  if content_type.startswith('image/'):
143
  if text and any(q in text.lower() for q in ['what is', 'what\'s', 'describe', 'tell me about', 'explain','how many', 'what color', 'is there', 'are they', 'does the']):
144
  return "visual-qa", target_language
@@ -156,11 +218,22 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
156
  if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']):
157
  return "chatbot", target_language
158
 
159
- # Dynamic translation detection with optional brackets
160
- translate_match = re.search(r'translate.*to\s+\[?([a-zA-Z]+)\]?:?', text_lower)
161
- if translate_match:
162
- target_language = translate_match.group(1).capitalize()
163
- return "translate", target_language
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  vqa_patterns = [
166
  r'how (many|much)',
@@ -184,15 +257,6 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
184
 
185
  if any(re.search(pattern, text_lower) for pattern in summarization_patterns):
186
  return "summarize", target_language
187
-
188
- question_patterns = [
189
- r'\b(what|when|where|why|how|who|which)\b',
190
- r'\?',
191
- r'\b(explain|tell me|describe|define)\b'
192
- ]
193
-
194
- if any(re.search(pattern, text_lower) for pattern in question_patterns):
195
- return "question-answering", target_language
196
 
197
  generation_patterns = [
198
  r'\b(write|generate|create|compose)\b',
@@ -205,6 +269,12 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
205
  if len(text) > 100:
206
  return "summarize", target_language
207
 
 
 
 
 
 
 
208
  return "chatbot", target_language
209
 
210
  class ProcessResponse(BaseModel):
@@ -212,85 +282,10 @@ class ProcessResponse(BaseModel):
212
  type: str
213
  additional_data: Optional[Dict[str, Any]] = None
214
 
215
- # Chatbot Web Interface with Translation Option
216
- @app.get("/chatbot", response_class=HTMLResponse)
217
  async def chatbot_interface():
218
- return """
219
- <!DOCTYPE html>
220
- <html lang="en">
221
- <head>
222
- <meta charset="UTF-8">
223
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
224
- <title>Cosmic AI Chatbot</title>
225
- <style>
226
- body { font-family: Arial, sans-serif; background: #282c34; color: white; text-align: center; }
227
- .chat-box { width: 50%; margin: 20px auto; background: #444; padding: 20px; border-radius: 10px; }
228
- #chat-box { max-height: 400px; overflow-y: auto; text-align: left; }
229
- input, button, select { width: 80%; padding: 10px; margin: 5px; border-radius: 5px; }
230
- button { background: #0084ff; color: white; cursor: pointer; }
231
- .message { margin: 10px 0; padding: 10px; border-radius: 5px; }
232
- .user-message { background: #555; }
233
- .bot-message { background: #666; }
234
- </style>
235
- </head>
236
- <body>
237
- <div class="chat-box">
238
- <h1>Cosmic AI Chatbot</h1>
239
- <div id="chat-box">
240
- <div class="message bot-message">
241
- <b>Bot:</b> Hello! I am your Cosmic AI Assistant. Upload a file or ask a question, and I can:<br>
242
- - Summarize documents<br>
243
- - Describe images<br>
244
- - Answer your questions<br>
245
- - Translate text to any language<br>
246
- - Generate visualization code
247
- </div>
248
- </div>
249
- <input type="text" id="user-input" placeholder="Type your message...">
250
- <select id="translate-to">
251
- <option value="">No translation</option>
252
- <option value="English">English</option>
253
- <option value="French">French</option>
254
- <option value="German">German</option>
255
- <option value="Spanish">Spanish</option>
256
- <option value="Italian">Italian</option>
257
- <option value="Russian">Russian</option>
258
- <option value="Chinese">Chinese</option>
259
- <option value="Japanese">Japanese</option>
260
- </select>
261
- <button onclick="sendMessage()">Send</button>
262
- </div>
263
- <script>
264
- async function sendMessage() {
265
- let inputField = document.getElementById("user-input");
266
- let translateTo = document.getElementById("translate-to").value;
267
- let chatBox = document.getElementById("chat-box");
268
- let userMessage = inputField.value.trim();
269
- if (!userMessage) return;
270
-
271
- let messageToSend = translateTo ? `Translate this to ${translateTo}: ${userMessage}` : userMessage;
272
-
273
- chatBox.innerHTML += `<div class="message user-message"><b>You:</b> ${userMessage}</div>`;
274
- let response = await fetch("/chat", {
275
- method: "POST",
276
- headers: { "Content-Type": "application/json" },
277
- body: JSON.stringify({ message: messageToSend })
278
- });
279
- let result = await response.json();
280
- chatBox.innerHTML += `<div class="message bot-message"><b>Bot:</b> ${result.response}</div>`;
281
- inputField.value = "";
282
- chatBox.scrollTop = chatBox.scrollHeight;
283
- }
284
-
285
- document.getElementById("user-input").addEventListener("keypress", function(e) {
286
- if (e.key === "Enter") {
287
- sendMessage();
288
- }
289
- });
290
- </script>
291
- </body>
292
- </html>
293
- """
294
 
295
  @app.post("/chat")
296
  async def chat_endpoint(data: dict):
@@ -298,7 +293,7 @@ async def chat_endpoint(data: dict):
298
  if not message:
299
  raise HTTPException(status_code=400, detail="No message provided")
300
  try:
301
- response = get_chatbot_response(message)
302
  return {"response": response}
303
  except Exception as e:
304
  raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
@@ -319,14 +314,30 @@ async def process_input(
319
 
320
  try:
321
  if intent == "chatbot":
322
- response = get_chatbot_response(text)
323
  return {"response": response, "type": "chat"}
324
 
325
  elif intent == "translate":
326
  content = await extract_text_from_file(file) if file else text
327
- content = re.sub(r'translate.*to\s+\[?[a-zA-Z]+\]?:?\s*', '', content, flags=re.IGNORECASE).strip()
328
- translated_text = translate_text(content, target_language)
329
- return {"response": translated_text, "type": "translation"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  elif intent == "summarize":
332
  content = await extract_text_from_file(file) if file else text
@@ -364,28 +375,6 @@ async def process_input(
364
  final_summary = re.sub(r'\s+', ' ', final_summary).strip()
365
  return {"response": final_summary, "type": "summary"}
366
 
367
- elif intent == "question-answering":
368
- context = await extract_text_from_file(file) if file else None
369
-
370
- if not context and not text:
371
- raise HTTPException(status_code=400, detail="No context provided")
372
-
373
- qa_pipeline = load_model("question-answering")
374
-
375
- if not context and "?" in text:
376
- parts = text.split("?", 1)
377
- question = parts[0] + "?"
378
- context = parts[1].strip() if len(parts) > 1 and parts[1].strip() else text
379
- else:
380
- question = text if text else "Summarize this document"
381
-
382
- result = qa_pipeline(
383
- question=question,
384
- context=context[:2000] if context else text[:2000]
385
- )
386
-
387
- return {"response": result["answer"], "type": "answer"}
388
-
389
  elif intent == "image-to-text":
390
  if not file or not file.content_type.startswith('image/'):
391
  raise HTTPException(status_code=400, detail="An image file is required")
@@ -423,6 +412,11 @@ async def process_input(
423
  answer = answer.capitalize()
424
  if not answer.endswith(('.', '!', '?')):
425
  answer += '.'
 
 
 
 
 
426
 
427
  logger.info(f"Final VQA answer: {answer}")
428
 
@@ -447,25 +441,19 @@ async def process_input(
447
  df = pd.read_excel(io.BytesIO(file_content))
448
 
449
  code = generate_visualization_code(df, text)
 
 
450
 
451
- return {"response": code, "type": "visualization_code"}
452
 
453
  elif intent == "text-generation":
454
- generator = load_model("text-generation")
455
-
456
- generated = generator(
457
- text,
458
- max_length=200,
459
- num_return_sequences=1,
460
- temperature=0.8,
461
- top_p=0.92,
462
- do_sample=True
463
- )
464
-
465
- return {"response": generated[0]["generated_text"], "type": "generated_text"}
466
 
467
  else:
468
- response = get_chatbot_response(text or "Hello! How can I assist you?")
469
  return {"response": response, "type": "chat"}
470
 
471
  except Exception as e:
@@ -476,42 +464,51 @@ async def process_input(
476
  logger.info(f"Request processed in {process_time:.2f} seconds")
477
 
478
  async def extract_text_from_file(file: UploadFile) -> str:
479
- """Enhanced text extraction with better error handling and format support"""
480
  if not file:
481
  return ""
482
-
483
  content = await file.read()
484
  filename = file.filename.lower()
485
-
486
  try:
487
  if filename.endswith('.pdf'):
488
  try:
489
  doc = fitz.open(stream=content, filetype="pdf")
 
 
490
  text = ""
491
  for page in doc:
492
  text += page.get_text()
493
  return text
494
  except Exception as pdf_error:
495
- logger.warning(f"PyMuPDF failed, trying pdfminer: {str(pdf_error)}")
496
  from pdfminer.high_level import extract_text
497
  from io import BytesIO
498
  return extract_text(BytesIO(content))
499
-
500
  elif filename.endswith(('.docx', '.doc')):
501
  doc = Document(io.BytesIO(content))
502
  return "\n".join(para.text for para in doc.paragraphs)
 
503
  elif filename.endswith('.txt'):
504
  return content.decode('utf-8', errors='replace')
 
505
  elif filename.endswith('.rtf'):
506
  text = content.decode('utf-8', errors='replace')
507
  text = re.sub(r'\\[a-z]+', ' ', text)
508
  text = re.sub(r'\{|\}|\\', '', text)
509
  return text
 
510
  else:
511
  raise HTTPException(status_code=400, detail=f"Unsupported file format: {filename}")
 
512
  except Exception as e:
513
  logger.error(f"File extraction error: {str(e)}", exc_info=True)
514
- raise HTTPException(status_code=500, detail=f"Error extracting text: {str(e)}")
 
 
 
515
 
516
  def generate_visualization_code(df: pd.DataFrame, request: str = None) -> str:
517
  """Generate visualization code based on data analysis"""
@@ -583,53 +580,6 @@ plt.savefig('distribution_plot.png')
583
  plt.show()
584
  print(df['{num_col}'].describe())"""
585
 
586
- elif len(numeric_cols) >= 3 and ("pairplot" in request_lower or "multi" in request_lower):
587
- return f"""import pandas as pd
588
- import matplotlib.pyplot as plt
589
- import seaborn as sns
590
- df = pd.read_excel('data.xlsx')
591
- plt.figure(figsize=(12, 10))
592
- sns.set(style="ticks")
593
- plot = sns.pairplot(df[{numeric_cols[:5]}], diag_kind='kde', plot_kws={{'alpha': 0.6}})
594
- plot.fig.suptitle('Correlation Matrix of Numeric Variables', y=1.02, fontsize=16)
595
- plt.tight_layout()
596
- plt.savefig('pairplot.png')
597
- plt.show()
598
- correlation_matrix = df[{numeric_cols[:5]}].corr()
599
- plt.figure(figsize=(10, 8))
600
- sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt='.2f', linewidths=0.5)
601
- plt.title('Correlation Matrix')
602
- plt.tight_layout()
603
- plt.savefig('correlation_matrix.png')
604
- plt.show()"""
605
-
606
- elif len(date_cols) >= 1 and len(numeric_cols) >= 1 and ("time" in request_lower or "trend" in request_lower):
607
- date_col = date_cols[0]
608
- num_col = numeric_cols[0]
609
- return f"""import pandas as pd
610
- import matplotlib.pyplot as plt
611
- import seaborn as sns
612
- import matplotlib.dates as mdates
613
- df = pd.read_excel('data.xlsx')
614
- df['{date_col}'] = pd.to_datetime(df['{date_col}'])
615
- df = df.sort_values(by='{date_col}')
616
- plt.figure(figsize=(12, 6))
617
- plt.plot(df['{date_col}'], df['{num_col}'], marker='o', linestyle='-', color='#7b2cbf', linewidth=2, markersize=6)
618
- plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
619
- plt.gca().xaxis.set_major_locator(mdates.AutoDateLocator())
620
- plt.title('Trend of {num_col} over time', fontsize=15)
621
- plt.xlabel('Date', fontsize=12)
622
- plt.ylabel('{num_col}', fontsize=12)
623
- plt.grid(True, alpha=0.3)
624
- plt.xticks(rotation=45)
625
- plt.tight_layout()
626
- plt.savefig('time_series.png')
627
- plt.show()
628
- from scipy import stats
629
- x = np.arange(len(df))
630
- slope, intercept, r_value, p_value, std_err = stats.linregress(x, df['{num_col}'])
631
- print(f"Trend: {{'Positive' if slope > 0 else 'Negative'}}, Slope: {{slope:.4f}}, R²: {{r_value**2:.4f}}")"""
632
-
633
  else:
634
  return f"""import pandas as pd
635
  import matplotlib.pyplot as plt
@@ -669,7 +619,7 @@ plt.show()"""
669
  @app.get("/", include_in_schema=False)
670
  async def home():
671
  """Redirect to the static index.html file"""
672
- return RedirectResponse(url="/static/index.html")
673
 
674
  @app.get("/health", include_in_schema=True)
675
  async def health_check():
@@ -684,36 +634,35 @@ async def list_models():
684
  @app.on_event("startup")
685
  async def startup_event():
686
  """Pre-load models at startup with timeout"""
 
687
  logger.info("Starting model pre-loading...")
688
-
689
- # Load Gemini models synchronously
690
- for task in ["chatbot", "translation"]:
691
- try:
692
- load_model(task) # Synchronous call
693
- logger.info(f"Successfully loaded {task} model")
694
- except Exception as e:
695
- logger.error(f"Error pre-loading {task}: {str(e)}")
696
-
697
- # Load Hugging Face models asynchronously
698
  async def load_model_with_timeout(task):
699
  try:
700
- await asyncio.wait_for(load_model(task), timeout=60.0)
701
  logger.info(f"Successfully loaded {task} model")
702
  except asyncio.TimeoutError:
703
  logger.warning(f"Timeout loading {task} model - will load on demand")
704
  except Exception as e:
705
  logger.error(f"Error pre-loading {task}: {str(e)}")
706
-
 
 
 
 
 
 
 
 
 
 
707
  await asyncio.gather(
708
  load_model_with_timeout("summarization"),
709
  load_model_with_timeout("image-to-text"),
710
- load_model_with_timeout("visual-qa")
 
711
  )
712
 
713
  if __name__ == "__main__":
714
  import uvicorn
715
- # Ensure the upload_dir is writable
716
- logger.info(f"Checking write permissions for {upload_dir}")
717
- if not os.access(upload_dir, os.W_OK):
718
- logger.error(f"No write permissions for {upload_dir}")
719
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
4
+ from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, M2M100ForConditionalGeneration, M2M100Tokenizer
 
5
  from typing import Optional, Dict, Any, List
6
  import logging
7
  import time
8
  import os
9
  import io
10
  import json
11
+ import re
12
  from PIL import Image
13
  from docx import Document
14
  import fitz # PyMuPDF
15
  import pandas as pd
16
  from functools import lru_cache
 
17
  import torch
18
  import numpy as np
19
  from pydantic import BaseModel
20
  import asyncio
21
  import google.generativeai as genai
22
 
 
 
 
23
  # Configure logging
24
  logging.basicConfig(
25
  level=logging.INFO,
 
33
 
34
  app = FastAPI(
35
  title="Cosmic AI Assistant",
36
+ description="An advanced AI assistant with space-themed interface and translation features",
37
  version="2.0.0"
38
  )
39
 
40
+ # Mount static files
41
  app.mount("/static", StaticFiles(directory="static"), name="static")
42
 
43
+ # Mount videos directory
44
+ app.mount("/videos", StaticFiles(directory="videos"), name="videos")
45
+
46
+ # Mount videos directory
47
+ app.mount("/images", StaticFiles(directory="images"), name="images")
48
 
49
  # Gemini API Configuration
50
  API_KEY = "AIzaSyCwmgD8KxzWiuivtySNtcZF_rfTvx9s9sY" # Replace with your actual API key
 
54
  MODELS = {
55
  "summarization": "sshleifer/distilbart-cnn-12-6",
56
  "image-to-text": "Salesforce/blip-image-captioning-large",
 
57
  "visual-qa": "dandelin/vilt-b32-finetuned-vqa",
58
+ "chatbot": "gemini-1.5-pro", # Handles both chat and text generation
59
+ "translation": "facebook/m2m100_418M"
60
+ }
61
+
62
+ # Supported languages for translation
63
+ SUPPORTED_LANGUAGES = {
64
+ "english": "en",
65
+ "french": "fr",
66
+ "german": "de",
67
+ "spanish": "es",
68
+ "italian": "it",
69
+ "russian": "ru",
70
+ "chinese": "zh",
71
+ "japanese": "ja",
72
+ "arabic": "ar",
73
+ "hindi": "hi",
74
+ "portuguese": "pt",
75
+ "korean": "ko"
76
  }
77
 
78
+ # Global variables for pre-loaded translation model
79
+ translation_model = None
80
+ translation_tokenizer = None
81
+
82
+ # Cache for model loading (excluding translation)
83
  @lru_cache(maxsize=8)
84
  def load_model(task: str, model_name: str = None):
85
  """Cached model loader with proper task names and error handling"""
 
89
 
90
  model_to_load = model_name or MODELS.get(task)
91
 
92
+ if task == "chatbot": # Gemini handles both chat and text generation
93
+ return genai.GenerativeModel(model_to_load)
 
 
 
94
 
95
  if task == "visual-qa":
96
  processor = ViltProcessor.from_pretrained(model_to_load)
 
111
  logger.info(f"VQA raw output: {answer}")
112
  return answer
113
 
 
114
  return vqa_function
115
 
116
+ return pipeline(task, model=model_to_load)
 
 
117
 
118
  except Exception as e:
119
  logger.error(f"Model load failed: {str(e)}")
120
  raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
121
 
122
+ def get_gemini_response(user_input: str, is_generation: bool = False):
123
+ """Function to generate response with Gemini for both chat and text generation"""
124
  if not user_input:
125
+ return "Please provide some input."
126
  try:
127
  chatbot = load_model("chatbot")
128
+ if is_generation:
129
+ prompt = f"Generate creative text based on this prompt: {user_input}"
130
+ else:
131
+ prompt = user_input
132
+ response = chatbot.generate_content(prompt)
133
  return response.text.strip()
134
  except Exception as e:
135
  return f"Error: {str(e)}"
136
 
137
  def translate_text(text: str, target_language: str):
138
+ """Translate text to any target language using pre-loaded M2M100 model"""
139
  if not text:
140
  return "Please provide text to translate."
141
+
142
  try:
143
+ global translation_model, translation_tokenizer
144
+
145
+ target_lang = target_language.lower()
146
+ if target_lang not in SUPPORTED_LANGUAGES:
147
+ similar = [lang for lang in SUPPORTED_LANGUAGES if target_lang in lang or lang in target_lang]
148
+ if similar:
149
+ target_lang = similar[0]
150
+ else:
151
+ return f"Language '{target_language}' not supported. Available languages: {', '.join(SUPPORTED_LANGUAGES.keys())}"
152
+
153
+ lang_code = SUPPORTED_LANGUAGES[target_lang]
154
+
155
+ if translation_model is None or translation_tokenizer is None:
156
+ raise Exception("Translation model not initialized")
157
+
158
+ match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower())
159
+ if match:
160
+ text_to_translate = match.group(1)
161
+ else:
162
+ content_match = re.search(r'(?:translate|convert).*to\s+[a-zA-Z]+\s*[:\s]*(.+)', text, re.IGNORECASE)
163
+ text_to_translate = content_match.group(1) if content_match else text
164
+
165
+ translation_tokenizer.src_lang = "en"
166
+ encoded = translation_tokenizer(text_to_translate, return_tensors="pt", padding=True, truncation=True).to(translation_model.device)
167
+
168
+ start_time = time.time()
169
+ generated_tokens = translation_model.generate(
170
+ **encoded,
171
+ forced_bos_token_id=translation_tokenizer.get_lang_id(lang_code),
172
+ max_length=512,
173
+ num_beams=1,
174
+ early_stopping=True
175
+ )
176
+ translated_text = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
177
+ logger.info(f"Translation took {time.time() - start_time:.2f} seconds")
178
+
179
+ return translated_text
180
+
181
  except Exception as e:
182
+ logger.error(f"Translation error: {str(e)}", exc_info=True)
183
  return f"Translation error: {str(e)}"
184
 
185
  def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
186
+ """Enhanced intent detection with dynamic translation support"""
187
  target_language = "English" # Default
188
 
189
  if file:
190
  content_type = file.content_type.lower() if file.content_type else ""
191
  filename = file.filename.lower() if file.filename else ""
192
 
193
+ # Added: Catch "what’s this" and "does this fly" first for images
194
+ if content_type.startswith('image/') and text:
195
+ text_lower = text.lower()
196
+ if "what’s this" in text_lower:
197
+ return "visual-qa", target_language
198
+ if "does this fly" in text_lower:
199
+ return "visual-qa", target_language
200
+ # Added: Broaden "fly" questions for VQA
201
+ if "fly" in text_lower and any(q in text_lower for q in ['does', 'can', 'will']):
202
+ return "visual-qa", target_language
203
+
204
  if content_type.startswith('image/'):
205
  if text and any(q in text.lower() for q in ['what is', 'what\'s', 'describe', 'tell me about', 'explain','how many', 'what color', 'is there', 'are they', 'does the']):
206
  return "visual-qa", target_language
 
218
  if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']):
219
  return "chatbot", target_language
220
 
221
+ translate_patterns = [
222
+ r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
223
+ r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
224
+ r'how to say.*in\s+\[?([a-zA-Z]+)\]?:?\s*(.*)'
225
+ ]
226
+
227
+ for pattern in translate_patterns:
228
+ translate_match = re.search(pattern, text_lower)
229
+ if translate_match:
230
+ potential_lang = translate_match.group(1).lower()
231
+ if potential_lang in SUPPORTED_LANGUAGES:
232
+ target_language = potential_lang.capitalize()
233
+ return "translate", target_language
234
+ else:
235
+ logger.warning(f"Invalid language detected: {potential_lang}")
236
+ return "chatbot", target_language
237
 
238
  vqa_patterns = [
239
  r'how (many|much)',
 
257
 
258
  if any(re.search(pattern, text_lower) for pattern in summarization_patterns):
259
  return "summarize", target_language
 
 
 
 
 
 
 
 
 
260
 
261
  generation_patterns = [
262
  r'\b(write|generate|create|compose)\b',
 
269
  if len(text) > 100:
270
  return "summarize", target_language
271
 
272
+ if file and file.content_type and file.content_type.startswith('image/'):
273
+ if text and "what’s this" in text_lower:
274
+ return "visual-qa", target_language
275
+ if text and any(q in text_lower for q in ['does this', 'is this', 'can this']):
276
+ return "visual-qa", target_language
277
+
278
  return "chatbot", target_language
279
 
280
  class ProcessResponse(BaseModel):
 
282
  type: str
283
  additional_data: Optional[Dict[str, Any]] = None
284
 
285
+ @app.get("/chatbot")
 
286
  async def chatbot_interface():
287
+ """Redirect to the static index.html file for the chatbot interface"""
288
+ return RedirectResponse(url="/static/index.html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  @app.post("/chat")
291
  async def chat_endpoint(data: dict):
 
293
  if not message:
294
  raise HTTPException(status_code=400, detail="No message provided")
295
  try:
296
+ response = get_gemini_response(message)
297
  return {"response": response}
298
  except Exception as e:
299
  raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
 
314
 
315
  try:
316
  if intent == "chatbot":
317
+ response = get_gemini_response(text)
318
  return {"response": response, "type": "chat"}
319
 
320
  elif intent == "translate":
321
  content = await extract_text_from_file(file) if file else text
322
+ if "all languages" in text.lower():
323
+ translations = {}
324
+ phrase_to_translate = "I want to explore the stars" if "I want to explore the stars" in text else content
325
+ for lang, code in SUPPORTED_LANGUAGES.items():
326
+ translation_tokenizer.src_lang = "en"
327
+ encoded = translation_tokenizer(phrase_to_translate, return_tensors="pt").to(translation_model.device)
328
+ generated_tokens = translation_model.generate(
329
+ **encoded,
330
+ forced_bos_token_id=translation_tokenizer.get_lang_id(code),
331
+ max_length=512,
332
+ num_beams=1
333
+ )
334
+ translations[lang] = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
335
+ response = "\n".join(f"{lang.capitalize()}: {translations[lang]}" for lang in translations)
336
+ logger.info(f"Translated to all supported languages: {', '.join(translations.keys())}")
337
+ return {"response": response, "type": "translation"}
338
+ else:
339
+ translated_text = translate_text(content, target_language)
340
+ return {"response": translated_text, "type": "translation"}
341
 
342
  elif intent == "summarize":
343
  content = await extract_text_from_file(file) if file else text
 
375
  final_summary = re.sub(r'\s+', ' ', final_summary).strip()
376
  return {"response": final_summary, "type": "summary"}
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  elif intent == "image-to-text":
379
  if not file or not file.content_type.startswith('image/'):
380
  raise HTTPException(status_code=400, detail="An image file is required")
 
412
  answer = answer.capitalize()
413
  if not answer.endswith(('.', '!', '?')):
414
  answer += '.'
415
+ chatbot = load_model("chatbot")
416
+ if "fly" in question.lower():
417
+ answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip()
418
+ else:
419
+ answer = chatbot.generate_content(f"Make this cosmic and poetic: {answer}").text.strip()
420
 
421
  logger.info(f"Final VQA answer: {answer}")
422
 
 
441
  df = pd.read_excel(io.BytesIO(file_content))
442
 
443
  code = generate_visualization_code(df, text)
444
+ stats = df.describe().to_string()
445
+ response = f"Stats:\n{stats}\n\nChart Code:\n{code}"
446
 
447
+ return {"response": response, "type": "visualization_code"}
448
 
449
  elif intent == "text-generation":
450
+ response = get_gemini_response(text, is_generation=True)
451
+ lines = response.split(". ")
452
+ formatted_poem = "\n".join(line.strip() + ("." if not line.endswith(".") else "") for line in lines if line)
453
+ return {"response": formatted_poem, "type": "generated_text"}
 
 
 
 
 
 
 
 
454
 
455
  else:
456
+ response = get_gemini_response(text or "Hello! How can I assist you?")
457
  return {"response": response, "type": "chat"}
458
 
459
  except Exception as e:
 
464
  logger.info(f"Request processed in {process_time:.2f} seconds")
465
 
466
  async def extract_text_from_file(file: UploadFile) -> str:
467
+ """Enhanced text extraction with multiple fallbacks"""
468
  if not file:
469
  return ""
470
+
471
  content = await file.read()
472
  filename = file.filename.lower()
473
+
474
  try:
475
  if filename.endswith('.pdf'):
476
  try:
477
  doc = fitz.open(stream=content, filetype="pdf")
478
+ if doc.is_encrypted:
479
+ return "PDF is encrypted and cannot be read"
480
  text = ""
481
  for page in doc:
482
  text += page.get_text()
483
  return text
484
  except Exception as pdf_error:
485
+ logger.warning(f"PyMuPDF failed: {str(pdf_error)}. Trying pdfminer.six...")
486
  from pdfminer.high_level import extract_text
487
  from io import BytesIO
488
  return extract_text(BytesIO(content))
489
+
490
  elif filename.endswith(('.docx', '.doc')):
491
  doc = Document(io.BytesIO(content))
492
  return "\n".join(para.text for para in doc.paragraphs)
493
+
494
  elif filename.endswith('.txt'):
495
  return content.decode('utf-8', errors='replace')
496
+
497
  elif filename.endswith('.rtf'):
498
  text = content.decode('utf-8', errors='replace')
499
  text = re.sub(r'\\[a-z]+', ' ', text)
500
  text = re.sub(r'\{|\}|\\', '', text)
501
  return text
502
+
503
  else:
504
  raise HTTPException(status_code=400, detail=f"Unsupported file format: {filename}")
505
+
506
  except Exception as e:
507
  logger.error(f"File extraction error: {str(e)}", exc_info=True)
508
+ raise HTTPException(
509
+ status_code=500,
510
+ detail=f"Error extracting text: {str(e)}. Supported formats: PDF, DOCX, TXT, RTF"
511
+ )
512
 
513
  def generate_visualization_code(df: pd.DataFrame, request: str = None) -> str:
514
  """Generate visualization code based on data analysis"""
 
580
  plt.show()
581
  print(df['{num_col}'].describe())"""
582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  else:
584
  return f"""import pandas as pd
585
  import matplotlib.pyplot as plt
 
619
  @app.get("/", include_in_schema=False)
620
  async def home():
621
  """Redirect to the static index.html file"""
622
+ return RedirectResponse(url="/static/tito.html")
623
 
624
  @app.get("/health", include_in_schema=True)
625
  async def health_check():
 
634
  @app.on_event("startup")
635
  async def startup_event():
636
  """Pre-load models at startup with timeout"""
637
+ global translation_model, translation_tokenizer
638
  logger.info("Starting model pre-loading...")
639
+
 
 
 
 
 
 
 
 
 
640
  async def load_model_with_timeout(task):
641
  try:
642
+ await asyncio.wait_for(asyncio.to_thread(load_model, task), timeout=60.0)
643
  logger.info(f"Successfully loaded {task} model")
644
  except asyncio.TimeoutError:
645
  logger.warning(f"Timeout loading {task} model - will load on demand")
646
  except Exception as e:
647
  logger.error(f"Error pre-loading {task}: {str(e)}")
648
+
649
+ try:
650
+ model_name = MODELS["translation"]
651
+ translation_model = M2M100ForConditionalGeneration.from_pretrained(model_name)
652
+ translation_tokenizer = M2M100Tokenizer.from_pretrained(model_name)
653
+ device = "cuda" if torch.cuda.is_available() else "cpu"
654
+ translation_model.to(device)
655
+ logger.info("Translation model pre-loaded successfully")
656
+ except Exception as e:
657
+ logger.error(f"Error pre-loading translation model: {str(e)}")
658
+
659
  await asyncio.gather(
660
  load_model_with_timeout("summarization"),
661
  load_model_with_timeout("image-to-text"),
662
+ load_model_with_timeout("visual-qa"),
663
+ load_model_with_timeout("chatbot")
664
  )
665
 
666
  if __name__ == "__main__":
667
  import uvicorn
 
 
 
 
668
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)