Spaces:
Sleeping
Sleeping
File size: 2,461 Bytes
7756b77 21061d1 7756b77 21061d1 31a60dc 7756b77 21061d1 7756b77 21061d1 4634935 21061d1 7756b77 21061d1 61253c2 21061d1 61253c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import gradio as gr
from main import tokenizer, model, device
import torch
import pandas as pd
# Загружаем данные из CSV файла
df = pd.read_csv("QazSynt_train.csv")
def get_random_row():
random_row = df.sample(n=1)
return random_row.iloc[0]
def qa_pipeline(text, question):
# Подготовка входных данных для модели
inputs = tokenizer(question, text, return_tensors="pt")
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
# Выполнение предсказания
start_logits, end_logits, loss = model(batch)
start_index = torch.argmax(start_logits, dim=-1).item()
end_index = torch.argmax(end_logits, dim=-1).item()
# Нахождение индексов начала и конца ответа
start_index = torch.argmax(start_logits, dim=-1).item()
end_index = torch.argmax(end_logits, dim=-1).item()
# Извлечение и декодирование предсказанных токенов ответа
predict_answer_tokens = input_ids[0, start_index : end_index + 1]
return tokenizer.decode(predict_answer_tokens)
def answer_question(context, question):
result = qa_pipeline(context, question)
return result
def get_random_example():
random_row = get_random_row()
context = random_row['context']
question = random_row['question']
real_answer = random_row['answer']
predicted_answer = answer_question(context, question)
return context, question, real_answer, predicted_answer
# Интерфейс Gradio
with gr.Blocks() as iface:
with gr.Row():
with gr.Column():
context = gr.Textbox(lines=10, label="Context")
question = gr.Textbox(lines=2, label="Question")
real_answer = gr.Textbox(lines=2, label="Real Answer")
with gr.Column():
predicted_answer = gr.Textbox(lines=2, label="Predicted Answer")
generate_button = gr.Button("Get Random Example")
def update_example():
context_val, question_val, real_answer_val, predicted_answer_val = get_random_example()
return context_val, question_val, real_answer_val, predicted_answer_val
generate_button.click(update_example, outputs=[context, question, real_answer, predicted_answer])
iface.launch()
|