|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
MODEL_NAME = "ealvaradob/bert-finetuned-phishing" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) |
|
|
|
|
|
class TextInput(BaseModel): |
|
text: str |
|
|
|
@app.post("/predict") |
|
def predict_spam(input_data: TextInput): |
|
|
|
inputs = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
prediction = torch.argmax(outputs.logits, dim=1).item() |
|
|
|
|
|
return { |
|
"text": input_data.text, |
|
"prediction": "Phishing Email" if prediction == 1 else "Not Phishing Email" |
|
} |
|
|
|
|
|
@app.get("/") |
|
def home(): |
|
return {"message": "Welcome to the Spam Classifier API!"} |
|
|