Spaces:
Sleeping
Sleeping
import gradio as gr | |
from main import tokenizer, model, device | |
import torch | |
import pandas as pd | |
# Загружаем данные из CSV файла | |
df = pd.read_csv("QazSynt_train.csv") | |
def get_random_row(): | |
random_row = df.sample(n=1) | |
return random_row.iloc[0] | |
def qa_pipeline(text, question): | |
# Подготовка входных данных для модели | |
inputs = tokenizer(question, text, return_tensors="pt") | |
input_ids = inputs['input_ids'].to(device) | |
attention_mask = inputs['attention_mask'].to(device) | |
batch = { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask | |
} | |
# Выполнение предсказания | |
start_logits, end_logits, loss = model(batch) | |
start_index = torch.argmax(start_logits, dim=-1).item() | |
end_index = torch.argmax(end_logits, dim=-1).item() | |
# Нахождение индексов начала и конца ответа | |
start_index = torch.argmax(start_logits, dim=-1).item() | |
end_index = torch.argmax(end_logits, dim=-1).item() | |
# Извлечение и декодирование предсказанных токенов ответа | |
predict_answer_tokens = input_ids[0, start_index : end_index + 1] | |
return tokenizer.decode(predict_answer_tokens) | |
def answer_question(context, question): | |
result = qa_pipeline(context, question) | |
return result | |
def get_random_example(): | |
random_row = get_random_row() | |
context = random_row['context'] | |
question = random_row['question'] | |
real_answer = random_row['answer'] | |
predicted_answer = answer_question(context, question) | |
return context, question, real_answer, predicted_answer | |
# Интерфейс Gradio | |
with gr.Blocks() as iface: | |
with gr.Row(): | |
with gr.Column(): | |
context = gr.Textbox(lines=10, label="Context") | |
question = gr.Textbox(lines=2, label="Question") | |
real_answer = gr.Textbox(lines=2, label="Real Answer") | |
with gr.Column(): | |
predicted_answer = gr.Textbox(lines=2, label="Predicted Answer") | |
generate_button = gr.Button("Get Random Example") | |
def update_example(): | |
context_val, question_val, real_answer_val, predicted_answer_val = get_random_example() | |
return context_val, question_val, real_answer_val, predicted_answer_val | |
generate_button.click(update_example, outputs=[context, question, real_answer, predicted_answer]) | |
iface.launch() | |