import gradio as gr from main import tokenizer, model, device import torch import pandas as pd # Загружаем данные из CSV файла df = pd.read_csv("QazSynt_train.csv") def get_random_row(): random_row = df.sample(n=1) return random_row.iloc[0] def qa_pipeline(text, question): # Подготовка входных данных для модели inputs = tokenizer(question, text, return_tensors="pt") input_ids = inputs['input_ids'].to(device) attention_mask = inputs['attention_mask'].to(device) batch = { "input_ids": input_ids, "attention_mask": attention_mask } # Выполнение предсказания start_logits, end_logits, loss = model(batch) start_index = torch.argmax(start_logits, dim=-1).item() end_index = torch.argmax(end_logits, dim=-1).item() # Нахождение индексов начала и конца ответа start_index = torch.argmax(start_logits, dim=-1).item() end_index = torch.argmax(end_logits, dim=-1).item() # Извлечение и декодирование предсказанных токенов ответа predict_answer_tokens = input_ids[0, start_index : end_index + 1] return tokenizer.decode(predict_answer_tokens) def answer_question(context, question): result = qa_pipeline(context, question) return result def get_random_example(): random_row = get_random_row() context = random_row['context'] question = random_row['question'] real_answer = random_row['answer'] predicted_answer = answer_question(context, question) return context, question, real_answer, predicted_answer # Интерфейс Gradio with gr.Blocks() as iface: with gr.Row(): with gr.Column(): context = gr.Textbox(lines=10, label="Context") question = gr.Textbox(lines=2, label="Question") real_answer = gr.Textbox(lines=2, label="Real Answer") with gr.Column(): predicted_answer = gr.Textbox(lines=2, label="Predicted Answer") generate_button = gr.Button("Get Random Example") def update_example(): context_val, question_val, real_answer_val, predicted_answer_val = get_random_example() return context_val, question_val, real_answer_val, predicted_answer_val generate_button.click(update_example, outputs=[context, question, real_answer, predicted_answer]) iface.launch()