dappyx commited on
Commit
21061d1
·
verified ·
1 Parent(s): 7756b77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -41
app.py CHANGED
@@ -1,53 +1,68 @@
1
  import gradio as gr
2
  from main import tokenizer, model, device
3
  import torch
 
4
 
5
- def qa_pipeline(text,question):
6
- inputs = tokenizer(question, text, return_tensors="pt")
7
- input_ids = inputs['input_ids'].to(device)
8
- attention_mask = inputs['attention_mask'].to(device)
9
- batch = {
10
- "input_ids": input_ids,
11
- "attention_mask": attention_mask
12
- }
13
- start_logits, end_logits, loss = model(batch)
14
 
15
- start_index = torch.argmax(start_logits, dim=-1).item()
16
- end_index = torch.argmax(end_logits, dim=-1).item()
 
17
 
18
- predict_answer_tokens = inputs.input_ids[0, start_index : end_index + 1]
19
- return tokenizer.decode(predict_answer_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def answer_question(context, question):
22
  result = qa_pipeline(context, question)
23
  return result
24
 
25
- example_contexts = [
26
- "Қазақстанның ұлттық құрамы алуан түрлі. Халықтың басым бөлігін тұрғылықты қазақ халқы құрайды, пайыздық үлесі — 70,18%[10], орыстар — 18,42%, өзбектер — 3,29%, украиндар — 1,36%, ұйғырлар — 1,48%, татарлар — 1,06%, басқа халықтар 5,38%.[11] Халықтың 75% астамын мұсылмандар құрайды, православты христиандар — 21%, қалғаны басқа да дін өкілдері.[12]",
27
- "Қазақстан бес мемлекетпен шекаралас, соның ішінде әлемдегі құрлықтағы ең ұзын шекара, солтүстігінде және батысында Ресеймен — 7591 км құрайды. Оңтүстігінде: Түрікменстан — 426 км, Өзбекстан — 2354 км және Қырғызстан — 1241 км, ал шығысында: Қытаймен — 1782 км шектеседі. Жалпы құрлық шекарасының ұзындығы — 13394 км. Батыста Каспий көлімен (2000 км), оңтүстік батыста Арал теңізімен шайылады.[9] 2024 жылдың 1 наурыздағы елдегі тұрғындар саны — 20 075 271[4], бұл әлем бойынша 64-орын. Жер көлемі жағынан әлем елдерінің ішінде 9-орын алады (2 724 902 км²).",
28
- "Қазақстан 1995 жылғы 30 тамыздағы республикалық референдумда қабылданған Конституция бойынша — өзін демократиялы, зайырлы, құқықты және әлеуметті мемлекет ретінде орнықтырды. Қазақстан Республикасы – президенттік басқару формасындағы біртұтас мемлекет. Республиканың ең жоғарғы өкілді органы — Парламент. Ол республиканың заң шығару құзіретін жүзеге асырады."
29
- ]
30
- example_questions = [
31
- "Қазақстанның халқы неше пайызды қазақтар құрайды?",
32
- "Қазақстан нешеу мемлекетпен шекаралас?",
33
- "Қазақстандағы басқару формасы қандай?",
34
- ]
35
-
36
-
37
- examples = [[context, question] for context, question in zip(example_contexts, example_questions)]
38
-
39
- # Создаем интерфейс
40
- iface = gr.Interface(
41
- fn=answer_question,
42
- inputs=[
43
- gr.Textbox(lines=10, label="Context"),
44
- gr.Textbox(lines=2, label="Question")
45
- ],
46
- outputs="text",
47
- title="Question Answering Model",
48
- description="Введите контекст и задайте вопрос, чтобы получить ответ.",
49
- examples=examples
50
- )
51
-
52
- # Запускаем интерфейс
53
  iface.launch()
 
1
  import gradio as gr
2
  from main import tokenizer, model, device
3
  import torch
4
+ import pandas as pd
5
 
6
+ # Загружаем данные из CSV файла
7
+ df = pd.read_csv("train.csv")
 
 
 
 
 
 
 
8
 
9
+ def get_random_row():
10
+ random_row = df.sample(n=1)
11
+ return random_row.iloc[0]
12
 
13
+ def qa_pipeline(text, question):
14
+ # Подготовка входных данных для модели
15
+ inputs = tokenizer(question, text, return_tensors="pt")
16
+ input_ids = inputs['input_ids'].to(device)
17
+ attention_mask = inputs['attention_mask'].to(device)
18
+ batch = {
19
+ "input_ids": input_ids,
20
+ "attention_mask": attention_mask
21
+ }
22
+
23
+ # Выполнение предсказания
24
+ with torch.no_grad():
25
+ outputs = model(**batch)
26
+
27
+ # Извлечение логитов начала и конца ответа
28
+ start_logits = outputs.start_logits
29
+ end_logits = outputs.end_logits
30
+
31
+ # Нахождение индексов начала и конца ответа
32
+ start_index = torch.argmax(start_logits, dim=-1).item()
33
+ end_index = torch.argmax(end_logits, dim=-1).item()
34
+
35
+ # Извлечение и декодирование предсказанных токенов ответа
36
+ predict_answer_tokens = input_ids[0, start_index : end_index + 1]
37
+ return tokenizer.decode(predict_answer_tokens)
38
 
39
  def answer_question(context, question):
40
  result = qa_pipeline(context, question)
41
  return result
42
 
43
+ def get_random_example():
44
+ random_row = get_random_row()
45
+ context = random_row['context']
46
+ question = random_row['question']
47
+ real_answer = random_row['answer']
48
+ predicted_answer = answer_question(context, question)
49
+ return context, question, real_answer, predicted_answer
50
+
51
+ # Интерфейс Gradio
52
+ with gr.Blocks() as iface:
53
+ with gr.Row():
54
+ with gr.Column():
55
+ context = gr.Textbox(lines=10, label="Context")
56
+ question = gr.Textbox(lines=2, label="Question")
57
+ real_answer = gr.Textbox(lines=2, label="Real Answer")
58
+ with gr.Column():
59
+ predicted_answer = gr.Textbox(lines=2, label="Predicted Answer")
60
+ generate_button = gr.Button("Get Random Example")
61
+
62
+ def update_example():
63
+ context_val, question_val, real_answer_val, predicted_answer_val = get_random_example()
64
+ return context_val, question_val, real_answer_val, predicted_answer_val
65
+
66
+ generate_button.click(update_example, outputs=[context, question, real_answer, predicted_answer])
67
+
 
 
 
68
  iface.launch()