nrjvarshney commited on
Commit
555714b
1 Parent(s): 9aeb0e3

Adding app file

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wikipedia
2
+ import transformers
3
+ import spacy
4
+ from transformers import AutoModelWithLMHead, AutoTokenizer
5
+ import random
6
+ import gradio as gr
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
9
+ model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
10
+ nlp = spacy.load("en_core_web_sm")
11
+
12
+ def get_question(answer, context, max_length=64):
13
+ input_text = "answer: %s context: %s </s>" % (answer, context)
14
+ features = tokenizer([input_text], return_tensors='pt')
15
+
16
+ output = model.generate(input_ids=features['input_ids'],
17
+ attention_mask=features['attention_mask'],
18
+ max_length=max_length)
19
+
20
+ return tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
21
+
22
+ import gradio as gr
23
+
24
+ def greet(topic):
25
+ print("Entered topic: ", topic)
26
+ topics = wikipedia.search(topic)
27
+ random.shuffle(topics)
28
+ for topic in topics:
29
+ try:
30
+ summary = wikipedia.summary(topic)
31
+ except wikipedia.DisambiguationError as e:
32
+ # print(e.options)
33
+ s = random.choice(e.options)
34
+ summary = wikipedia.summary(s)
35
+ except wikipedia.PageError as e:
36
+ continue
37
+ break
38
+ print("Selected topic: ", topic)
39
+ print("Summary: ", summary)
40
+ summary = summary.replace("\n", "")
41
+ doc = nlp(summary)
42
+
43
+ answers = doc.ents
44
+ filtered_answers = []
45
+ for answer in answers:
46
+ if(answer.text in topic or topic in answer.text):
47
+ pass
48
+ else:
49
+ filtered_answers.append(answer)
50
+
51
+ answer_1 = random.choice(filtered_answers)
52
+ question_1 = get_question(answer_1, summary)
53
+ question_1 = question_1[9:]
54
+ print("Question: ", question_1)
55
+ print("Answer: ", answer_1)
56
+ return [question_1, gr.update(visible=True), gr.update(value=answer_1, visible=False)]
57
+
58
+
59
+ def get_answer(input_answer, gold_answer):
60
+ print("Entered Answer: ", input_answer)
61
+ return gr.update(value=gold_answer, visible=True)
62
+
63
+
64
+ with gr.Blocks() as demo:
65
+ # with gr.Row():
66
+ topic = gr.Textbox(label="Topic")
67
+ greet_btn = gr.Button("Ask a Question")
68
+ question = gr.Textbox(label="Question")
69
+ input_answer = gr.Textbox(label="Your Answer", visible=False)
70
+ answer_btn = gr.Button("Show Answer")
71
+ gold_answer = gr.Textbox(label="Correct Answer", visible=False)
72
+ greet_btn.click(fn=greet, inputs=topic, outputs=[question, input_answer, gold_answer])
73
+
74
+ # with gr.Row():
75
+
76
+ answer_btn.click(fn=get_answer, inputs=[input_answer,gold_answer], outputs=gold_answer)
77
+
78
+ demo.launch()
79
+ # demo.launch(share=True)