Spaces:
Running
Running
import os | |
import tempfile | |
import gradio as gr | |
import torch | |
import logging | |
import base64 | |
from operator import itemgetter | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate | |
from langchain_community.vectorstores.chroma import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import AIMessage, HumanMessage | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain.globals import set_debug | |
from dotenv import load_dotenv | |
def image_to_base64(image_path): | |
with open(image_path, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
return encoded_string | |
# configure logging | |
logging.basicConfig(level=logging.INFO) | |
set_debug(True) | |
load_dotenv() | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
langchain_api_key = os.getenv("LANGCHAIN_API_KEY") | |
langchain_endpoint = os.getenv("LANGCHAIN_ENDPOINT") | |
langchain_project_id = os.getenv("LANGCHAIN_PROJECT") | |
access_key = os.getenv("ACCESS_TOKEN_SECRET") | |
persist_dir = "./chroma_db" | |
device = 'cuda:0' | |
model_name = "all-mpnet-base-v2" | |
model_kwargs = {'device': device if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"} | |
logging.info(f"Using device {model_kwargs['device']}") | |
embed_money = False | |
# Create embeddings and store in vectordb | |
if embed_money: | |
embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | |
logging.info(f"Using OpenAI embeddings") | |
else: | |
embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs) | |
logging.info(f"Using HuggingFace embeddings") | |
def configure_retriever(local_files, chunk_size=15000, chunk_overlap=2500): | |
logging.info("Configuring retriever") | |
if not os.path.exists(persist_dir): | |
logging.info(f"Persist directory {persist_dir} does not exist. Creating it.") | |
# Read documents | |
docs = [] | |
temp_dir = tempfile.TemporaryDirectory() | |
for filename in local_files: | |
logging.info(f"Reading file {filename}") | |
# Read the file once | |
if not os.path.exists(os.path.join("docs", filename)): | |
file_content = open(os.path.join(".", filename), "rb").read() | |
else: | |
file_content = open(os.path.join("docs", filename), "rb").read() | |
temp_filepath = os.path.join(temp_dir.name, filename) | |
with open(temp_filepath, "wb") as f: | |
f.write(file_content) | |
loader = PyPDFLoader(temp_filepath) | |
docs.extend(loader.load()) | |
# Split documents | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
splits = text_splitter.split_documents(docs) | |
vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir) | |
# Define retriever | |
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}) | |
return retriever | |
else: | |
logging.info(f"Persist directory {persist_dir} exists. Loading from it.") | |
vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) | |
# Define retriever | |
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}) | |
return retriever | |
directory = "docs" if os.path.exists("docs") else "." | |
local_files = [f for f in os.listdir(directory) if f.endswith(".pdf")] | |
def setup_llm(system_message): | |
# Setup LLM | |
llm = ChatOpenAI( | |
model_name="gpt-4o", openai_api_key=openai_api_key, temperature=0.1, streaming=True | |
) | |
retriever = configure_retriever(local_files) | |
template = system_message + """ | |
Answer the question based only on the following context in it's original language. | |
{context} | |
Question: {question} | |
Original Message: {original_msg} | |
Chat History: {history} | |
If the question is not related to the context, answer with "I don't know" in the original language. | |
If the user is asking for follow-up questions on the same topic, generate different questions than you already answered. | |
If the user is asking to explain the context, or expand on the context, then provide explanation in the original language. | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
chain_translate = ( | |
llm | |
| StrOutputParser() | |
) | |
chain_rag = ( | |
{ | |
"context": itemgetter("question") | retriever, | |
"question": itemgetter("question"), | |
"original_msg": itemgetter("original_msg"), | |
"history": itemgetter("history") | |
} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
return chain_rag, chain_translate | |
def predict(message, history, system_message): | |
logging.info(system_message) | |
chain_rag, chain_translate = setup_llm(system_message) | |
message_transalated = chain_translate.invoke(f"Translate this query to English if it is in German otherwise return original contetn: {message}") | |
history_langchain_format = [] | |
partial_message = "" | |
for human, ai in history: | |
history_langchain_format.append(HumanMessage(content=human)) | |
history_langchain_format.append(AIMessage(content=ai)) | |
history_langchain_format.append(HumanMessage(content=message)) | |
for response in chain_rag.stream({"question": message_transalated, "original_msg": message, "history": history_langchain_format}): | |
partial_message += response | |
yield partial_message | |
image_path = "./ui/logo.png" if os.path.exists("./ui/logo.png") else "./logo.png" | |
logo_base64 = image_to_base64(image_path) | |
# CSS with the Base64-encoded image | |
css = f""" | |
body::before {{ | |
content: ''; | |
display: block; | |
height: 150px !important; /* Adjust based on your logo's size */ | |
background: url('data:image/png;base64,{logo_base64}') no-repeat center center !important; | |
background-size: contain !important; /* This makes sure the logo fits well in the header */ | |
}} | |
#q-output {{ | |
max-height: 60vh !important; | |
overflow: auto !important; | |
}} | |
""" | |
gr.ChatInterface( | |
predict, | |
chatbot=gr.Chatbot(likeable=True, show_share_button=False, show_copy_button=True), | |
textbox=gr.Textbox(placeholder="stell mir Fragen", scale=7), | |
description="Ich bin Ihr hilfreicher KI-Assistent", | |
theme="soft", | |
submit_btn="Senden", | |
retry_btn="🔄 Wiederholen", | |
undo_btn="⏪ Rückgängig", | |
clear_btn="🗑️ Löschen", | |
additional_inputs=[ | |
gr.Textbox("You are an auditor with many years of professional experience and are to develop a questionnaire on the topic of home office in the form of a self-assessment for me. As a basis for the questionnaire, you use standards and best practices (for example, from ISO 27001 and COBIT). The questionnaire should not exceed 20 questions.", label="System Prompt") | |
], | |
cache_examples=False, | |
fill_height=True, | |
css=css, | |
).launch(show_api=False) |