import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import re models = { "RUSpam/spam_deberta_v4": "RUSpam/spam_deberta_v4", "RUSpam/spamNS_v1": "RUSpam/spamNS_v1" } tokenizers = {} model_instances = {} for name, path in models.items(): tokenizers[name] = AutoTokenizer.from_pretrained(path) model_instances[name] = AutoModelForSequenceClassification.from_pretrained(path) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_instances["RUSpam/spamNS_v1"] = model_instances["RUSpam/spamNS_v1"].to(device).eval() def clean_text(text): text = re.sub(r'http\S+', '', text) text = re.sub(r'[^А-Яа-я0-9 ]+', ' ', text) text = text.lower().strip() return text def predict_spam_deberta(text): tokenizer = tokenizers["RUSpam/spam_deberta_v4"] model = model_instances["RUSpam/spam_deberta_v4"] inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256) input_ids = inputs['input_ids'].to(device) attention_mask = inputs['attention_mask'].to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class = torch.argmax(logits, dim=1).item() result = "Спам" if predicted_class == 1 else "Не спам" return result def predict_spam_spamns(text): tokenizer = tokenizers["RUSpam/spamNS_v1"] model = model_instances["RUSpam/spamNS_v1"] text = clean_text(text) encoding = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt') input_ids = encoding['input_ids'].to(device) attention_mask = encoding['attention_mask'].to(device) with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask).logits pred = torch.sigmoid(outputs).cpu().numpy()[0][0] result = "Спам" if pred >= 0.5 else "Не спам" return result def predict_spam(text, model_choice): if model_choice == "RUSpam/spam_deberta_v4": return predict_spam_deberta(text) elif model_choice == "RUSpam/spamNS_v1": return predict_spam_spamns(text) # Создание интерфейса Gradio iface = gr.Interface( fn=predict_spam, inputs=[ gr.Textbox(lines=5, label="Введите текст"), gr.Radio(choices=list(models.keys()), label="Выберите модель", value="RUSpam/spam_deberta_v4") ], outputs=gr.Label(label="Результат"), title="Определение спама в русскоязычных текстах" ) iface.launch()