Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from transformers import pipeline | |
| from langdetect import detect | |
| # Загрузка моделей и токенизаторов | |
| def load_models(): | |
| # Модель для классификации категорий | |
| #category_model = AutoModelForSequenceClassification.from_pretrained("./news_classifier") | |
| category_model = AutoModelForSequenceClassification.from_pretrained( | |
| "./news_classifier", | |
| trust_remote_code=True, | |
| device_map=None, # Отключаем автоматическое распределение по устройствам | |
| low_cpu_mem_usage=False, # Отключаем оптимизацию памяти | |
| ) | |
| category_tokenizer = AutoTokenizer.from_pretrained("./news_classifier") | |
| # Модель для детекции фейков | |
| #fake_model = AutoModelForSequenceClassification.from_pretrained("./fake_detector") | |
| fake_model = AutoModelForSequenceClassification.from_pretrained( | |
| "./fake_detector", | |
| trust_remote_code=True, | |
| device_map=None, # Отключаем автоматическое распределение по устройствам | |
| low_cpu_mem_usage=False, # Отключаем оптимизацию памяти | |
| ) | |
| fake_tokenizer = AutoTokenizer.from_pretrained("./fake_detector") | |
| # Модель для перевода с английского на русский | |
| translator = pipeline("translation_en_to_ru", model="Helsinki-NLP/opus-mt-en-ru") | |
| return category_model, category_tokenizer, fake_model, fake_tokenizer, translator | |
| category_model, category_tokenizer, fake_model, fake_tokenizer, translator = load_models() | |
| id_to_category = { | |
| 0: "Климат", | |
| 1: "Конфликты", | |
| 2: "Культура", | |
| 3: "Экономика", | |
| 4: "Обзор", | |
| 5: "Здоровье", | |
| 6: "Политика", | |
| 7: "Наука", | |
| 8: "Общество", | |
| 9: "Спорт", | |
| 10: "Путешествия" | |
| } | |
| def detect_language(text): | |
| try: | |
| return detect(text) | |
| except: | |
| return "unknown" | |
| def translate_text(text): | |
| translation = translator(text, max_length=400) | |
| return translation[0]['translation_text'] | |
| def predict_category(text): | |
| inputs = category_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256) | |
| with torch.no_grad(): | |
| logits = category_model(**inputs).logits | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| # Собираем топ-95% | |
| sorted_probs, sorted_indices = torch.sort(probabilities, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
| top_indices = sorted_indices[cumulative_probs <= 0.95] | |
| if len(top_indices) == 0: # Если даже первая вероятность >95% | |
| top_indices = sorted_indices[:1] | |
| results = [] | |
| for idx in top_indices: | |
| results.append({ | |
| "label": id_to_category[idx.item()], | |
| "score": probabilities[idx].item() | |
| }) | |
| return results | |
| def predict_fake(text): | |
| inputs = fake_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256) | |
| with torch.no_grad(): | |
| logits = fake_model(**inputs).logits | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| is_fake = torch.argmax(probabilities).item() | |
| confidence = probabilities[is_fake].item() | |
| return { | |
| "is_fake": bool(is_fake), | |
| "confidence": confidence, | |
| "label": "Фейк" if is_fake else "Реальная" | |
| } | |
| # Настройка интерфейса Streamlit | |
| st.title("Анализ новостных статей") | |
| st.write("Введите текст статьи для анализа:") | |
| # Поля ввода | |
| input_text = st.text_area("Заголовок статьи / текст статьи", height=200) | |
| # Кнопка для запуска предсказания | |
| if st.button("Анализировать"): | |
| if input_text: | |
| # Определяем язык текста | |
| lang = detect_language(input_text) | |
| translated_text = None | |
| if lang == 'en': | |
| with st.spinner("Переводим текст с английского на русский..."): | |
| translated_text = translate_text(input_text) | |
| st.subheader("Переведенный текст") | |
| st.text_area("Перевод", translated_text, height=200, disabled=True) | |
| # Используем переведенный текст для анализа | |
| text_to_analyze = translated_text | |
| else: | |
| text_to_analyze = input_text | |
| # Получаем предсказания | |
| with st.spinner("Анализируем статью..."): | |
| categories = predict_category(text_to_analyze) | |
| fake_result = predict_fake(text_to_analyze) | |
| # Выводим результаты | |
| st.subheader("Результаты анализа") | |
| # Детекция фейков | |
| fake_color = "red" if fake_result["is_fake"] else "green" | |
| st.markdown(f"**Статус:** <span style='color:{fake_color}'>**{fake_result['label']}**</span> (уверенность: {fake_result['confidence']:.2%})", | |
| unsafe_allow_html=True) | |
| # Категории | |
| st.write("\n**Категории статьи (топ-95% вероятности):**") | |
| for cat in categories: | |
| st.write(f"- {cat['label']}: {cat['score']:.2%}") | |
| else: | |
| st.warning("Пожалуйста, введите заголовок статьи.") |