|
import os |
|
from huggingface_hub import InferenceClient |
|
import gradio as gr |
|
import nltk |
|
import torch |
|
from transformers import DistilBertTokenizer, DistilBertModel |
|
from duckduckgo_search import ddg |
|
from langchain.chains import RetrievalQA |
|
from langchain.document_loaders import UnstructuredFileLoader |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.vectorstores import Chroma |
|
from transformers import DistilBertConfig, DistilBertModel |
|
|
|
|
|
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") |
|
embedding_model_name = "distilbert/distilbert-base-uncased-finetuned-sst-2-english" |
|
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
|
|
|
qwen_text_gen = InferenceClient("HuggingFaceH4/zephyr-7b-beta") |
|
|
|
|
|
def search_web(query): |
|
results = ddg(query) |
|
web_content = '' |
|
if results: |
|
for result in results: |
|
web_content += result['body'] |
|
return web_content |
|
|
|
|
|
def init_knowledge_vector_store(file): |
|
if file is None: |
|
return |
|
filepath = file.name |
|
distilbert_embedding = HuggingFaceBgeEmbeddings(model_name=embedding_model_name) |
|
loader = UnstructuredFileLoader(filepath, mode="elements") |
|
docs = loader.load() |
|
Chroma.from_documents(docs, distilbert_embedding, persist_directory="./vector_store") |
|
|
|
|
|
def get_knowledge_vector_store(): |
|
distilbert_embedding = HuggingFaceBgeEmbeddings(model_name=embedding_model_name) |
|
vector_store = Chroma(embedding_function=distilbert_embedding, persist_directory="./vector_store") |
|
return vector_store |
|
|
|
|
|
def get_knowledge_based_answer(query, qwen_text_gen, vector_store, VECTOR_SEARCH_TOP_K, web_content): |
|
if web_content: |
|
prompt_template = f"""Answer the user's question based on the following known information. |
|
Known web search content: {web_content} """ + """ |
|
Known Content: |
|
{context} |
|
question: |
|
{question}""" |
|
else: |
|
prompt_template = """Answer the user's question based on the known information. |
|
Known Content: |
|
{context} |
|
question: |
|
{question}""" |
|
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) |
|
|
|
knowledge_chain = RetrievalQA.from_llm( |
|
llm=qwen_text_gen, |
|
retriever=vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}), |
|
prompt=prompt |
|
) |
|
|
|
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( |
|
input_variables=["page_content"], |
|
template="{page_content}" |
|
) |
|
|
|
knowledge_chain.return_source_documents = True |
|
|
|
result = knowledge_chain.invoke({"query": query}) |
|
|
|
return result['result'] |
|
|
|
|
|
def clear_session(): |
|
return '', None |
|
|
|
|
|
def predict(input, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, key=None, history=None): |
|
if history == None: |
|
history = [] |
|
vector_store = get_knowledge_vector_store() |
|
if use_web == 'True': |
|
web_content = search_web(query=input) |
|
if web_content is None: |
|
web_content = "" |
|
else: |
|
web_content = '' |
|
|
|
resp = get_knowledge_based_answer( |
|
query=input, |
|
qwen_text_gen=qwen_text_gen, |
|
vector_store=vector_store, |
|
VECTOR_SEARCH_TOP_K=VECTOR_SEARCH_TOP_K, |
|
web_content=web_content, |
|
) |
|
history.append((input, resp)) |
|
return '', history, history |
|
|
|
|
|
block = gr.Blocks() |
|
with block as demo: |
|
gr.Markdown("<h1><center>Chat History </center></h1>") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
file = gr.File(label='Please upload txt, md, docx type files', file_types=['.txt', '.md', '.docx']) |
|
get_vs = gr.Button("Generate Knowledge Base") |
|
get_vs.click(init_knowledge_vector_store, inputs=[file]) |
|
|
|
use_web = gr.Radio(["True", "False"], label="Web Search", value="False") |
|
|
|
VECTOR_SEARCH_TOP_K = gr.Slider(1, 10, value=5, step=1, label="vector search top k", interactive=True) |
|
|
|
with gr.Column(scale=4): |
|
chatbot = gr.Chatbot(label='Ming History Knowledge Question and Answer Assistant', height=600) |
|
message = gr.Textbox(label='Please enter your question') |
|
state = gr.State() |
|
|
|
with gr.Row(): |
|
clear_history = gr.Button("Clear history conversation") |
|
send = gr.Button("Send") |
|
send.click(predict, |
|
inputs=[message, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, state], |
|
outputs=[message, chatbot, state]) |
|
clear_history.click(fn=clear_session, inputs=[], outputs=[chatbot, state], queue=False) |
|
|
|
message.submit(predict, |
|
inputs=[message, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, state], |
|
outputs=[message, chatbot, state]) |
|
|
|
demo.queue().launch(share=False) |