Spaces:
Sleeping
Sleeping
File size: 2,592 Bytes
f7df7d4 d6791c3 612cae5 d6791c3 612cae5 d6791c3 f7df7d4 612cae5 f7df7d4 d6791c3 f7df7d4 612cae5 589b738 612cae5 f7df7d4 d6791c3 b6ac895 612cae5 589b738 612cae5 d6791c3 f7df7d4 612cae5 f7df7d4 d6791c3 612cae5 d6791c3 612cae5 d6791c3 612cae5 0ad264e 612cae5 d6791c3 612cae5 0ad264e d6791c3 612cae5 d6791c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
import os
import uvicorn
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set cache directory to a writable location
cache_dir = "/tmp/hf_cache"
os.environ["HF_HOME"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
# Create the cache directory if it doesn't exist
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
app = FastAPI()
# Add CORS middleware to allow frontend requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Load the zero-shot classification model with explicit cache directory
logger.info("Loading the model...")
try:
classifier = pipeline(
"zero-shot-classification",
model="UBC-NLP/ARBERTv2", # Switch to a better Arabic model
tokenizer="UBC-NLP/ARBERTv2",
cache_dir=cache_dir
)
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
@app.get("/")
async def index():
logger.info("Serving index.html")
return FileResponse("static/index.html")
@app.post("/classify")
async def classify_text(data: dict):
logger.info(f"Received classify request with data: {data}")
try:
text = data.get("document")
labels = data.get("labels")
if not text or not labels:
logger.warning("Missing text or labels in request")
return {"error": "Please provide both text and labels"}, 400
# Convert labels to list if it's a string
if isinstance(labels, str):
labels = [label.strip() for label in labels.split(",") if label.strip()]
logger.info(f"Classifying text: {text[:50]}... with labels: {labels}")
result = classifier(text, labels, multi_label=False)
logger.info(f"Classification result: {result}")
return {"labels": result["labels"], "scores": result["scores"]}
except Exception as e:
logger.error(f"Error during classification: {str(e)}")
return {"error": str(e)}, 500
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port) |