Hadiil commited on
Commit
0c83cdd
·
verified ·
1 Parent(s): 35946e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -8
app.py CHANGED
@@ -21,6 +21,9 @@ from pydantic import BaseModel
21
  import asyncio
22
  import google.generativeai as genai
23
 
 
 
 
24
  # Configure logging
25
  logging.basicConfig(
26
  level=logging.INFO,
@@ -70,7 +73,10 @@ def load_model(task: str, model_name: str = None):
70
  model_to_load = model_name or MODELS.get(task)
71
 
72
  if task in ["chatbot", "translation"]:
73
- return genai.GenerativeModel(model_to_load)
 
 
 
74
 
75
  if task == "visual-qa":
76
  processor = ViltProcessor.from_pretrained(model_to_load)
@@ -91,9 +97,12 @@ def load_model(task: str, model_name: str = None):
91
  logger.info(f"VQA raw output: {answer}")
92
  return answer
93
 
 
94
  return vqa_function
95
 
96
- return pipeline(task, model=model_to_load)
 
 
97
 
98
  except Exception as e:
99
  logger.error(f"Model load failed: {str(e)}")
@@ -676,7 +685,16 @@ async def list_models():
676
  async def startup_event():
677
  """Pre-load models at startup with timeout"""
678
  logger.info("Starting model pre-loading...")
679
-
 
 
 
 
 
 
 
 
 
680
  async def load_model_with_timeout(task):
681
  try:
682
  await asyncio.wait_for(load_model(task), timeout=60.0)
@@ -685,15 +703,17 @@ async def startup_event():
685
  logger.warning(f"Timeout loading {task} model - will load on demand")
686
  except Exception as e:
687
  logger.error(f"Error pre-loading {task}: {str(e)}")
688
-
689
  await asyncio.gather(
690
  load_model_with_timeout("summarization"),
691
  load_model_with_timeout("image-to-text"),
692
- load_model_with_timeout("visual-qa"),
693
- load_model_with_timeout("chatbot"),
694
- load_model_with_timeout("translation")
695
  )
696
 
697
  if __name__ == "__main__":
698
  import uvicorn
699
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
 
 
 
 
21
  import asyncio
22
  import google.generativeai as genai
23
 
24
+ # Set the TRANSFORMERS_CACHE environment variable to a writable directory
25
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
26
+
27
  # Configure logging
28
  logging.basicConfig(
29
  level=logging.INFO,
 
73
  model_to_load = model_name or MODELS.get(task)
74
 
75
  if task in ["chatbot", "translation"]:
76
+ logger.info(f"Initializing Gemini model: {model_to_load}")
77
+ model = genai.GenerativeModel(model_to_load)
78
+ logger.info(f"Gemini model loaded in {time.time() - start_time:.2f}s")
79
+ return model
80
 
81
  if task == "visual-qa":
82
  processor = ViltProcessor.from_pretrained(model_to_load)
 
97
  logger.info(f"VQA raw output: {answer}")
98
  return answer
99
 
100
+ logger.info(f"Visual QA model loaded in {time.time() - start_time:.2f}s")
101
  return vqa_function
102
 
103
+ model = pipeline(task, model=model_to_load)
104
+ logger.info(f"Pipeline model loaded in {time.time() - start_time:.2f}s")
105
+ return model
106
 
107
  except Exception as e:
108
  logger.error(f"Model load failed: {str(e)}")
 
685
  async def startup_event():
686
  """Pre-load models at startup with timeout"""
687
  logger.info("Starting model pre-loading...")
688
+
689
+ # Load Gemini models synchronously
690
+ for task in ["chatbot", "translation"]:
691
+ try:
692
+ load_model(task) # Synchronous call
693
+ logger.info(f"Successfully loaded {task} model")
694
+ except Exception as e:
695
+ logger.error(f"Error pre-loading {task}: {str(e)}")
696
+
697
+ # Load Hugging Face models asynchronously
698
  async def load_model_with_timeout(task):
699
  try:
700
  await asyncio.wait_for(load_model(task), timeout=60.0)
 
703
  logger.warning(f"Timeout loading {task} model - will load on demand")
704
  except Exception as e:
705
  logger.error(f"Error pre-loading {task}: {str(e)}")
706
+
707
  await asyncio.gather(
708
  load_model_with_timeout("summarization"),
709
  load_model_with_timeout("image-to-text"),
710
+ load_model_with_timeout("visual-qa")
 
 
711
  )
712
 
713
  if __name__ == "__main__":
714
  import uvicorn
715
+ # Ensure the upload_dir is writable
716
+ logger.info(f"Checking write permissions for {upload_dir}")
717
+ if not os.access(upload_dir, os.W_OK):
718
+ logger.error(f"No write permissions for {upload_dir}")
719
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)