import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from transformers import pipeline from langdetect import detect # Загрузка моделей и токенизаторов @st.cache_resource def load_models(): # Модель для классификации категорий #category_model = AutoModelForSequenceClassification.from_pretrained("./news_classifier") category_model = AutoModelForSequenceClassification.from_pretrained( "./news_classifier", trust_remote_code=True, device_map=None, # Отключаем автоматическое распределение по устройствам low_cpu_mem_usage=False, # Отключаем оптимизацию памяти ) category_tokenizer = AutoTokenizer.from_pretrained("./news_classifier") # Модель для детекции фейков #fake_model = AutoModelForSequenceClassification.from_pretrained("./fake_detector") fake_model = AutoModelForSequenceClassification.from_pretrained( "./fake_detector", trust_remote_code=True, device_map=None, # Отключаем автоматическое распределение по устройствам low_cpu_mem_usage=False, # Отключаем оптимизацию памяти ) fake_tokenizer = AutoTokenizer.from_pretrained("./fake_detector") # Модель для перевода с английского на русский translator = pipeline("translation_en_to_ru", model="Helsinki-NLP/opus-mt-en-ru") return category_model, category_tokenizer, fake_model, fake_tokenizer, translator category_model, category_tokenizer, fake_model, fake_tokenizer, translator = load_models() id_to_category = { 0: "Климат", 1: "Конфликты", 2: "Культура", 3: "Экономика", 4: "Обзор", 5: "Здоровье", 6: "Политика", 7: "Наука", 8: "Общество", 9: "Спорт", 10: "Путешествия" } def detect_language(text): try: return detect(text) except: return "unknown" def translate_text(text): translation = translator(text, max_length=400) return translation[0]['translation_text'] def predict_category(text): inputs = category_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256) with torch.no_grad(): logits = category_model(**inputs).logits probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] # Собираем топ-95% sorted_probs, sorted_indices = torch.sort(probabilities, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=0) top_indices = sorted_indices[cumulative_probs <= 0.95] if len(top_indices) == 0: # Если даже первая вероятность >95% top_indices = sorted_indices[:1] results = [] for idx in top_indices: results.append({ "label": id_to_category[idx.item()], "score": probabilities[idx].item() }) return results def predict_fake(text): inputs = fake_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256) with torch.no_grad(): logits = fake_model(**inputs).logits probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] is_fake = torch.argmax(probabilities).item() confidence = probabilities[is_fake].item() return { "is_fake": bool(is_fake), "confidence": confidence, "label": "Фейк" if is_fake else "Реальная" } # Настройка интерфейса Streamlit st.title("Анализ новостных статей") st.write("Введите текст статьи для анализа:") # Поля ввода input_text = st.text_area("Заголовок статьи / текст статьи", height=200) # Кнопка для запуска предсказания if st.button("Анализировать"): if input_text: # Определяем язык текста lang = detect_language(input_text) translated_text = None if lang == 'en': with st.spinner("Переводим текст с английского на русский..."): translated_text = translate_text(input_text) st.subheader("Переведенный текст") st.text_area("Перевод", translated_text, height=200, disabled=True) # Используем переведенный текст для анализа text_to_analyze = translated_text else: text_to_analyze = input_text # Получаем предсказания with st.spinner("Анализируем статью..."): categories = predict_category(text_to_analyze) fake_result = predict_fake(text_to_analyze) # Выводим результаты st.subheader("Результаты анализа") # Детекция фейков fake_color = "red" if fake_result["is_fake"] else "green" st.markdown(f"**Статус:** **{fake_result['label']}** (уверенность: {fake_result['confidence']:.2%})", unsafe_allow_html=True) # Категории st.write("\n**Категории статьи (топ-95% вероятности):**") for cat in categories: st.write(f"- {cat['label']}: {cat['score']:.2%}") else: st.warning("Пожалуйста, введите заголовок статьи.")