Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# Загрузка моделей и токенизаторов | |
def load_models(): | |
# Модель для классификации категорий | |
#category_model = AutoModelForSequenceClassification.from_pretrained("./news_classifier") | |
category_model = AutoModelForSequenceClassification.from_pretrained( | |
"./news_classifier", | |
trust_remote_code=True, | |
device_map=None, # Отключаем автоматическое распределение по устройствам | |
low_cpu_mem_usage=False, # Отключаем оптимизацию памяти | |
) | |
category_tokenizer = AutoTokenizer.from_pretrained("./news_classifier") | |
# Модель для детекции фейков | |
#fake_model = AutoModelForSequenceClassification.from_pretrained("./fake_detector") | |
fake_model = AutoModelForSequenceClassification.from_pretrained( | |
"./fake_detector", | |
trust_remote_code=True, | |
device_map=None, # Отключаем автоматическое распределение по устройствам | |
low_cpu_mem_usage=False, # Отключаем оптимизацию памяти | |
) | |
fake_tokenizer = AutoTokenizer.from_pretrained("./fake_detector") | |
return category_model, category_tokenizer, fake_model, fake_tokenizer | |
category_model, category_tokenizer, fake_model, fake_tokenizer = load_models() | |
id_to_category = { | |
0: "Климат", | |
1: "Конфликты", | |
2: "Культура", | |
3: "Экономика", | |
4: "Обзор", | |
5: "Здоровье", | |
6: "Политика", | |
7: "Наука", | |
8: "Общество", | |
9: "Спорт", | |
10: "Путешествия" | |
} | |
def predict_category(text): | |
inputs = category_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256) | |
with torch.no_grad(): | |
logits = category_model(**inputs).logits | |
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
# Собираем топ-95% | |
sorted_probs, sorted_indices = torch.sort(probabilities, descending=True) | |
cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
top_indices = sorted_indices[cumulative_probs <= 0.95] | |
if len(top_indices) == 0: # Если даже первая вероятность >95% | |
top_indices = sorted_indices[:1] | |
results = [] | |
for idx in top_indices: | |
results.append({ | |
"label": id_to_category[idx.item()], | |
"score": probabilities[idx].item() | |
}) | |
return results | |
def predict_fake(text): | |
inputs = fake_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256) | |
with torch.no_grad(): | |
logits = fake_model(**inputs).logits | |
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
is_fake = torch.argmax(probabilities).item() | |
confidence = probabilities[is_fake].item() | |
return { | |
"is_fake": bool(is_fake), | |
"confidence": confidence, | |
"label": "Фейк" if is_fake else "Реальная" | |
} | |
# Настройка интерфейса Streamlit | |
st.title("Анализ новостных статей") | |
st.write("Введите текст статьи для анализа:") | |
# Поля ввода | |
input_title = st.text_area("Заголовок статьи", height=100) | |
input_text = st.text_area("Текст статьи (опционально)", height=200) | |
# Кнопка для запуска предсказания | |
if st.button("Анализировать"): | |
if input_title: | |
# Объединяем заголовок и текст, если текст есть | |
full_text = input_title | |
if input_text: | |
full_text += " " + input_text | |
# Получаем предсказания | |
with st.spinner("Анализируем статью..."): | |
categories = predict_category(full_text) | |
fake_result = predict_fake(full_text) | |
# Выводим результаты | |
st.subheader("Результаты анализа") | |
# Детекция фейков | |
fake_color = "red" if fake_result["is_fake"] else "green" | |
st.markdown(f"**Статус:** <span style='color:{fake_color}'>**{fake_result['label']}**</span> (уверенность: {fake_result['confidence']:.2%})", | |
unsafe_allow_html=True) | |
# Категории | |
st.write("\n**Категории статьи (топ-95% вероятности):**") | |
for cat in categories: | |
st.write(f"- {cat['label']}: {cat['score']:.2%}") | |
else: | |
st.warning("Пожалуйста, введите хотя бы заголовок статьи.") |