import streamlit as st import torch from loguru import logger from shad_mlops_transformers.model import DocumentClassifier from shad_mlops_transformers.trainer import load_mapper # tokenizer = AutoTokenizer.from_pretrained("Davlan/distilbert-base-multilingual-cased-ner-hrl") # model = AutoModelForTokenClassification.from_pretrained("Davlan/distilbert-base-multilingual-cased-ner-hrl") # nlp = pipeline("ner", model=model, tokenizer=tokenizer) @st.cache_resource def load_model(): # NOTE hardcoded return DocumentClassifier(n_classes=68).from_file() mapper = load_mapper() if __name__ == "__main__": model = load_model() st.markdown("### Predict tags for article summary") # st.markdown("", unsafe_allow_html=True) text = st.text_input("Enter your summary") raw_predictions = model(text) best_class = torch.argmax(raw_predictions, dim=1) inverse_mapper = {v: k for k, v in mapper.items()} key = best_class.item() to_show = inverse_mapper.get(key, "unknown") logger.debug(f"key={key}, to_show={to_show}") st.markdown(f"predicted label: {to_show}")