phishapi / app.py
eric707's picture
Create app.py
0ebba2e verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Initialize FastAPI app
app = FastAPI()
# Load Hugging Face model and tokenizer
MODEL_NAME = "ealvaradob/bert-finetuned-phishing"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# Define input structure
class TextInput(BaseModel):
text: str
@app.post("/predict")
def predict_spam(input_data: TextInput):
# Tokenize input text
inputs = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Perform prediction
with torch.no_grad():
outputs = model(**inputs)
# Get classification result
prediction = torch.argmax(outputs.logits, dim=1).item()
# Return response
return {
"text": input_data.text,
"prediction": "Phishing Email" if prediction == 1 else "Not Phishing Email"
}
# Root Endpoint
@app.get("/")
def home():
return {"message": "Welcome to the Spam Classifier API!"}