Slava256's picture
11
9b81f33 verified
raw
history blame
4.91 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Загрузка моделей и токенизаторов
@st.cache_resource
def load_models():
# Модель для классификации категорий
#category_model = AutoModelForSequenceClassification.from_pretrained("./news_classifier")
category_model = AutoModelForSequenceClassification.from_pretrained(
"./news_classifier",
trust_remote_code=True,
device_map=None, # Отключаем автоматическое распределение по устройствам
low_cpu_mem_usage=False, # Отключаем оптимизацию памяти
)
category_tokenizer = AutoTokenizer.from_pretrained("./news_classifier")
# Модель для детекции фейков
#fake_model = AutoModelForSequenceClassification.from_pretrained("./fake_detector")
fake_model = AutoModelForSequenceClassification.from_pretrained(
"./fake_detector",
trust_remote_code=True,
device_map=None, # Отключаем автоматическое распределение по устройствам
low_cpu_mem_usage=False, # Отключаем оптимизацию памяти
)
fake_tokenizer = AutoTokenizer.from_pretrained("./fake_detector")
return category_model, category_tokenizer, fake_model, fake_tokenizer
category_model, category_tokenizer, fake_model, fake_tokenizer = load_models()
id_to_category = {
0: "Климат",
1: "Конфликты",
2: "Культура",
3: "Экономика",
4: "Обзор",
5: "Здоровье",
6: "Политика",
7: "Наука",
8: "Общество",
9: "Спорт",
10: "Путешествия"
}
def predict_category(text):
inputs = category_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
with torch.no_grad():
logits = category_model(**inputs).logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
# Собираем топ-95%
sorted_probs, sorted_indices = torch.sort(probabilities, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=0)
top_indices = sorted_indices[cumulative_probs <= 0.95]
if len(top_indices) == 0: # Если даже первая вероятность >95%
top_indices = sorted_indices[:1]
results = []
for idx in top_indices:
results.append({
"label": id_to_category[idx.item()],
"score": probabilities[idx].item()
})
return results
def predict_fake(text):
inputs = fake_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
with torch.no_grad():
logits = fake_model(**inputs).logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
is_fake = torch.argmax(probabilities).item()
confidence = probabilities[is_fake].item()
return {
"is_fake": bool(is_fake),
"confidence": confidence,
"label": "Фейк" if is_fake else "Реальная"
}
# Настройка интерфейса Streamlit
st.title("Анализ новостных статей")
st.write("Введите текст статьи для анализа:")
# Поля ввода
input_title = st.text_area("Заголовок статьи", height=100)
input_text = st.text_area("Текст статьи (опционально)", height=200)
# Кнопка для запуска предсказания
if st.button("Анализировать"):
if input_title:
# Объединяем заголовок и текст, если текст есть
full_text = input_title
if input_text:
full_text += " " + input_text
# Получаем предсказания
with st.spinner("Анализируем статью..."):
categories = predict_category(full_text)
fake_result = predict_fake(full_text)
# Выводим результаты
st.subheader("Результаты анализа")
# Детекция фейков
fake_color = "red" if fake_result["is_fake"] else "green"
st.markdown(f"**Статус:** <span style='color:{fake_color}'>**{fake_result['label']}**</span> (уверенность: {fake_result['confidence']:.2%})",
unsafe_allow_html=True)
# Категории
st.write("\n**Категории статьи (топ-95% вероятности):**")
for cat in categories:
st.write(f"- {cat['label']}: {cat['score']:.2%}")
else:
st.warning("Пожалуйста, введите хотя бы заголовок статьи.")