muhtasham commited on
Commit
6f7484c
·
verified ·
1 Parent(s): dd1f101

Upload 5 files

Browse files
Files changed (3) hide show
  1. .env +6 -0
  2. app.py +138 -0
  3. requirements.txt +7 -0
.env ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ OPENAI_API_KEY="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"
2
+ LANGCHAIN_TRACING_V2=true
3
+ LANGCHAIN_ENDPOINT=https://api.smith.langchain.com
4
+ LANGCHAIN_API_KEY="ls__481915cb2eaa4a53876c4bcf592457b0"
5
+ LANGCHAIN_PROJECT="Beitrag POC"
6
+ ACCESS_TOKEN_SECRET="hpr;F3H678%H"
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import gradio as gr
4
+ import torch
5
+ import logging
6
+
7
+ from operator import itemgetter
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain_community.document_loaders import PyPDFLoader
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
11
+ from langchain_core.prompts import ChatPromptTemplate
12
+ from langchain_community.vectorstores.chroma import Chroma
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain.schema import AIMessage, HumanMessage
15
+ from langchain_core.output_parsers import StrOutputParser
16
+ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
17
+ from langchain.chains.combine_documents import create_stuff_documents_chain
18
+ from langchain.chains import create_retrieval_chain
19
+ from langchain.globals import set_debug
20
+ from dotenv import load_dotenv
21
+
22
+ # configure logging
23
+ logging.basicConfig(level=logging.INFO)
24
+
25
+ set_debug(True)
26
+ load_dotenv()
27
+
28
+ openai_api_key = os.getenv("OPENAI_API_KEY")
29
+
30
+ persist_dir = "./chroma_db"
31
+ device='cuda:0'
32
+ model_name="all-mpnet-base-v2"
33
+ model_kwargs = {'device': device if torch.cuda.is_available() else 'cpu'}
34
+ logging.info(f"Using device {model_kwargs['device']}")
35
+ # Create embeddings and store in vectordb
36
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs)
37
+
38
+ def configure_retriever(local_files, chunk_size=12500, chunk_overlap=2500):
39
+ logging.info("Configuring retriever")
40
+
41
+ if not os.path.exists(persist_dir):
42
+ logging.info(f"Persist directory {persist_dir} does not exist. Creating it.")
43
+ # Read documents
44
+ docs = []
45
+ temp_dir = tempfile.TemporaryDirectory()
46
+ for filename in local_files:
47
+ logging.info(f"Reading file {filename}")
48
+ # Read the file once
49
+ if not os.path.exists(os.path.join("docs", filename)):
50
+ file_content = open(os.path.join(".", filename), "rb").read()
51
+ else:
52
+ file_content = open(os.path.join("docs", filename), "rb").read()
53
+ temp_filepath = os.path.join(temp_dir.name, filename)
54
+ with open(temp_filepath, "wb") as f:
55
+ f.write(file_content)
56
+ loader = PyPDFLoader(temp_filepath)
57
+ docs.extend(loader.load())
58
+
59
+ # Split documents
60
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
61
+ splits = text_splitter.split_documents(docs)
62
+
63
+ vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir)
64
+
65
+ # Define retriever
66
+ retriever = vectordb.as_retriever(
67
+ search_type="similarity_score_threshold",
68
+ search_kwargs={'score_threshold': 0.8}
69
+ )
70
+
71
+ return retriever
72
+ else:
73
+ logging.info(f"Persist directory {persist_dir} exists. Loading from it.")
74
+ vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
75
+
76
+ # Define retriever
77
+ retriever = vectordb.as_retriever(
78
+ search_type="similarity_score_threshold",
79
+ search_kwargs={'score_threshold': 0.8}
80
+ )
81
+
82
+ return retriever
83
+
84
+ directory = "docs" if os.path.exists("docs") else "."
85
+ local_files = [f for f in os.listdir(directory) if f.endswith(".pdf")]
86
+
87
+ # Setup LLM
88
+ llm = ChatOpenAI(
89
+ model_name="gpt-3.5-turbo", openai_api_key=openai_api_key, temperature=0, streaming=True
90
+ )
91
+
92
+ retriever = configure_retriever(local_files)
93
+
94
+ template = """Answer the question based only on the following context:
95
+ {context}
96
+
97
+ Question: {question}
98
+
99
+ Answer in German language.
100
+ """
101
+
102
+ prompt = ChatPromptTemplate.from_template(template)
103
+
104
+ chain = (
105
+ {
106
+ "context": itemgetter("question") | retriever,
107
+ "question": itemgetter("question"),
108
+ }
109
+ | prompt
110
+ | llm
111
+ | StrOutputParser()
112
+ )
113
+
114
+ def predict(message, history):
115
+ message = f"Translate the following text to German: {message}"
116
+ history_langchain_format = []
117
+ for human, ai in history:
118
+ history_langchain_format.append(HumanMessage(content=human))
119
+ history_langchain_format.append(AIMessage(content=ai))
120
+ history_langchain_format.append(HumanMessage(content=message))
121
+ gpt_response = llm(history_langchain_format)
122
+ return chain.invoke({"question": gpt_response.content})
123
+
124
+ demo = gr.ChatInterface(
125
+ predict,
126
+ chatbot=gr.Chatbot(height=500, show_share_button=True),
127
+ textbox=gr.Textbox(placeholder="stell mir Fragen", container=False, scale=7),
128
+ title="Beitrag Service",
129
+ description="Ich bin Ihr hilfreicher KI-Assistent",
130
+ theme="soft",
131
+ examples=["Hello"],
132
+ cache_examples=True,
133
+ retry_btn="Wiederholen",
134
+ undo_btn="Vorheriges löschen",
135
+ clear_btn="Löschen").launch(show_api= False)
136
+
137
+ if __name__ == "__main__":
138
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ openai==1.12.0
3
+ langchain==0.1.10
4
+ langchain-openai==0.0.8
5
+ pypdf==4.0.1
6
+ python-dotenv==1.0.1
7
+ chromadb==0.4.22