Spaces:
Runtime error
Runtime error
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import os | |
from dotenv import load_dotenv | |
from langchain_community.llms import CTransformers | |
from langchain_community.llms import HuggingFaceHub | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
model_name = "vilm/vinallama-2.7b-chat-GGUF" | |
# model_file_path = './models/vinallama-7b-chat_q5_0.gguf' | |
model_embedding_name = 'bkai-foundation-models/vietnamese-bi-encoder' | |
vectorDB_path = './db' | |
load_dotenv() | |
huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
# def load_model(model_file_path, | |
# model_type, | |
# temperature=0.01, | |
# context_length=1024, | |
# max_new_tokens=1024 | |
# ): | |
# llm = CTransformers( | |
# model = model_file_path, | |
# model_type = model_type, | |
# max_new_tokens = max_new_tokens, | |
# temperature = temperature, | |
# config = { | |
# 'context_length': context_length, | |
# }, | |
# ) | |
# return llm | |
def load_model(model_name, | |
api_token, | |
temperature=0.01, | |
context_length=1024, | |
max_new_tokens=1024): | |
client = InferenceClient(model=model_name, token=api_token) | |
llm = HuggingFaceHub( | |
client = client, | |
repo_id=model_name, | |
task='text-generation', | |
# max_new_tokens = max_new_tokens, | |
# temperature = temperature, | |
# context_length = context_length, | |
) | |
return llm | |
def load_db(): | |
model_kwargs = {'device': 'cpu'} | |
encode_kwargs = {'normalize_embeddings': False} | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_embedding_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs | |
) | |
db = FAISS.load_local(vectorDB_path, embeddings, allow_dangerous_deserialization=True) | |
return db | |
def create_prompt(template): | |
prompt = PromptTemplate( | |
template=template, | |
input_variables=['context', 'question'], | |
) | |
return prompt | |
def create_chain(llm, | |
prompt, | |
db, | |
top_k_documents=3, | |
return_source_documents=True): | |
chain = RetrievalQA.from_chain_type( | |
llm = llm, | |
chain_type = 'stuff', | |
retriever = db.as_retriever( | |
search_kwargs={ | |
"k": top_k_documents | |
} | |
), | |
return_source_documents = return_source_documents, | |
chain_type_kwargs = { | |
'prompt': prompt, | |
}, | |
) | |
return chain | |
db = load_db() | |
# llm = load_model( | |
# model_file_path=model_file_path, | |
# model_type='llama', | |
# context_length=2048 | |
# ) | |
llm = load_model( | |
model_name=model_name, | |
api_token=huggingfacehub_api_token, | |
context_length=2048 | |
) | |
template = """<|im_start|>system | |
Sử dụng thông tin sau đây để trả lời câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời \n | |
{context}<|im_end|>\n | |
<|im_start|>user\n | |
{question}!<|im_end|>\n | |
<|im_start|>assistant | |
""" | |
prompt = create_prompt(template=template) | |
llm_chain = create_chain(llm, prompt, db) | |
def respond(message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_k_documents, | |
): | |
response = llm_chain.invoke({"query": message}) | |
history.append((message, response['result'])) | |
yield response['result'] | |
demo = gr.ChatInterface( | |
respond, | |
title="Chatbot", | |
additional_inputs=[ | |
# gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=1, maximum=8, value=3, step=1, label="Top k documents to search for answers in", | |
), | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() |