hellohugg's picture
Upload 5 files
907d348 verified
raw
history blame
1.95 kB
import streamlit as st
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from transformers import Trainer, TrainingArguments
import torch
import json
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np
@st.cache_resource
def load_model(model_path):
model = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model.to("cpu")
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
return model, tokenizer
def predict(text, model, tokenizer, threshold=0.5):
with open("classes.json", "r") as f:
classes = json.load(f)
# Tokenize input
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Move to CPU (if not already)
inputs = {k: v.to("cpu") for k, v in inputs.items()}
# Predict
with torch.no_grad():
outputs = model(**inputs)
# Get logits and apply sigmoid for multi-label classification
logits = outputs.logits
probs = np.array(torch.nn.functional.softmax(logits)[0])
# print(probs) # Convert to probabilities
idx = np.argsort(probs)[::-1]
# print(probs[idx])
tags = []
cumsum = 0.0
ind = 0
while cumsum <= 0.95:
tags.append(classes[idx[ind]])
cumsum += probs[idx[ind]]
ind += 1
return tags
model, tokenizer = load_model("./results/checkpoint-200")
st.title("Multilabel article classification")
st.header("Based on title and summary")
st.text_input("Input title", key="title")
st.text_input("Input summary", key="summary")
if (st.session_state["title"] or st.session_state["summary"]):
query = "TITLE:" + st.session_state['title'] + ", SUMMARY:" + st.session_state['summary']
tags = predict(query, model, tokenizer)
st.text("Predicted Tags:")
st.text(", ".join(tags))