Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
from torch import nn | |
# Загрузка модели и токенизатора (кешируем для ускорения) | |
def load_model(): | |
MODEL_NAME = "cointegrated/rubert-tiny2" | |
model = AutoModel.from_pretrained(MODEL_NAME, num_labels=5) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
return model, tokenizer | |
PATH = "models/model_weight_bert.pt" | |
class MyTinyBERT(nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.bert = model | |
for param in self.bert.parameters(): | |
param.requires_grad = False | |
self.linear = nn.Sequential( | |
nn.Linear(312, 256), nn.Dropout(0.3), nn.ReLU(), nn.Linear(256, 5) | |
) | |
def forward(self, input_ids, attention_mask): | |
bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
normed_bert_out = bert_out.last_hidden_state[:, 0, :] | |
out = self.linear(normed_bert_out) | |
return out | |
def classification_myBERT(text, model, tokenizer): | |
model = MyTinyBERT(model) | |
model.load_state_dict(torch.load(PATH, weights_only=True)) | |
model.eval() | |
my_classes = {0: "Крипта", 1: "Мода", 2: "Спорт", 3: "Технологии", 4: "Финансы"} | |
t = tokenizer(text, padding=True, truncation=True, return_tensors="pt") | |
return f'Хоть я и не ChatGPT, осмелюсь предположить, что данный текст относится к следующему классу:\n{my_classes[torch.argmax(model(t["input_ids"], t["attention_mask"])).item()]}' | |
# Интерфейс Streamlit | |
def main(): | |
st.markdown( | |
"<h1 style='text-align: center;'>Классификация тематики новостей из телеграм каналов.</h1>", | |
unsafe_allow_html=True, | |
) | |
st.markdown("---") | |
col1, col2, col3 = st.columns([1, 8, 1]) # Центральная колонка шире остальных | |
with col2: | |
st.markdown( | |
"<h5 style='text-align: center;'>Использование классического алгоритма</h5>", | |
unsafe_allow_html=True, | |
) | |
# st.text("Использование классического алгоритма") | |
st.image("./images/Struct.png", width=500) | |
st.image("./images/L_A.png", width=800) | |
st.image("./images/C_M.png", width=800) | |
st.markdown( | |
"<h5 style='text-align: center;'>Стандартный rubert_tiny2</h5>", | |
unsafe_allow_html=True, | |
) | |
# st.text("Использование классического алгоритма") | |
st.image("./images/LogReg.png", width=800) | |
st.markdown( | |
"<h5 style='text-align: center;'>rubert_tiny2 с обучаемым fc слоем</h5>", | |
unsafe_allow_html=True, | |
) | |
# st.text("Использование классического алгоритма") | |
st.image("./images/myTinyBERT.png", width=800) | |
# Загрузка модели | |
model, tokenizer = load_model() | |
# Параметры генерации | |
with st.sidebar: | |
st.header("Настройки генерации") | |
prompt = st.text_area("Введите начальный текст:", height=100) | |
# Кнопка генерации | |
if st.sidebar.button("Сгенерировать текст"): | |
if not prompt: | |
st.warning("Введите начальный текст!") | |
return | |
st.subheader("Результаты:") | |
st.text(classification_myBERT(prompt, model, tokenizer)) | |
if __name__ == "__main__": | |
main() | |