Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -21,6 +21,9 @@ from pydantic import BaseModel
|
|
21 |
import asyncio
|
22 |
import google.generativeai as genai
|
23 |
|
|
|
|
|
|
|
24 |
# Configure logging
|
25 |
logging.basicConfig(
|
26 |
level=logging.INFO,
|
@@ -70,7 +73,10 @@ def load_model(task: str, model_name: str = None):
|
|
70 |
model_to_load = model_name or MODELS.get(task)
|
71 |
|
72 |
if task in ["chatbot", "translation"]:
|
73 |
-
|
|
|
|
|
|
|
74 |
|
75 |
if task == "visual-qa":
|
76 |
processor = ViltProcessor.from_pretrained(model_to_load)
|
@@ -91,9 +97,12 @@ def load_model(task: str, model_name: str = None):
|
|
91 |
logger.info(f"VQA raw output: {answer}")
|
92 |
return answer
|
93 |
|
|
|
94 |
return vqa_function
|
95 |
|
96 |
-
|
|
|
|
|
97 |
|
98 |
except Exception as e:
|
99 |
logger.error(f"Model load failed: {str(e)}")
|
@@ -676,7 +685,16 @@ async def list_models():
|
|
676 |
async def startup_event():
|
677 |
"""Pre-load models at startup with timeout"""
|
678 |
logger.info("Starting model pre-loading...")
|
679 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
680 |
async def load_model_with_timeout(task):
|
681 |
try:
|
682 |
await asyncio.wait_for(load_model(task), timeout=60.0)
|
@@ -685,15 +703,17 @@ async def startup_event():
|
|
685 |
logger.warning(f"Timeout loading {task} model - will load on demand")
|
686 |
except Exception as e:
|
687 |
logger.error(f"Error pre-loading {task}: {str(e)}")
|
688 |
-
|
689 |
await asyncio.gather(
|
690 |
load_model_with_timeout("summarization"),
|
691 |
load_model_with_timeout("image-to-text"),
|
692 |
-
load_model_with_timeout("visual-qa")
|
693 |
-
load_model_with_timeout("chatbot"),
|
694 |
-
load_model_with_timeout("translation")
|
695 |
)
|
696 |
|
697 |
if __name__ == "__main__":
|
698 |
import uvicorn
|
699 |
-
|
|
|
|
|
|
|
|
|
|
21 |
import asyncio
|
22 |
import google.generativeai as genai
|
23 |
|
24 |
+
# Set the TRANSFORMERS_CACHE environment variable to a writable directory
|
25 |
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
|
26 |
+
|
27 |
# Configure logging
|
28 |
logging.basicConfig(
|
29 |
level=logging.INFO,
|
|
|
73 |
model_to_load = model_name or MODELS.get(task)
|
74 |
|
75 |
if task in ["chatbot", "translation"]:
|
76 |
+
logger.info(f"Initializing Gemini model: {model_to_load}")
|
77 |
+
model = genai.GenerativeModel(model_to_load)
|
78 |
+
logger.info(f"Gemini model loaded in {time.time() - start_time:.2f}s")
|
79 |
+
return model
|
80 |
|
81 |
if task == "visual-qa":
|
82 |
processor = ViltProcessor.from_pretrained(model_to_load)
|
|
|
97 |
logger.info(f"VQA raw output: {answer}")
|
98 |
return answer
|
99 |
|
100 |
+
logger.info(f"Visual QA model loaded in {time.time() - start_time:.2f}s")
|
101 |
return vqa_function
|
102 |
|
103 |
+
model = pipeline(task, model=model_to_load)
|
104 |
+
logger.info(f"Pipeline model loaded in {time.time() - start_time:.2f}s")
|
105 |
+
return model
|
106 |
|
107 |
except Exception as e:
|
108 |
logger.error(f"Model load failed: {str(e)}")
|
|
|
685 |
async def startup_event():
|
686 |
"""Pre-load models at startup with timeout"""
|
687 |
logger.info("Starting model pre-loading...")
|
688 |
+
|
689 |
+
# Load Gemini models synchronously
|
690 |
+
for task in ["chatbot", "translation"]:
|
691 |
+
try:
|
692 |
+
load_model(task) # Synchronous call
|
693 |
+
logger.info(f"Successfully loaded {task} model")
|
694 |
+
except Exception as e:
|
695 |
+
logger.error(f"Error pre-loading {task}: {str(e)}")
|
696 |
+
|
697 |
+
# Load Hugging Face models asynchronously
|
698 |
async def load_model_with_timeout(task):
|
699 |
try:
|
700 |
await asyncio.wait_for(load_model(task), timeout=60.0)
|
|
|
703 |
logger.warning(f"Timeout loading {task} model - will load on demand")
|
704 |
except Exception as e:
|
705 |
logger.error(f"Error pre-loading {task}: {str(e)}")
|
706 |
+
|
707 |
await asyncio.gather(
|
708 |
load_model_with_timeout("summarization"),
|
709 |
load_model_with_timeout("image-to-text"),
|
710 |
+
load_model_with_timeout("visual-qa")
|
|
|
|
|
711 |
)
|
712 |
|
713 |
if __name__ == "__main__":
|
714 |
import uvicorn
|
715 |
+
# Ensure the upload_dir is writable
|
716 |
+
logger.info(f"Checking write permissions for {upload_dir}")
|
717 |
+
if not os.access(upload_dir, os.W_OK):
|
718 |
+
logger.error(f"No write permissions for {upload_dir}")
|
719 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|