Spaces:
Sleeping
Sleeping
import streamlit as st | |
from joblib import load | |
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
from tensorflow.keras.models import load_model | |
import tensorflow as tf | |
from tensorflow.keras.preprocessing.text import Tokenizer | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
import time | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
import textwrap | |
tok = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') | |
model_checkpoint = 'cointegrated/rubert-tiny-toxicity' | |
toxicity_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
toxicity_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) | |
clf = load('my_model_filename.pkl') | |
vectorizer = load('tfidf_vectorizer.pkl') | |
scaler = load('scaler.joblib') | |
tukinazor = load('tokenizer.pkl') | |
rnn_model = load_model('path_to_my_model.h5') | |
bert_model = BertForSequenceClassification.from_pretrained('my_bert_model') | |
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased') | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
bert_model = bert_model.to(device) | |
model_finetuned = GPT2LMHeadModel.from_pretrained('GPT_2') | |
labels = ["не токсичный", "оскорбляющий", "непристойный", "угрожающий", "опасный"] | |
def text2toxicity(text, aggregate=True): | |
""" Calculate toxicity of a text (if aggregate=True) or a vector of toxicity aspects (if aggregate=False)""" | |
with torch.no_grad(): | |
inputs = toxicity_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(toxicity_model.device) | |
proba = torch.sigmoid(toxicity_model(**inputs).logits).cpu().numpy() | |
if isinstance(text, str): | |
proba = proba[0] | |
if aggregate: | |
return 1 - proba.T[0] * (1 - proba.T[-1]) | |
else: | |
result = {} | |
for label, prob in zip(labels, proba): | |
result[label] = prob | |
return result | |
def predict_text(text): | |
sequences = tukinazor.texts_to_sequences([text]) | |
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=200, padding='post', truncating='post') | |
predictions = rnn_model.predict(padded_sequences) | |
predicted_class = tf.argmax(predictions, axis=-1).numpy()[0] | |
return predicted_class | |
def generate_text(model, prompt, max_length=150, temperature=1.0, num_beams=10, top_k=600, top_p=0.75, no_repeat_ngram_size=1, num_return_sequences=1): | |
input_ids = tok.encode(prompt, return_tensors='pt').to(device) | |
with torch.inference_mode(): | |
output = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
num_beams=num_beams, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
num_return_sequences=num_return_sequences | |
) | |
texts = [textwrap.fill(tok.decode(out), 60) for out in output] | |
return "\n------------------\n".join(texts) | |
def page_reviews_classification(): | |
st.title("Модель классификации отзывов") | |
st.image("ramsey.jpg", use_column_width=True) | |
user_input = st.text_area("Введите текст отзыва:") | |
if st.button("Классифицировать"): | |
start_time = time.time() | |
user_input_vec = vectorizer.transform([user_input]) | |
sentence_vector_scaled = scaler.transform(user_input_vec) | |
prediction = clf.predict( | |
sentence_vector_scaled) | |
elapsed_time = time.time() - start_time | |
st.write(f"Прогнозируемый класс: {prediction[0]}") | |
st.write(f"Время вычисления: {elapsed_time:.2f} сек.") | |
user_input_rnn = st.text_area("Введите текст отзыва для RNN модели:") | |
if st.button("Классифицировать с RNN"): | |
start_time = time.time() | |
prediction_rnn = predict_text(user_input_rnn) | |
elapsed_time = time.time() - start_time | |
st.write(f"Прогнозируемый класс с RNN: {prediction_rnn}") | |
st.write(f"Время вычисления: {elapsed_time:.2f} сек.") | |
user_input_bert = st.text_area("Введите текст отзыва для BERT:") | |
if st.button("Классифицировать (BERT)"): | |
start_time = time.time() | |
encoding = tokenizer.encode_plus( | |
user_input_bert, | |
add_special_tokens=True, | |
max_length=200, | |
return_token_type_ids=False, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask) | |
predictions = torch.argmax(outputs.logits, dim=1) | |
elapsed_time = time.time() - start_time | |
st.write(f"Прогнозируемый класс (BERT): {predictions.item() + 1}") | |
st.write(f"Время вычисления: {elapsed_time:.2f} сек.") | |
def page_toxicity_analysis(): | |
st.title("Оценка текста на токсичность") | |
st.image("scale_1200.webp", use_column_width=True) | |
user_input_toxicity = st.text_area("Введите текст для оценки токсичности:") | |
if st.button("Оценить токсичность"): | |
start_time = time.time() | |
probs = text2toxicity(user_input_toxicity, aggregate=False) | |
elapsed_time = time.time() - start_time | |
for label, prob in probs.items(): | |
st.write(f"Вероятность того что комментарий {label}: {prob:.4f}") | |
def page_gpt_generation(): | |
st.title("Генерация текста с помощью GPT-модели") | |
st.image("noize-mc-2_50163340_orig_.jpg", use_column_width=True) | |
user_prompt = st.text_area("Введите ваш текст:") | |
sequence_length = st.slider("Длина последовательности:", min_value=10, max_value=1000, value=150, step=10) | |
num_generations = st.slider("Число генераций:", min_value=1, max_value=10, value=1) | |
temperature = st.slider("Температура:", min_value=0.1, max_value=3.0, value=1.0, step=0.1) | |
if st.button("Генерировать"): | |
for _ in range(num_generations): | |
generated_text = generate_text(model_finetuned, user_prompt, sequence_length, temperature) | |
st.text(generated_text) | |
def main(): | |
page_selection = st.sidebar.selectbox("Выберите страницу:", ["Классификация отзывов", "Анализ токсичности","Генерация текста Noize MC"]) | |
if page_selection == "Классификация отзывов": | |
page_reviews_classification() | |
elif page_selection == "Анализ токсичности": | |
page_toxicity_analysis() | |
elif page_selection == "Генерация текста Noize MC": | |
page_gpt_generation() | |
if __name__ == "__main__": | |
main() |