Spaces:
Sleeping
Sleeping
File size: 4,196 Bytes
caf3054 c52929a caf3054 b2b7ea1 caf3054 b2b7ea1 caf3054 b2b7ea1 caf3054 b2b7ea1 caf3054 b2b7ea1 2e5824a caf3054 2e5824a caf3054 8b04727 0f46926 b2b7ea1 0f46926 b2b7ea1 0f46926 2e5824a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import time
import numpy as np
import pandas as pd
import gradio as gr
import torch
import faiss
from sklearn.preprocessing import normalize
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from sentence_transformers import SentenceTransformer, util
from pythainlp import Tokenizer
import pickle
import re
from pythainlp.tokenize import sent_tokenize
from unstructured.partition.html import partition_html
DEFAULT_MODEL = 'wangchanberta'
DEFAULT_SENTENCE_EMBEDDING_MODEL = 'intfloat/multilingual-e5-base'
MODEL_DICT = {
'wangchanberta': 'Chananchida/wangchanberta-xet_ref-params',
'wangchanberta-hyp': 'Chananchida/wangchanberta-xet_hyp-params',
}
def load_model(model_name=DEFAULT_MODEL):
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_DICT[model_name])
tokenizer = AutoTokenizer.from_pretrained(MODEL_DICT[model_name])
print('Load model done')
return model, tokenizer
def load_embedding_model(model_name=DEFAULT_SENTENCE_EMBEDDING_MODEL):
if torch.cuda.is_available():
embedding_model = SentenceTransformer(model_name, device='cuda')
else:
embedding_model = SentenceTransformer(model_name)
print('Load sentence embedding model done')
return embedding_model
def set_index(vector):
if torch.cuda.is_available():
res = faiss.StandardGpuResources()
index = faiss.IndexFlatL2(vector.shape[1])
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index)
gpu_index_flat.add(vector)
index = gpu_index_flat
else:
index = faiss.IndexFlatL2(vector.shape[1])
index.add(vector)
return index
def get_embeddings(embedding_model, text_list):
return embedding_model.encode(text_list)
def prepare_sentences_vector(encoded_list):
encoded_list = [i.reshape(1, -1) for i in encoded_list]
encoded_list = np.vstack(encoded_list).astype('float32')
encoded_list = normalize(encoded_list)
return encoded_list
def faiss_search(index, question_vector, k=1):
distances, indices = index.search(question_vector, k)
return distances,indices
def model_pipeline(model, tokenizer, question, context):
inputs = tokenizer(question, context, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = inputs.input_ids[0, answer_start_index: answer_end_index + 1]
Answer = tokenizer.decode(predict_answer_tokens)
return Answer.replace('<unk>','@')
def predict_test(embedding_model, context, question, index, url):
t = time.time()
question = question.strip()
question_vector = get_embeddings(embedding_model, question)
question_vector = prepare_sentences_vector([question_vector])
distances, indices = faiss_search(index, question_vector, 3)
most_similar_contexts = ''
for i in range(3):
most_sim_context = context[indices[0][i]].strip()
answer_url = f"{url}#:~:text={most_sim_context}"
# encoded_url = urllib.parse.quote(answer_url)
most_similar_contexts += f'<a href="{answer_url}">[ {i+1} ]: {most_sim_context}</a>\n\n'
print(most_similar_contexts)
return most_similar_contexts
if __name__ == "__main__":
url = "https://www.dataxet.co/media-landscape/2024-th"
elements = partition_html(url=url)
context = [str(element) for element in elements if len(str(element)) >60]
embedding_model = load_embedding_model()
index = set_index(prepare_sentences_vector(get_embeddings(embedding_model, context)))
def chat_interface(question, history):
response = predict_test(embedding_model, context, question, index, url)
return response
examples=['ภูมิทัศน์สื่อไทยในปี 2567 มีแนวโน้มว่า ',
'Fragmentation คือ',
'ติ๊กต๊อก คือ',
'รายงานจาก Reuters Institute'
]
interface = gr.ChatInterface(fn=chat_interface,
examples=examples)
interface.launch() |