nlp_gpt_proj / pages /tgchannels.py
cdxxi's picture
initial commit
1867879
import streamlit as st
from transformers import AutoTokenizer, AutoModel
import torch
from torch import nn
# Загрузка модели и токенизатора (кешируем для ускорения)
@st.cache_resource
def load_model():
MODEL_NAME = "cointegrated/rubert-tiny2"
model = AutoModel.from_pretrained(MODEL_NAME, num_labels=5)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
return model, tokenizer
PATH = "models/model_weight_bert.pt"
class MyTinyBERT(nn.Module):
def __init__(self, model):
super().__init__()
self.bert = model
for param in self.bert.parameters():
param.requires_grad = False
self.linear = nn.Sequential(
nn.Linear(312, 256), nn.Dropout(0.3), nn.ReLU(), nn.Linear(256, 5)
)
def forward(self, input_ids, attention_mask):
bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
normed_bert_out = bert_out.last_hidden_state[:, 0, :]
out = self.linear(normed_bert_out)
return out
def classification_myBERT(text, model, tokenizer):
model = MyTinyBERT(model)
model.load_state_dict(torch.load(PATH, weights_only=True))
model.eval()
my_classes = {0: "Крипта", 1: "Мода", 2: "Спорт", 3: "Технологии", 4: "Финансы"}
t = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
return f'Хоть я и не ChatGPT, осмелюсь предположить, что данный текст относится к следующему классу:\n{my_classes[torch.argmax(model(t["input_ids"], t["attention_mask"])).item()]}'
# Интерфейс Streamlit
def main():
st.markdown(
"<h1 style='text-align: center;'>Классификация тематики новостей из телеграм каналов.</h1>",
unsafe_allow_html=True,
)
st.markdown("---")
col1, col2, col3 = st.columns([1, 8, 1]) # Центральная колонка шире остальных
with col2:
st.markdown(
"<h5 style='text-align: center;'>Использование классического алгоритма</h5>",
unsafe_allow_html=True,
)
# st.text("Использование классического алгоритма")
st.image("./images/Struct.png", width=500)
st.image("./images/L_A.png", width=800)
st.image("./images/C_M.png", width=800)
st.markdown(
"<h5 style='text-align: center;'>Стандартный rubert_tiny2</h5>",
unsafe_allow_html=True,
)
# st.text("Использование классического алгоритма")
st.image("./images/LogReg.png", width=800)
st.markdown(
"<h5 style='text-align: center;'>rubert_tiny2 с обучаемым fc слоем</h5>",
unsafe_allow_html=True,
)
# st.text("Использование классического алгоритма")
st.image("./images/myTinyBERT.png", width=800)
# Загрузка модели
model, tokenizer = load_model()
# Параметры генерации
with st.sidebar:
st.header("Настройки генерации")
prompt = st.text_area("Введите начальный текст:", height=100)
# Кнопка генерации
if st.sidebar.button("Сгенерировать текст"):
if not prompt:
st.warning("Введите начальный текст!")
return
st.subheader("Результаты:")
st.text(classification_myBERT(prompt, model, tokenizer))
if __name__ == "__main__":
main()