File size: 1,182 Bytes
7379bd2
743cb9b
 
39cd2c1
743cb9b
 
7379bd2
743cb9b
 
 
7379bd2
 
743cb9b
 
 
98b2c27
39cd2c1
743cb9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import streamlit as st
import torch
from loguru import logger

from shad_mlops_transformers.model import DocumentClassifier
from shad_mlops_transformers.trainer import load_mapper

# tokenizer = AutoTokenizer.from_pretrained("Davlan/distilbert-base-multilingual-cased-ner-hrl")
# model = AutoModelForTokenClassification.from_pretrained("Davlan/distilbert-base-multilingual-cased-ner-hrl")
# nlp = pipeline("ner", model=model, tokenizer=tokenizer)


@st.cache_resource
def load_model():
    # NOTE hardcoded
    return DocumentClassifier(n_classes=68).from_file()


mapper = load_mapper()

if __name__ == "__main__":
    model = load_model()
    st.markdown("### Predict tags for article summary")
    # st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
    text = st.text_input("Enter your summary")
    raw_predictions = model(text)
    best_class = torch.argmax(raw_predictions, dim=1)
    inverse_mapper = {v: k for k, v in mapper.items()}
    key = best_class.item()
    to_show = inverse_mapper.get(key, "unknown")
    logger.debug(f"key={key}, to_show={to_show}")
    st.markdown(f"predicted label: {to_show}")