web-qa / app.py
Chananchida's picture
Update app.py
b2b7ea1 verified
raw
history blame
4.2 kB
import time
import numpy as np
import pandas as pd
import gradio as gr
import torch
import faiss
from sklearn.preprocessing import normalize
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from sentence_transformers import SentenceTransformer, util
from pythainlp import Tokenizer
import pickle
import re
from pythainlp.tokenize import sent_tokenize
from unstructured.partition.html import partition_html
DEFAULT_MODEL = 'wangchanberta'
DEFAULT_SENTENCE_EMBEDDING_MODEL = 'intfloat/multilingual-e5-base'
MODEL_DICT = {
'wangchanberta': 'Chananchida/wangchanberta-xet_ref-params',
'wangchanberta-hyp': 'Chananchida/wangchanberta-xet_hyp-params',
}
def load_model(model_name=DEFAULT_MODEL):
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_DICT[model_name])
tokenizer = AutoTokenizer.from_pretrained(MODEL_DICT[model_name])
print('Load model done')
return model, tokenizer
def load_embedding_model(model_name=DEFAULT_SENTENCE_EMBEDDING_MODEL):
if torch.cuda.is_available():
embedding_model = SentenceTransformer(model_name, device='cuda')
else:
embedding_model = SentenceTransformer(model_name)
print('Load sentence embedding model done')
return embedding_model
def set_index(vector):
if torch.cuda.is_available():
res = faiss.StandardGpuResources()
index = faiss.IndexFlatL2(vector.shape[1])
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index)
gpu_index_flat.add(vector)
index = gpu_index_flat
else:
index = faiss.IndexFlatL2(vector.shape[1])
index.add(vector)
return index
def get_embeddings(embedding_model, text_list):
return embedding_model.encode(text_list)
def prepare_sentences_vector(encoded_list):
encoded_list = [i.reshape(1, -1) for i in encoded_list]
encoded_list = np.vstack(encoded_list).astype('float32')
encoded_list = normalize(encoded_list)
return encoded_list
def faiss_search(index, question_vector, k=1):
distances, indices = index.search(question_vector, k)
return distances,indices
def model_pipeline(model, tokenizer, question, context):
inputs = tokenizer(question, context, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = inputs.input_ids[0, answer_start_index: answer_end_index + 1]
Answer = tokenizer.decode(predict_answer_tokens)
return Answer.replace('<unk>','@')
def predict_test(embedding_model, context, question, index, url):
t = time.time()
question = question.strip()
question_vector = get_embeddings(embedding_model, question)
question_vector = prepare_sentences_vector([question_vector])
distances, indices = faiss_search(index, question_vector, 3)
most_similar_contexts = ''
for i in range(3):
most_sim_context = context[indices[0][i]].strip()
answer_url = f"{url}#:~:text={most_sim_context}"
# encoded_url = urllib.parse.quote(answer_url)
most_similar_contexts += f'<a href="{answer_url}">[ {i+1} ]: {most_sim_context}</a>\n\n'
print(most_similar_contexts)
return most_similar_contexts
if __name__ == "__main__":
url = "https://www.dataxet.co/media-landscape/2024-th"
elements = partition_html(url=url)
context = [str(element) for element in elements if len(str(element)) >60]
embedding_model = load_embedding_model()
index = set_index(prepare_sentences_vector(get_embeddings(embedding_model, context)))
def chat_interface(question, history):
response = predict_test(embedding_model, context, question, index, url)
return response
examples=['ภูมิทัศน์สื่อไทยในปี 2567 มีแนวโน้มว่า ',
'Fragmentation คือ',
'ติ๊กต๊อก คือ',
'รายงานจาก Reuters Institute'
]
interface = gr.ChatInterface(fn=chat_interface,
examples=examples)
interface.launch()