Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import tensorflow as tf | |
import joblib | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
# Load the model and tokenizer from Hugging Face Hub | |
model_path = hf_hub_download(repo_id="rio3210/amharic-hate-speech-using-rnn-bidirectional", filename="amharic_hate_speech_rnn_model.keras") | |
tokenizer_path = hf_hub_download(repo_id="rio3210/amharic-hate-speech-using-rnn-bidirectional", filename="tokenizer.joblib") | |
# Load the Keras model | |
keras_model = tf.keras.models.load_model(model_path) | |
# Load the tokenizer | |
tokenizer = joblib.load(tokenizer_path) | |
# Define the FastAPI application | |
app = FastAPI() | |
# Setup CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define the request body schema | |
class ClassifyRequest(BaseModel): | |
text: str | |
# Preprocessing function | |
def preprocess_text(text: str, tokenizer, max_length: int = 100): | |
sequences = tokenizer.texts_to_sequences([text]) # Tokenize the input text | |
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences( | |
sequences, maxlen=max_length, padding="post", truncating="post" | |
) | |
return padded_sequences | |
# Classification route | |
def classify_text(request_body: ClassifyRequest): | |
text = request_body.text | |
processed_text = preprocess_text(text, tokenizer) # Preprocess the input text | |
prediction = keras_model.predict(processed_text) # Predict using the Keras model | |
label = "Hate" if prediction[0] > 0.5 else "Free" # Threshold-based classification | |
confidence = float(prediction[0]) # Get confidence score | |
# Return the result | |
response = {"label": label, "confidence": confidence} | |
return JSONResponse(content=response, status_code=201) | |
# Root route | |
def home(): | |
return {"hello": "world"} | |