Spaces:
Running
Running
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # For development - you should restrict this in production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load model and tokenizer | |
model_name = "fakespot-ai/roberta-base-ai-text-detection-v1" | |
#model_name = "SuperAnnotate/ai-detector" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
class TextRequest(BaseModel): | |
text: str | |
async def predict(request: TextRequest): | |
try: | |
# Tokenize the input text | |
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=512) | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
# Get the probability scores | |
human_prob = predictions[0][0].item() | |
ai_prob = predictions[0][1].item() | |
return { | |
"text": request.text, | |
"human_probability": round(human_prob * 100, 2), | |
"ai_probability": round(ai_prob * 100, 2), | |
"prediction": "AI-generated" if ai_prob > human_prob else "Human-written" | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
return {"message": "AI Text Detection API is running"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |