dappyx's picture
Update app.py
61253c2 verified
import gradio as gr
from main import tokenizer, model, device
import torch
import pandas as pd
# Загружаем данные из CSV файла
df = pd.read_csv("QazSynt_train.csv")
def get_random_row():
random_row = df.sample(n=1)
return random_row.iloc[0]
def qa_pipeline(text, question):
# Подготовка входных данных для модели
inputs = tokenizer(question, text, return_tensors="pt")
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
# Выполнение предсказания
start_logits, end_logits, loss = model(batch)
start_index = torch.argmax(start_logits, dim=-1).item()
end_index = torch.argmax(end_logits, dim=-1).item()
# Нахождение индексов начала и конца ответа
start_index = torch.argmax(start_logits, dim=-1).item()
end_index = torch.argmax(end_logits, dim=-1).item()
# Извлечение и декодирование предсказанных токенов ответа
predict_answer_tokens = input_ids[0, start_index : end_index + 1]
return tokenizer.decode(predict_answer_tokens)
def answer_question(context, question):
result = qa_pipeline(context, question)
return result
def get_random_example():
random_row = get_random_row()
context = random_row['context']
question = random_row['question']
real_answer = random_row['answer']
predicted_answer = answer_question(context, question)
return context, question, real_answer, predicted_answer
# Интерфейс Gradio
with gr.Blocks() as iface:
with gr.Row():
with gr.Column():
context = gr.Textbox(lines=10, label="Context")
question = gr.Textbox(lines=2, label="Question")
real_answer = gr.Textbox(lines=2, label="Real Answer")
with gr.Column():
predicted_answer = gr.Textbox(lines=2, label="Predicted Answer")
generate_button = gr.Button("Get Random Example")
def update_example():
context_val, question_val, real_answer_val, predicted_answer_val = get_random_example()
return context_val, question_val, real_answer_val, predicted_answer_val
generate_button.click(update_example, outputs=[context, question, real_answer, predicted_answer])
iface.launch()