Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torch.nn as nn | |
import streamlit as st | |
from pydantic import BaseModel | |
from fastapi import FastAPI, Request | |
from transformers import AutoTokenizer, AutoModel | |
from peft import PeftModel | |
# Get the token from environment variable (optional) | |
hf_token = os.environ.get("HF_TOKEN") | |
# Define model IDs | |
adapter_model_id = "seniormgt/arabicmgt-test" | |
base_model_id = "Alibaba-NLP/gte-multilingual-base" | |
# Define your model | |
class GTEClassifier(nn.Module): | |
def __init__(self, model_name=base_model_id): | |
super(GTEClassifier, self).__init__() | |
self.base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True) | |
self.config = self.base_model.config | |
self.pooler = nn.Linear(self.config.hidden_size, self.config.hidden_size) | |
self.pooler_activation = nn.Tanh() | |
self.dropout = nn.Dropout(0.0) | |
self.classifier = nn.Linear(self.config.hidden_size, 1) | |
self.loss_fn = nn.BCEWithLogitsLoss() | |
def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, **kwargs): | |
if inputs_embeds is not None: | |
outputs = self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask) | |
else: | |
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) | |
pooled_output = outputs.last_hidden_state[:, 0, :] | |
pooled_output = self.pooler(pooled_output) | |
pooled_output = self.pooler_activation(pooled_output) | |
logits = self.classifier(self.dropout(pooled_output)).squeeze(-1) | |
loss = None | |
if labels is not None: | |
labels = labels.float() | |
loss = self.loss_fn(logits, labels) | |
return {"loss": loss, "logits": logits} | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(adapter_model_id, token=hf_token, trust_remote_code=True) | |
base_model = GTEClassifier() | |
peft_model = PeftModel.from_pretrained(base_model, adapter_model_id, token=hf_token) | |
# peft_model.eval() | |
# Define prediction | |
def classify_text(text): | |
inputs = tokenizer(text, max_length=512, padding=True, return_attention_mask=True, return_tensors="pt", truncation=True) | |
input_ids = inputs['input_ids'] | |
attention_mask = inputs['attention_mask'] | |
with torch.no_grad(): | |
outputs = peft_model(input_ids=input_ids, attention_mask=attention_mask) | |
logits = outputs["logits"] | |
probs = torch.sigmoid(logits).cpu().numpy().squeeze() | |
pred_label = int(probs >= 0.5) | |
return {"label": str(pred_label), "confidence": float(probs)} | |
# 🔹 Streamlit UI | |
st.title("Text Classification (MGT Detection)") | |
text = st.text_area("Enter text", height=150) | |
if st.button("Classify") and text.strip(): | |
result = classify_text(text) | |
st.json(result) | |
# 🔹 FastAPI endpoint | |
app = FastAPI() | |
class Input(BaseModel): | |
data: list | |
async def predict(request: Request): | |
payload = await request.json() | |
text = payload["data"][0]["text"] | |
result = classify_text(text) | |
return {"data": [result]} | |