Spaces:
Runtime error
Runtime error
Commit
·
2b5ef14
1
Parent(s):
fc5f262
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from haystack.document_stores.memory import InMemoryDocumentStore
|
2 |
+
from haystack.nodes import TfidfRetriever, FARMReader
|
3 |
+
from google.colab import drive
|
4 |
+
drive.mount('/content/drive')
|
5 |
+
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
pickle_file = '/content/drive/MyDrive/Group13_NLP_Project/knowledge_graph.pickle'
|
9 |
+
|
10 |
+
# Load the knowledge graph from the pickle file
|
11 |
+
with open(pickle_file, 'rb') as f:
|
12 |
+
knowledge_graph = pickle.load(f)
|
13 |
+
|
14 |
+
print("Knowledge graph loaded from ", pickle_file)
|
15 |
+
|
16 |
+
document_store = InMemoryDocumentStore()
|
17 |
+
node_sentences = {}
|
18 |
+
documents = []
|
19 |
+
nodes = [node for node in knowledge_graph.nodes() if node is not None]
|
20 |
+
|
21 |
+
for node in nodes:
|
22 |
+
# Get all the edges related to the current node
|
23 |
+
related_edges = [edge for edge in knowledge_graph.edges() if edge[0] == node or edge[1] == node]
|
24 |
+
|
25 |
+
# Get the parents and grandparents of the current node
|
26 |
+
parents = [edge[0] for edge in related_edges if edge[1] == node]
|
27 |
+
grandparents = []
|
28 |
+
for parent in parents:
|
29 |
+
grandparents.extend([edge[0] for edge in related_edges if edge[1] == parent])
|
30 |
+
|
31 |
+
# Get the children and grandchildren of the current node
|
32 |
+
children = [edge[1] for edge in related_edges if edge[0] == node]
|
33 |
+
grandchildren = []
|
34 |
+
for child in children:
|
35 |
+
grandchildren.extend([edge[1] for edge in related_edges if edge[0] == child])
|
36 |
+
|
37 |
+
# Create the sentence by combining all the related nodes
|
38 |
+
sentence_parts = grandparents + parents + [node] + children + grandchildren
|
39 |
+
sentence = ' '.join(sentence_parts)
|
40 |
+
|
41 |
+
# Store the sentence for the current node
|
42 |
+
node_sentences[node] = sentence
|
43 |
+
|
44 |
+
# Create the document with the node and the sentence as the content
|
45 |
+
documents.append({'text': node, 'content': sentence})
|
46 |
+
document_store.write_documents(documents)
|
47 |
+
|
48 |
+
#Initialize the retriever
|
49 |
+
retriever = TfidfRetriever(document_store=document_store)
|
50 |
+
#Initialize the reader
|
51 |
+
model_name = "primasr/multilingualbert-for-eqa-finetuned"
|
52 |
+
reader = FARMReader(model_name_or_path=model_name, use_gpu=False)
|
53 |
+
#Create pipeline with the component of retriever and reader
|
54 |
+
from haystack.pipelines import Pipeline
|
55 |
+
pipeline = Pipeline()
|
56 |
+
pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
57 |
+
pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"])
|
58 |
+
|
59 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
60 |
+
# Targeted to Translate English queries to Malay Language
|
61 |
+
# Initialize the tokenizer
|
62 |
+
en_id_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-id")
|
63 |
+
# Initialize the model
|
64 |
+
en_id_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-id")
|
65 |
+
|
66 |
+
# Targeted to Translate Malay Answer to English Language
|
67 |
+
# Initialize the tokenizer
|
68 |
+
id_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-id-en")
|
69 |
+
# Initialize the model
|
70 |
+
id_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-id-en")
|
71 |
+
|
72 |
+
#Defined some pairs for chatbot response
|
73 |
+
pairs = [
|
74 |
+
[
|
75 |
+
"your answer is wrong",
|
76 |
+
"Sorry for providing wrong answer, here is the newest answer:\n\n",
|
77 |
+
"I am sorry that I can't actually answer your question =("
|
78 |
+
],
|
79 |
+
[
|
80 |
+
"jawapan anda adalah salah",
|
81 |
+
"Maaf sedangkan memberi jawapan yang salah. Berikut adalah jawapan yang baru:\n\n",
|
82 |
+
"Minta Maaf, saya tidak boleh menemukan soalan anda =("
|
83 |
+
]]
|
84 |
+
|
85 |
+
#Function for checking reiterate providing answer for same question
|
86 |
+
def checkReiterateQuery(query,lang):
|
87 |
+
if query in [pairs[0][0],pairs[1][0]]:
|
88 |
+
if lang == 'en':
|
89 |
+
j = 0
|
90 |
+
else:
|
91 |
+
j = 1
|
92 |
+
return True, j
|
93 |
+
|
94 |
+
else:
|
95 |
+
return False , 3
|
96 |
+
|
97 |
+
import gradio as gr
|
98 |
+
from langdetect import detect
|
99 |
+
import warnings
|
100 |
+
warnings.filterwarnings('ignore')
|
101 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
102 |
+
|
103 |
+
|
104 |
+
chat_history = []
|
105 |
+
answer_counter = 0
|
106 |
+
def chatbot_interface(message):
|
107 |
+
global answer_counter
|
108 |
+
global result
|
109 |
+
|
110 |
+
# Append the current message to the chat history
|
111 |
+
chat_history.append(message)
|
112 |
+
lang = detect(message)
|
113 |
+
reiterate, j = checkReiterateQuery(message, lang)
|
114 |
+
|
115 |
+
#If user want to re-iterate the answer for same question
|
116 |
+
if reiterate:
|
117 |
+
answer_counter = answer_counter + 1
|
118 |
+
if answer_counter < 5:
|
119 |
+
retrieved_main_answer = pairs[j][1] + result['answers'][answer_counter].answer
|
120 |
+
retrieved_main_context = result['answers'][answer_counter].context
|
121 |
+
else:
|
122 |
+
retrieved_main_answer = pairs[j][2]
|
123 |
+
retrieved_main_context = ""
|
124 |
+
else:
|
125 |
+
answer_counter = 0
|
126 |
+
#if language is english then convert it to malay language
|
127 |
+
if lang == "en":
|
128 |
+
tokenized_text = en_id_tokenizer.prepare_seq2seq_batch([message], return_tensors='pt')
|
129 |
+
translation = en_id_model.generate(**tokenized_text)
|
130 |
+
message = en_id_tokenizer.batch_decode(translation, skip_special_tokens=True)[0]
|
131 |
+
|
132 |
+
result = pipeline.run(query=message.lower(), params={
|
133 |
+
"Retriever": {"top_k": 5},
|
134 |
+
"Reader": {"top_k": 5}})
|
135 |
+
retrieved_main_answer = result['answers'][answer_counter].answer
|
136 |
+
retrieved_main_context = result['answers'][answer_counter].context
|
137 |
+
|
138 |
+
response = retrieved_main_answer + ", " + retrieved_main_context
|
139 |
+
|
140 |
+
#Convert the response to english if user ask question in english
|
141 |
+
if lang == "en":
|
142 |
+
tokenized_text = id_en_tokenizer.prepare_seq2seq_batch([response.lower()], return_tensors='pt')
|
143 |
+
translation = id_en_model.generate(**tokenized_text)
|
144 |
+
response = id_en_tokenizer.batch_decode(translation, skip_special_tokens=True)[0]
|
145 |
+
|
146 |
+
# Append the response to the chat history
|
147 |
+
chat_history.append(response)
|
148 |
+
|
149 |
+
# Join the chat history with newline characters
|
150 |
+
chat_history_text = "\n\n".join(chat_history)
|
151 |
+
|
152 |
+
return response, chat_history_text
|
153 |
+
|
154 |
+
# Create a Gradio interface
|
155 |
+
iface = gr.Interface(
|
156 |
+
fn=chatbot_interface,
|
157 |
+
inputs=gr.inputs.Textbox(label="Please Type Your Question Here: "),
|
158 |
+
outputs=[gr.outputs.Textbox(label="Answers"), gr.outputs.Textbox(label="Chat History")],
|
159 |
+
description="## Question Answering system\n\nIt supports **English** and **Bahasa Malaysia**.",
|
160 |
+
allow_flagging = False
|
161 |
+
)
|
162 |
+
|
163 |
+
#Demo for the chatbot
|
164 |
+
iface.launch(inline = False)
|