Spaces:
Sleeping
Sleeping
File size: 1,182 Bytes
7379bd2 743cb9b 39cd2c1 743cb9b 7379bd2 743cb9b 7379bd2 743cb9b 98b2c27 39cd2c1 743cb9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import streamlit as st
import torch
from loguru import logger
from shad_mlops_transformers.model import DocumentClassifier
from shad_mlops_transformers.trainer import load_mapper
# tokenizer = AutoTokenizer.from_pretrained("Davlan/distilbert-base-multilingual-cased-ner-hrl")
# model = AutoModelForTokenClassification.from_pretrained("Davlan/distilbert-base-multilingual-cased-ner-hrl")
# nlp = pipeline("ner", model=model, tokenizer=tokenizer)
@st.cache_resource
def load_model():
# NOTE hardcoded
return DocumentClassifier(n_classes=68).from_file()
mapper = load_mapper()
if __name__ == "__main__":
model = load_model()
st.markdown("### Predict tags for article summary")
# st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
text = st.text_input("Enter your summary")
raw_predictions = model(text)
best_class = torch.argmax(raw_predictions, dim=1)
inverse_mapper = {v: k for k, v in mapper.items()}
key = best_class.item()
to_show = inverse_mapper.get(key, "unknown")
logger.debug(f"key={key}, to_show={to_show}")
st.markdown(f"predicted label: {to_show}")
|