import gradio as gr from huggingface_hub import InferenceClient import os from dotenv import load_dotenv from langchain_community.llms import CTransformers from langchain_community.llms import HuggingFaceHub from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings model_name = "vilm/vinallama-2.7b-chat-GGUF" # model_file_path = './models/vinallama-7b-chat_q5_0.gguf' model_embedding_name = 'bkai-foundation-models/vietnamese-bi-encoder' vectorDB_path = './db' load_dotenv() huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") # def load_model(model_file_path, # model_type, # temperature=0.01, # context_length=1024, # max_new_tokens=1024 # ): # llm = CTransformers( # model = model_file_path, # model_type = model_type, # max_new_tokens = max_new_tokens, # temperature = temperature, # config = { # 'context_length': context_length, # }, # ) # return llm def load_model(model_name, api_token, temperature=0.01, context_length=1024, max_new_tokens=1024): client = InferenceClient(model=model_name, token=api_token) llm = HuggingFaceHub( client = client, repo_id=model_name, # task='text-generation', # max_new_tokens = max_new_tokens, # temperature = temperature, # context_length = context_length, ) return llm def load_db(): model_kwargs = {'device': 'cpu'} encode_kwargs = {'normalize_embeddings': False} embeddings = HuggingFaceEmbeddings( model_name=model_embedding_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs ) db = FAISS.load_local(vectorDB_path, embeddings, allow_dangerous_deserialization=True) return db def create_prompt(template): prompt = PromptTemplate( template=template, input_variables=['context', 'question'], ) return prompt def create_chain(llm, prompt, db, top_k_documents=3, return_source_documents=True): chain = RetrievalQA.from_chain_type( llm = llm, chain_type = 'stuff', retriever = db.as_retriever( search_kwargs={ "k": top_k_documents } ), return_source_documents = return_source_documents, chain_type_kwargs = { 'prompt': prompt, }, ) return chain db = load_db() # llm = load_model( # model_file_path=model_file_path, # model_type='llama', # context_length=2048 # ) llm = load_model( model_name=model_name, api_token=huggingfacehub_api_token, context_length=2048 ) template = """<|im_start|>system Sử dụng thông tin sau đây để trả lời câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời \n {context}<|im_end|>\n <|im_start|>user\n {question}!<|im_end|>\n <|im_start|>assistant """ prompt = create_prompt(template=template) llm_chain = create_chain(llm, prompt, db) def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_k_documents, ): response = llm_chain.invoke({"query": message}) history.append((message, response['result'])) yield response['result'] demo = gr.ChatInterface( respond, title="Chatbot", additional_inputs=[ # gr.Textbox(value="You are a friendly Chatbot.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=1, maximum=8, value=3, step=1, label="Top k documents to search for answers in", ), ], ) if __name__ == "__main__": demo.launch()