textclarity / app.py
ganna217's picture
update
589b738
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
import os
import uvicorn
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set cache directory to a writable location
cache_dir = "/tmp/hf_cache"
os.environ["HF_HOME"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
# Create the cache directory if it doesn't exist
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
app = FastAPI()
# Add CORS middleware to allow frontend requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Load the zero-shot classification model with explicit cache directory
logger.info("Loading the model...")
try:
classifier = pipeline(
"zero-shot-classification",
model="UBC-NLP/ARBERTv2", # Switch to a better Arabic model
tokenizer="UBC-NLP/ARBERTv2",
cache_dir=cache_dir
)
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
@app.get("/")
async def index():
logger.info("Serving index.html")
return FileResponse("static/index.html")
@app.post("/classify")
async def classify_text(data: dict):
logger.info(f"Received classify request with data: {data}")
try:
text = data.get("document")
labels = data.get("labels")
if not text or not labels:
logger.warning("Missing text or labels in request")
return {"error": "Please provide both text and labels"}, 400
# Convert labels to list if it's a string
if isinstance(labels, str):
labels = [label.strip() for label in labels.split(",") if label.strip()]
logger.info(f"Classifying text: {text[:50]}... with labels: {labels}")
result = classifier(text, labels, multi_label=False)
logger.info(f"Classification result: {result}")
return {"labels": result["labels"], "scores": result["scores"]}
except Exception as e:
logger.error(f"Error during classification: {str(e)}")
return {"error": str(e)}, 500
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)