File size: 2,592 Bytes
f7df7d4
 
d6791c3
612cae5
d6791c3
 
 
612cae5
 
 
 
 
d6791c3
f7df7d4
 
 
 
612cae5
f7df7d4
 
 
 
d6791c3
f7df7d4
612cae5
 
 
 
589b738
612cae5
 
 
 
 
f7df7d4
d6791c3
b6ac895
612cae5
 
 
 
589b738
 
612cae5
 
 
 
 
 
d6791c3
f7df7d4
 
612cae5
f7df7d4
d6791c3
 
 
612cae5
d6791c3
 
 
 
612cae5
d6791c3
612cae5
 
 
0ad264e
612cae5
 
d6791c3
612cae5
0ad264e
d6791c3
612cae5
d6791c3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
import os
import uvicorn
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set cache directory to a writable location
cache_dir = "/tmp/hf_cache"
os.environ["HF_HOME"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir

# Create the cache directory if it doesn't exist
if not os.path.exists(cache_dir):
    os.makedirs(cache_dir, exist_ok=True)

app = FastAPI()

# Add CORS middleware to allow frontend requests
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.mount("/static", StaticFiles(directory="static"), name="static")

# Load the zero-shot classification model with explicit cache directory
logger.info("Loading the model...")
try:
    classifier = pipeline(
        "zero-shot-classification",
        model="UBC-NLP/ARBERTv2",  # Switch to a better Arabic model
        tokenizer="UBC-NLP/ARBERTv2",
        cache_dir=cache_dir
    )
    logger.info("Model loaded successfully!")
except Exception as e:
    logger.error(f"Error loading model: {str(e)}")
    raise

@app.get("/")
async def index():
    logger.info("Serving index.html")
    return FileResponse("static/index.html")

@app.post("/classify")
async def classify_text(data: dict):
    logger.info(f"Received classify request with data: {data}")
    try:
        text = data.get("document")
        labels = data.get("labels")
        if not text or not labels:
            logger.warning("Missing text or labels in request")
            return {"error": "Please provide both text and labels"}, 400
        
        # Convert labels to list if it's a string
        if isinstance(labels, str):
            labels = [label.strip() for label in labels.split(",") if label.strip()]
        
        logger.info(f"Classifying text: {text[:50]}... with labels: {labels}")
        result = classifier(text, labels, multi_label=False)
        logger.info(f"Classification result: {result}")
        return {"labels": result["labels"], "scores": result["scores"]}
    except Exception as e:
        logger.error(f"Error during classification: {str(e)}")
        return {"error": str(e)}, 500

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)