Chananchida commited on
Commit
d2ecb95
1 Parent(s): e77b6f6

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +219 -0
  2. data/dataset.xlsx +0 -0
  3. data/embeddings.pkl +3 -0
  4. requirements.txt +13 -0
app.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title web interface demo
2
+ import random
3
+ import gradio as gr
4
+ import time
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import faiss
9
+ from sklearn.preprocessing import normalize
10
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
11
+ from sentence_transformers import SentenceTransformer, util
12
+ from pythainlp import Tokenizer
13
+ import pickle
14
+ import evaluate
15
+ import re
16
+ from pythainlp.tokenize import sent_tokenize
17
+
18
+ DEFAULT_MODEL = 'wangchanberta'
19
+ DEFAULT_SENTENCE_EMBEDDING_MODEL = 'intfloat/multilingual-e5-base'
20
+
21
+ MODEL_DICT = {
22
+ 'wangchanberta': 'Chananchida/wangchanberta-xet_ref-params',
23
+ 'wangchanberta-hyp': 'Chananchida/wangchanberta-xet_hyp-params',
24
+ }
25
+
26
+ EMBEDDINGS_PATH = 'data/embeddings.pkl'
27
+ DATA_PATH='data/dataset.xlsx'
28
+
29
+
30
+ def load_data(path=DATA_PATH):
31
+ df = pd.read_excel(path, sheet_name='Default')
32
+ df['Context'] = pd.read_excel(path, sheet_name='mdeberta')['Context']
33
+ print('Load data done')
34
+ return df
35
+
36
+
37
+ def load_model(model_name=DEFAULT_MODEL):
38
+ model = AutoModelForQuestionAnswering.from_pretrained(MODEL_DICT[model_name])
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DICT[model_name])
40
+ print('Load model done')
41
+ return model, tokenizer
42
+
43
+ def load_embedding_model(model_name=DEFAULT_SENTENCE_EMBEDDING_MODEL):
44
+ if torch.cuda.is_available():
45
+ embedding_model = SentenceTransformer(model_name, device='cuda')
46
+ else:
47
+ embedding_model = SentenceTransformer(model_name)
48
+ print('Load sentence embedding model done')
49
+ return embedding_model
50
+
51
+
52
+ def set_index(vector):
53
+ if torch.cuda.is_available():
54
+ res = faiss.StandardGpuResources()
55
+ index = faiss.IndexFlatL2(vector.shape[1])
56
+ gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index)
57
+ gpu_index_flat.add(vector)
58
+ index = gpu_index_flat
59
+ else:
60
+ index = faiss.IndexFlatL2(vector.shape[1])
61
+ index.add(vector)
62
+ return index
63
+
64
+
65
+ def get_embeddings(embedding_model, text_list):
66
+ return embedding_model.encode(text_list)
67
+
68
+
69
+ def prepare_sentences_vector(encoded_list):
70
+ encoded_list = [i.reshape(1, -1) for i in encoded_list]
71
+ encoded_list = np.vstack(encoded_list).astype('float32')
72
+ encoded_list = normalize(encoded_list)
73
+ return encoded_list
74
+
75
+
76
+ def store_embeddings(df, embeddings):
77
+ with open('embeddings.pkl', "wb") as fOut:
78
+ pickle.dump({'sentences': df['Question'], 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)
79
+ print('Store embeddings done')
80
+
81
+
82
+ def load_embeddings(file_path=EMBEDDINGS_PATH):
83
+ with open(file_path, "rb") as fIn:
84
+ stored_data = pickle.load(fIn)
85
+ stored_sentences = stored_data['sentences']
86
+ stored_embeddings = stored_data['embeddings']
87
+ print('Load (questions) embeddings done')
88
+ return stored_embeddings
89
+
90
+
91
+ def model_pipeline(model, tokenizer, question, similar_context):
92
+ inputs = tokenizer(question, similar_context, return_tensors="pt")
93
+ with torch.no_grad():
94
+ outputs = model(**inputs)
95
+ answer_start_index = outputs.start_logits.argmax()
96
+ answer_end_index = outputs.end_logits.argmax()
97
+ predict_answer_tokens = inputs.input_ids[0, answer_start_index: answer_end_index + 1]
98
+ Answer = tokenizer.decode(predict_answer_tokens)
99
+ return Answer.replace('<unk>','@')
100
+
101
+
102
+ def faiss_search(index, question_vector, k=1):
103
+ distances, indices = index.search(question_vector, k)
104
+ return distances,indices
105
+
106
+
107
+ def predict_faiss(model, tokenizer, embedding_model, df, question, index):
108
+ t = time.time()
109
+ question = question.strip()
110
+ question_vector = get_embeddings(embedding_model, question)
111
+ question_vector = prepare_sentences_vector([question_vector])
112
+ distances,indices = faiss_search(index, question_vector)
113
+ Answers = [df['Answer'][i] for i in indices[0]]
114
+ _time = time.time() - t
115
+ output = {
116
+ "user_question": question,
117
+ "answer": Answers[0],
118
+ "totaltime": round(_time, 3),
119
+ "score": round(distances[0][0], 4)
120
+ }
121
+ return output
122
+
123
+ def predict(model, tokenizer, embedding_model, df, question, index):
124
+ t = time.time()
125
+ question = question.strip()
126
+ question_vector = get_embeddings(embedding_model, question)
127
+ question_vector = prepare_sentences_vector([question_vector])
128
+ distances,indices = faiss_search(index, question_vector)
129
+
130
+ # Answer = model_pipeline(model, tokenizer, df['Question'][indices[0][0]], df['Context'][indices[0][0]])
131
+ Answer = model_pipeline(model, tokenizer, question, df['Context'][indices[0][0]])
132
+ _time = time.time() - t
133
+ output = {
134
+ "user_question": question,
135
+ "answer": Answer,
136
+ "totaltime": round(_time, 3),
137
+ "distance": round(distances[0][0], 4)
138
+ }
139
+ return Answer
140
+
141
+ def predict_test(model, tokenizer, embedding_model, df, question, index): # sent_tokenize pythainlp
142
+ t = time.time()
143
+ question = question.strip()
144
+ question_vector = get_embeddings(embedding_model, question)
145
+ question_vector = prepare_sentences_vector([question_vector])
146
+ distances,indices = faiss_search(index, question_vector)
147
+
148
+ mostSimContext = df['Context'][indices[0][0]]
149
+ pattern = r'(?<=\s{10}).*'
150
+ matches = re.search(pattern, mostSimContext, flags=re.DOTALL)
151
+
152
+ if matches:
153
+ mostSimContext = matches.group(0)
154
+
155
+ mostSimContext = mostSimContext.strip()
156
+ mostSimContext = re.sub(r'\s+', ' ', mostSimContext)
157
+
158
+
159
+ segments = sent_tokenize(mostSimContext, engine="crfcut")
160
+ segments_index = set_index(get_embeddings(embedding_model,segments))
161
+ _distances,_indices = faiss_search(segments_index, question_vector)
162
+ mostSimSegment = segments[_indices[0][0]]
163
+
164
+ Answer = model_pipeline(model, tokenizer,question,mostSimSegment)
165
+
166
+ # Find the start and end indices of mostSimSegment within mostSimContext
167
+ start_index = mostSimContext.find(Answer)
168
+ end_index = start_index + len(Answer)
169
+ _time = time.time() - t
170
+ output = {
171
+ "user_question": question,
172
+ "answer": df['Answer'][indices[0][0]],
173
+ "totaltime": round(_time, 3),
174
+ "distance": round(distances[0][0], 4),
175
+ "highlight_start": start_index,
176
+ "highlight_end": end_index
177
+ }
178
+ return output
179
+
180
+ def highlight_text(text, start_index, end_index):
181
+ highlighted_text = ""
182
+ for i, char in enumerate(text):
183
+ if i == start_index:
184
+ highlighted_text += "<mark>"
185
+ highlighted_text += char
186
+ if i == end_index - 1:
187
+ highlighted_text += "</mark>"
188
+ return highlighted_text
189
+
190
+ def chat_interface_before(question, history):
191
+ response = predict(model, tokenizer, embedding_model, df, question, index)
192
+ return response
193
+
194
+ def chat_interface_after(question, history):
195
+ response = predict_test(model, tokenizer, embedding_model, df, question, index)
196
+ highlighted_answer = highlight_text(response["answer"], response["highlight_start"], response["highlight_end"])
197
+ return highlighted_answer
198
+
199
+ examples=[
200
+ 'อยากทราบความถี่ในการดึงข้อมูลของ DXT360 ในแต่ละแพลตฟอร์ม',
201
+ 'อยากทราบความถี่ในการดึงข้อมูลของ DXT360 บน Twitter',
202
+ 'ช่องทางติดตามข่าวสารของเรา',
203
+ 'ขอช่องทางติดตามข่าวสารทาง Line หน่อย'
204
+ ]
205
+ demo_before = gr.ChatInterface(fn=chat_interface_before,
206
+ examples=examples)
207
+
208
+ demo_after = gr.ChatInterface(fn=chat_interface_after,
209
+ examples=examples)
210
+
211
+ interface = gr.TabbedInterface([demo_before, demo_after], ["Before", "After"])
212
+
213
+ if __name__ == "__main__":
214
+ # Load your model, tokenizer, data, and index here...
215
+ model, tokenizer = load_model('wangchanberta-hyp')
216
+ embedding_model = load_embedding_model()
217
+ df = load_data()
218
+ index = set_index(prepare_sentences_vector(load_embeddings(EMBEDDINGS_PATH)))
219
+ interface.launch()
data/dataset.xlsx ADDED
Binary file (330 kB). View file
 
data/embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6eb3bfadbf8444133238c887c871b8f3dda10d9db57a236868e67dc81bd0dc2c
3
+ size 2380335
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pythainlp
2
+ datasets
3
+ accelerate
4
+ faiss-gpu
5
+ sentence-transformers
6
+ python-crfsuite
7
+ numpy
8
+ pandas
9
+ torch
10
+ transformers
11
+ gensim==4.3.2
12
+ safetensors==0.4.2
13
+ scikit-learn==1.2.2