Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import logging
|
6 |
|
7 |
from operator import itemgetter
|
8 |
-
from langchain_openai import ChatOpenAI
|
9 |
from langchain_community.document_loaders import PyPDFLoader
|
10 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
11 |
from langchain_core.prompts import ChatPromptTemplate
|
@@ -13,9 +13,6 @@ from langchain_community.vectorstores.chroma import Chroma
|
|
13 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
from langchain.schema import AIMessage, HumanMessage
|
15 |
from langchain_core.output_parsers import StrOutputParser
|
16 |
-
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
17 |
-
from langchain.chains.combine_documents import create_stuff_documents_chain
|
18 |
-
from langchain.chains import create_retrieval_chain
|
19 |
from langchain.globals import set_debug
|
20 |
from dotenv import load_dotenv
|
21 |
|
@@ -26,16 +23,27 @@ set_debug(True)
|
|
26 |
load_dotenv()
|
27 |
|
28 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
|
|
|
|
|
29 |
|
30 |
persist_dir = "./chroma_db"
|
31 |
-
device='cuda:0'
|
32 |
-
model_name="all-mpnet-base-v2"
|
33 |
-
model_kwargs = {'device': device if torch.cuda.is_available() else
|
34 |
logging.info(f"Using device {model_kwargs['device']}")
|
35 |
-
|
36 |
-
embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs)
|
37 |
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
logging.info("Configuring retriever")
|
40 |
|
41 |
if not os.path.exists(persist_dir):
|
@@ -63,10 +71,8 @@ def configure_retriever(local_files, chunk_size=12500, chunk_overlap=2500):
|
|
63 |
vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir)
|
64 |
|
65 |
# Define retriever
|
66 |
-
retriever = vectordb.as_retriever(
|
67 |
-
|
68 |
-
search_kwargs={'score_threshold': 0.8}
|
69 |
-
)
|
70 |
|
71 |
return retriever
|
72 |
else:
|
@@ -74,10 +80,7 @@ def configure_retriever(local_files, chunk_size=12500, chunk_overlap=2500):
|
|
74 |
vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
|
75 |
|
76 |
# Define retriever
|
77 |
-
retriever = vectordb.as_retriever(
|
78 |
-
search_type="similarity_score_threshold",
|
79 |
-
search_kwargs={'score_threshold': 0.8}
|
80 |
-
)
|
81 |
|
82 |
return retriever
|
83 |
|
@@ -86,7 +89,11 @@ local_files = [f for f in os.listdir(directory) if f.endswith(".pdf")]
|
|
86 |
|
87 |
# Setup LLM
|
88 |
llm = ChatOpenAI(
|
89 |
-
model_name="gpt-
|
|
|
|
|
|
|
|
|
90 |
)
|
91 |
|
92 |
retriever = configure_retriever(local_files)
|
@@ -96,7 +103,7 @@ template = """Answer the question based only on the following context:
|
|
96 |
|
97 |
Question: {question}
|
98 |
|
99 |
-
Answer in German
|
100 |
"""
|
101 |
|
102 |
prompt = ChatPromptTemplate.from_template(template)
|
@@ -111,28 +118,44 @@ chain = (
|
|
111 |
| StrOutputParser()
|
112 |
)
|
113 |
|
|
|
|
|
|
|
|
|
|
|
114 |
def predict(message, history):
|
115 |
-
message = f"Translate
|
116 |
history_langchain_format = []
|
117 |
for human, ai in history:
|
118 |
history_langchain_format.append(HumanMessage(content=human))
|
119 |
history_langchain_format.append(AIMessage(content=ai))
|
120 |
history_langchain_format.append(HumanMessage(content=message))
|
121 |
gpt_response = llm(history_langchain_format)
|
122 |
-
|
|
|
|
|
|
|
|
|
123 |
|
124 |
-
|
|
|
|
|
125 |
predict,
|
126 |
chatbot=gr.Chatbot(height=500, show_share_button=True),
|
127 |
textbox=gr.Textbox(placeholder="stell mir Fragen", container=False, scale=7),
|
128 |
title="Beitrag Service",
|
129 |
description="Ich bin Ihr hilfreicher KI-Assistent",
|
130 |
theme="soft",
|
131 |
-
examples=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
cache_examples=True,
|
133 |
-
|
134 |
-
undo_btn="Vorheriges löschen",
|
135 |
-
clear_btn="Löschen").launch(show_api= False)
|
136 |
|
137 |
if __name__ == "__main__":
|
138 |
-
demo.launch()
|
|
|
5 |
import logging
|
6 |
|
7 |
from operator import itemgetter
|
8 |
+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
9 |
from langchain_community.document_loaders import PyPDFLoader
|
10 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
11 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
13 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
from langchain.schema import AIMessage, HumanMessage
|
15 |
from langchain_core.output_parsers import StrOutputParser
|
|
|
|
|
|
|
16 |
from langchain.globals import set_debug
|
17 |
from dotenv import load_dotenv
|
18 |
|
|
|
23 |
load_dotenv()
|
24 |
|
25 |
openai_api_key = os.getenv("OPENAI_API_KEY")
|
26 |
+
langchain_api_key = os.getenv("LANGCHAIN_API_KEY")
|
27 |
+
langchain_endpoint = os.getenv("LANGCHAIN_ENDPOINT")
|
28 |
+
langchain_project_id = os.getenv("LANGCHAIN_PROJECT")
|
29 |
+
access_key = os.getenv("ACCESS_TOKEN_SECRET")
|
30 |
|
31 |
persist_dir = "./chroma_db"
|
32 |
+
device = 'cuda:0'
|
33 |
+
model_name = "all-mpnet-base-v2"
|
34 |
+
model_kwargs = {'device': device if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"}
|
35 |
logging.info(f"Using device {model_kwargs['device']}")
|
36 |
+
embed_money = False
|
|
|
37 |
|
38 |
+
# Create embeddings and store in vectordb
|
39 |
+
if embed_money:
|
40 |
+
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
41 |
+
logging.info(f"Using OpenAI embeddings")
|
42 |
+
else:
|
43 |
+
embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs)
|
44 |
+
logging.info(f"Using HuggingFace embeddings")
|
45 |
+
|
46 |
+
def configure_retriever(local_files, chunk_size=15000, chunk_overlap=2500):
|
47 |
logging.info("Configuring retriever")
|
48 |
|
49 |
if not os.path.exists(persist_dir):
|
|
|
71 |
vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir)
|
72 |
|
73 |
# Define retriever
|
74 |
+
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25})
|
75 |
+
|
|
|
|
|
76 |
|
77 |
return retriever
|
78 |
else:
|
|
|
80 |
vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
|
81 |
|
82 |
# Define retriever
|
83 |
+
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25})
|
|
|
|
|
|
|
84 |
|
85 |
return retriever
|
86 |
|
|
|
89 |
|
90 |
# Setup LLM
|
91 |
llm = ChatOpenAI(
|
92 |
+
model_name="gpt-4-0125-preview", openai_api_key=openai_api_key, temperature=0.1, streaming=True
|
93 |
+
)
|
94 |
+
|
95 |
+
llm_translate = ChatOpenAI(
|
96 |
+
model_name="gpt-3.5-turbo", openai_api_key=openai_api_key, temperature=0.0
|
97 |
)
|
98 |
|
99 |
retriever = configure_retriever(local_files)
|
|
|
103 |
|
104 |
Question: {question}
|
105 |
|
106 |
+
Answer in German Language. If the question is not related to the context, answer with "I don't know" in German.
|
107 |
"""
|
108 |
|
109 |
prompt = ChatPromptTemplate.from_template(template)
|
|
|
118 |
| StrOutputParser()
|
119 |
)
|
120 |
|
121 |
+
chain_translate = (llm_translate
|
122 |
+
| StrOutputParser()
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
def predict(message, history):
|
127 |
+
message = chain_translate.invoke(f"Translate this sentence to English: {message}")
|
128 |
history_langchain_format = []
|
129 |
for human, ai in history:
|
130 |
history_langchain_format.append(HumanMessage(content=human))
|
131 |
history_langchain_format.append(AIMessage(content=ai))
|
132 |
history_langchain_format.append(HumanMessage(content=message))
|
133 |
gpt_response = llm(history_langchain_format)
|
134 |
+
for chunk in chain.stream({"question": gpt_response.content}): # Stream the response
|
135 |
+
yield chunk
|
136 |
+
|
137 |
+
|
138 |
+
image_path = "./ui/logo.png" if os.path.exists("./ui/logo.png") else "./logo.png"
|
139 |
|
140 |
+
with gr.Blocks() as demo:
|
141 |
+
gr.Image(image_path)
|
142 |
+
gr.ChatInterface(
|
143 |
predict,
|
144 |
chatbot=gr.Chatbot(height=500, show_share_button=True),
|
145 |
textbox=gr.Textbox(placeholder="stell mir Fragen", container=False, scale=7),
|
146 |
title="Beitrag Service",
|
147 |
description="Ich bin Ihr hilfreicher KI-Assistent",
|
148 |
theme="soft",
|
149 |
+
examples=[
|
150 |
+
"Generate auditing questions about Change Management",
|
151 |
+
"Generate auditing questions about Software Maintenance",
|
152 |
+
"Generate auditing questions about Data Protection",
|
153 |
+
"Generate auditing questions about IT",
|
154 |
+
"Generate auditing questions about control systems",
|
155 |
+
"Generate auditing questions about GDPR compliance",
|
156 |
+
],
|
157 |
cache_examples=True,
|
158 |
+
).launch(show_api= False)
|
|
|
|
|
159 |
|
160 |
if __name__ == "__main__":
|
161 |
+
demo.launch()
|