leolaish commited on
Commit
4e0fb6d
·
verified ·
1 Parent(s): bddb19b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -52
app.py CHANGED
@@ -1,64 +1,138 @@
1
- import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- textbox_css = {"border": "2px solid blue"} # Customize the border style here
46
- demo = gr.ChatInterface(
47
- respond, title="MediPro",
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="MediPro",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
1
+ import os
2
  from huggingface_hub import InferenceClient
3
+ import gradio as gr
4
+ import nltk
5
+ import torch
6
+ from transformers import DistilBertTokenizer, DistilBertModel
7
+ from duckduckgo_search import ddg
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.document_loaders import UnstructuredFileLoader
10
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
11
+ from langchain.vectorstores import Chroma
12
+ from transformers import DistilBertConfig, DistilBertModel
13
+
14
+ # Initialize tokenizer and model for embedding
15
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
16
+ embedding_model_name = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
18
+
19
+ # Load Qwen 2 for text generation
20
+ qwen_text_gen = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
21
 
22
+ # Function to search the web
23
+ def search_web(query):
24
+ results = ddg(query)
25
+ web_content = ''
26
+ if results:
27
+ for result in results:
28
+ web_content += result['body']
29
+ return web_content
30
 
31
+ # Function to initialize knowledge vector store
32
+ def init_knowledge_vector_store(file):
33
+ if file is None:
34
+ return
35
+ filepath = file.name
36
+ distilbert_embedding = HuggingFaceBgeEmbeddings(model_name=embedding_model_name)
37
+ loader = UnstructuredFileLoader(filepath, mode="elements")
38
+ docs = loader.load()
39
+ Chroma.from_documents(docs, distilbert_embedding, persist_directory="./vector_store")
40
 
41
+ # Function to get knowledge vector store
42
+ def get_knowledge_vector_store():
43
+ distilbert_embedding = HuggingFaceBgeEmbeddings(model_name=embedding_model_name)
44
+ vector_store = Chroma(embedding_function=distilbert_embedding, persist_directory="./vector_store")
45
+ return vector_store
 
 
 
 
46
 
47
+ # Function to get knowledge-based answer
48
+ def get_knowledge_based_answer(query, qwen_text_gen, vector_store, VECTOR_SEARCH_TOP_K, web_content):
49
+ if web_content:
50
+ prompt_template = f"""Answer the user's question based on the following known information.
51
+ Known web search content: {web_content} """ + """
52
+ Known Content:
53
+ {context}
54
+ question:
55
+ {question}"""
56
+ else:
57
+ prompt_template = """Answer the user's question based on the known information.
58
+ Known Content:
59
+ {context}
60
+ question:
61
+ {question}"""
62
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
63
+
64
+ knowledge_chain = RetrievalQA.from_llm(
65
+ llm=qwen_text_gen,
66
+ retriever=vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}),
67
+ prompt=prompt
68
+ )
69
+
70
+ knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
71
+ input_variables=["page_content"],
72
+ template="{page_content}"
73
+ )
74
+
75
+ knowledge_chain.return_source_documents = True
76
+
77
+ result = knowledge_chain.invoke({"query": query})
78
+
79
+ return result['result']
80
 
81
+ # Function to clear session
82
+ def clear_session():
83
+ return '', None
84
 
85
+ # Function to predict
86
+ def predict(input, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, key=None, history=None):
87
+ if history == None:
88
+ history = []
89
+ vector_store = get_knowledge_vector_store()
90
+ if use_web == 'True':
91
+ web_content = search_web(query=input)
92
+ if web_content is None:
93
+ web_content = ""
94
+ else:
95
+ web_content = ''
96
 
97
+ resp = get_knowledge_based_answer(
98
+ query=input,
99
+ qwen_text_gen=qwen_text_gen,
100
+ vector_store=vector_store,
101
+ VECTOR_SEARCH_TOP_K=VECTOR_SEARCH_TOP_K,
102
+ web_content=web_content,
103
+ )
104
+ history.append((input, resp))
105
+ return '', history, history
106
 
107
+ # Gradio interface setup
108
+ block = gr.Blocks()
109
+ with block as demo:
110
+ gr.Markdown("<h1><center>Chat History </center></h1>")
111
+ with gr.Row():
112
+ with gr.Column(scale=1):
113
+ file = gr.File(label='Please upload txt, md, docx type files', file_types=['.txt', '.md', '.docx'])
114
+ get_vs = gr.Button("Generate Knowledge Base")
115
+ get_vs.click(init_knowledge_vector_store, inputs=[file])
116
+
117
+ use_web = gr.Radio(["True", "False"], label="Web Search", value="False")
118
+
119
+ VECTOR_SEARCH_TOP_K = gr.Slider(1, 10, value=5, step=1, label="vector search top k", interactive=True)
120
 
121
+ with gr.Column(scale=4):
122
+ chatbot = gr.Chatbot(label='Ming History Knowledge Question and Answer Assistant', height=600)
123
+ message = gr.Textbox(label='Please enter your question')
124
+ state = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ with gr.Row():
127
+ clear_history = gr.Button("Clear history conversation")
128
+ send = gr.Button("Send")
129
+ send.click(predict,
130
+ inputs=[message, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, state],
131
+ outputs=[message, chatbot, state])
132
+ clear_history.click(fn=clear_session, inputs=[], outputs=[chatbot, state], queue=False)
133
+
134
+ message.submit(predict,
135
+ inputs=[message, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, state],
136
+ outputs=[message, chatbot, state])
137
 
138
+ demo.queue().launch(share=False)