arabicstream / app.py
adhsdksdjsbdk's picture
Update app.py
5333011 verified
import os
import torch
import torch.nn as nn
import streamlit as st
from pydantic import BaseModel
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
# Get the token from environment variable (optional)
hf_token = os.environ.get("HF_TOKEN")
# Define model IDs
adapter_model_id = "seniormgt/arabicmgt-test"
base_model_id = "Alibaba-NLP/gte-multilingual-base"
# Define your model
class GTEClassifier(nn.Module):
def __init__(self, model_name=base_model_id):
super(GTEClassifier, self).__init__()
self.base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
self.config = self.base_model.config
self.pooler = nn.Linear(self.config.hidden_size, self.config.hidden_size)
self.pooler_activation = nn.Tanh()
self.dropout = nn.Dropout(0.0)
self.classifier = nn.Linear(self.config.hidden_size, 1)
self.loss_fn = nn.BCEWithLogitsLoss()
def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, **kwargs):
if inputs_embeds is not None:
outputs = self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
else:
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0, :]
pooled_output = self.pooler(pooled_output)
pooled_output = self.pooler_activation(pooled_output)
logits = self.classifier(self.dropout(pooled_output)).squeeze(-1)
loss = None
if labels is not None:
labels = labels.float()
loss = self.loss_fn(logits, labels)
return {"loss": loss, "logits": logits}
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(adapter_model_id, token=hf_token, trust_remote_code=True)
base_model = GTEClassifier()
peft_model = PeftModel.from_pretrained(base_model, adapter_model_id, token=hf_token)
# peft_model.eval()
# Define prediction
def classify_text(text):
inputs = tokenizer(text, max_length=512, padding=True, return_attention_mask=True, return_tensors="pt", truncation=True)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
with torch.no_grad():
outputs = peft_model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs["logits"]
probs = torch.sigmoid(logits).cpu().numpy().squeeze()
pred_label = int(probs >= 0.5)
return {"label": str(pred_label), "confidence": float(probs)}
# 🔹 Streamlit UI
st.title("Text Classification (MGT Detection)")
text = st.text_area("Enter text", height=150)
if st.button("Classify") and text.strip():
result = classify_text(text)
st.json(result)
# 🔹 FastAPI endpoint
app = FastAPI()
class Input(BaseModel):
data: list
@app.post("/predict")
async def predict(request: Request):
payload = await request.json()
text = payload["data"][0]["text"]
result = classify_text(text)
return {"data": [result]}