AMAbot / app.py
ChristopherMarais's picture
Update app.py
56f821d verified
import os
import gradio as gr
from huggingface_hub import InferenceClient
from cryptography.fernet import Fernet
# --- LangChain / RAG Imports ---
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationSummaryMemory #ConversationBufferMemory
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
def load_decrypted_preprompt(file_path="pre_prompt.enc"):
"""
Load and decrypt the pre-prompt from the encrypted file using the key
stored in the environment variable 'ENCRYPTION_KEY'.
"""
# Retrieve the encryption key from the environment
key_str = os.getenv("KEY", "")
if not key_str:
raise ValueError("Missing ENCRYPTION_KEY environment variable!")
key = key_str.encode() # Key must be in bytes
fernet = Fernet(key)
# Read the encrypted pre-prompt
with open(file_path, "rb") as file:
encrypted_text = file.read()
# Decrypt and decode the text
decrypted_text = fernet.decrypt(encrypted_text)
return decrypted_text.decode("utf-8")
# Instead of hardcoding, load the pre-prompt dynamically.
PRE_PROMPT = load_decrypted_preprompt()
# Default parameters for the QA chain
DEFAULT_TEMPERATURE = 0.7
DEFAULT_MAX_TOKENS = 1024
DEFAULT_TOP_K = 10
DEFAULT_TOP_P = 0.95
def load_vector_db(index_path="faiss_index", model_name="sentence-transformers/all-MiniLM-L6-v2"):
"""
Load the FAISS vector database from disk, allowing dangerous deserialization.
"""
embeddings = HuggingFaceEmbeddings(model_name=model_name)
vector_db = FAISS.load_local(
index_path,
embeddings,
allow_dangerous_deserialization=True # Only set this to True if you trust your data source!
)
return vector_db
def initialize_qa_chain(temperature, max_tokens, top_k, vector_db):
"""
Initialize the retrieval-augmented QA chain using the pre-built vector database.
"""
if vector_db is None:
return None
HF_TOKEN = os.getenv("AMAbot_r", "") # use for publishing
if not HF_TOKEN:
raise ValueError("Missing HF_TOKEN environment variable!")
llm = HuggingFaceEndpoint(
# repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
repo_id="Qwen/Qwen2.5-1.5B-Instruct",
# repo_id="google/gemma-2b-it",
huggingfacehub_api_token=HF_TOKEN, # Only needed if the model endpoint requires authentication
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
task="text-generation"
)
memory = ConversationSummaryMemory(
llm=llm,
max_token_limit=500, # Adjust this to control the summary size
memory_key="chat_history",
return_messages=True
)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=False, # Do not return source documents
verbose=False,
)
return qa_chain
def format_chat_history(history):
"""
Format chat history (a list of dictionaries) into a list of strings for the QA chain.
Each entry is prefixed with "User:" or "Assistant:" accordingly.
"""
formatted = []
for message in history:
if message["role"] == "user":
formatted.append(f"User: {message['content']}")
elif message["role"] == "assistant":
formatted.append(f"Assistant: {message['content']}")
return formatted
def update_chat(message, history):
"""
Append the user's message to the chat history and clear the input box.
Returns:
- Updated chat history (for the Chatbot)
- The user message (to be used as input for the next function)
- An empty string to clear the textbox.
"""
if history is None:
history = []
history = history.copy()
history.append({"role": "user", "content": message})
return history, message, ""
def get_assistant_response(message, history, max_tokens, temperature, top_p, qa_chain_state_dict):
qa_chain = qa_chain_state_dict.get("qa_chain")
if qa_chain is not None:
# Format chat history to the plain-text format expected by the QA chain.
formatted_history = format_chat_history(history)
# Update the pre-prompt to encourage speculative responses.
speculative_pre_prompt = PRE_PROMPT + "\nIf you're not completely sure, please provide your best guess and mention that it is speculative."
combined_question = speculative_pre_prompt + "\n" + message
# Try retrieving an answer via the QA chain.
response = qa_chain.invoke({"question": combined_question, "chat_history": formatted_history})
answer = response.get("answer", "").strip()
# If no answer is returned, try the fallback plain chat mode with adjusted parameters.
if not answer:
# Increase temperature and optionally max_tokens for fallback.
increased_temperature = min(temperature + 0.2, 1.0) # Cap temperature at 1.0
increased_max_tokens = max_tokens + 128 # Increase max tokens for a longer response if needed
speculative_prompt = speculative_pre_prompt + "\n" + message
messages = [{"role": "system", "content": speculative_prompt}] + history
response = ""
result = client.chat_completion(
messages,
max_tokens=increased_max_tokens,
stream=False,
temperature=increased_temperature,
top_p=top_p,
)
for token_message in result:
token = token_message.choices[0].delta.content
response += token
answer = response.strip()
# Final fallback if still empty.
if not answer:
answer = ("I'm sorry, I couldn't retrieve a clear answer. "
"However, based on the available context, here is my best guess: "
"[speculative answer].")
history.append({"role": "assistant", "content": answer})
return history, {"qa_chain": qa_chain}
# Fallback: Plain Chat Mode using the InferenceClient when no QA chain is available.
messages = [{"role": "system", "content": PRE_PROMPT}] + history
response = ""
result = client.chat_completion(
messages,
max_tokens=max_tokens,
stream=False,
temperature=temperature,
top_p=top_p,
)
# for token_message in result:
# token = token_message.choices[0].delta.content
# response += token
response = result.choices[0].message.content.strip()
# response = response.strip()
if not response:
response = ("I'm sorry, I couldn't generate a response. Please try asking in a different way. "
"Alternatively, consider contacting Christopher directly: https://gcmarais.com/contact/")
history.append({"role": "assistant", "content": response})
return history, {"qa_chain": qa_chain}
HF_TOKEN = os.getenv("AMAbot_r", "") # use for publishing
if not HF_TOKEN:
raise ValueError("Missing HF_TOKEN environment variable!")
# Global InferenceClient for plain chat (fallback)
client = InferenceClient(
# "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"Qwen/Qwen2.5-1.5B-Instruct",
# "google/gemma-2b-it",
token=HF_TOKEN)
# --- Auto-load vector database and initialize QA chain at startup ---
try:
vector_db = load_vector_db("faiss_index")
db_status_msg = "Vector DB loaded successfully."
except Exception as e:
vector_db = None
db_status_msg = f"Failed to load Vector DB: {e}"
if vector_db:
qa_chain = initialize_qa_chain(DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS, DEFAULT_TOP_K, vector_db)
else:
qa_chain = None
qa_chain_state_initial = {"qa_chain": qa_chain}
# New function to immediately send an example query:
def send_example(example_text, history, max_tokens, temperature, top_p, qa_chain_state):
if history is None:
history = []
# Simulate appending the user's message.
history, _, _ = update_chat(example_text, history)
# Get the assistant's response.
history, qa_chain_state = get_assistant_response(example_text, history, max_tokens, temperature, top_p, qa_chain_state)
# Also hide the examples row.
return history, qa_chain_state, gr.update(visible=False)
# ---------------------------
# Gradio Interface Layout
# ---------------------------
# Create a theme instance using one of Gradio's prebuilt themes
# Custom CSS that forces light mode regardless of browser settings.
custom_css = """
:root {
--primary-200: transparent !important;
color-scheme: light !important;
background-color: #fff !important;
color: #333 !important;
}
/* Override the background color for user messages in the Chatbot */
#chatbot .message.user {
background-color: #ccc !important; /* Grey background */
color: #222 !important;
}
.gradio-container footer {
display: none !important;
}
.gradio-container {
width: 100% !important;
max-width: none !important;
margin: 0;
}
.gradio-container .fillable {
width: 100% !important;
max-width: unset !important;
margin: 0;
}
.hf-chat-input textarea:focus {
outline: none !important;
box-shadow: none !important;
border-color: #c2c2c2 !important;
}
.hf-chat-input:focus {
outline: none !important;
box-shadow: none !important;
border-color: #c2c2c2 !important; /* or use your preferred grey */
}
.block-container {
width: 100% !important;
max-width: none !important;
}
"""
with gr.Blocks(fill_width=True, css=custom_css, theme=gr.themes.Default(primary_hue="sky")) as demo:
# Insert custom CSS for layout:
gr.HTML("""
<script>
window.addEventListener("load", () => {
document.documentElement.setAttribute("data-theme", "light");
});
</script>
<style>
:root {
color-scheme: light !important;
background-color: #fff !important;
color: #333 !important;
}
body .gradio-container .chatbot .hf-chat-input button .textbox textarea {
background-color: #fff !important;
color: #333 !important;
}
.example-row {
flex-grow: 1 !important;
width: 100% !important;
display: flex;
flex-direction: row;
flex-wrap: wrap; /* Will wrap to vertical if there's not enough space */
justify-content: center; /* or flex-start, depending on your layout preference */
gap: 10px; /* optional: add spacing between buttons */
}
/* Container for the input box and embedded send button */
.input-container {
position: relative;
width: 100%;
}
/* Style for the input text to mimic Hugging Face Chat UI */
.hf-chat-input {
background-color: #f9f9f9;
border: 1px solid #e0e0e0;
border-radius: 20px;
padding: 10px 50px 10px 20px; /* extra right padding to make room for the send button */
font-size: 16px;
width: 100%;
box-sizing: border-box;
transition: border-color 0.2s ease;
}
.hf-chat-input:focus {
outline: none;
border-color: #c2c2c2;
}
/* Style for the embedded send button */
.send-button {
position: absolute;
right: 10px; /* adjust as needed */
top: 50%;
transform: translateY(-50%);
width: 15px !important; /* desired width */
height: 30px !important; /* desired height */
padding: 0;
background: #fff;
border: none;
border-radius: 50%;
cursor: pointer;
transition: background-color 0.2s ease;
display: flex; /* use flexbox for centering */
align-items: center;
justify-content: center;
font-size: 16px; /* ensure consistent text size */
line-height: 1;
}
.send-button:hover,
.send-button:focus,
.send-button:active {
background-color: #f0f0f0;
outline: none; /* remove focus outline */
top: 50% !important;
transform: translateY(-50%) !important;
}
/* Overall input row styling */
.input-row {
display: flex;
align-items: center;
width: 100%;
gap: 10px;
}
</style>
""")
# Keep the QA chain state in Gradio
qa_chain_state = gr.State(value=qa_chain_state_initial)
# Hidden state to temporarily hold the user message for processing
user_message_state = gr.State()
# Chat window using dictionary message format; initially hidden
chatbot = gr.Chatbot(label="AMAbot", show_label=True, elem_id="chatbot", height=250, type="messages", visible=False)
# ---------------------------
# Example Inputs Row (clickable examples)
# ---------------------------
with gr.Row(elem_classes="example-row", visible=True) as examples_container:
ex1 = gr.Button("Who?")
ex2 = gr.Button("Where?")
ex3 = gr.Button("What?")
# Immediately show the chatbot when an example button is clicked (non-blocking)
ex1.click(lambda: gr.update(visible=True), None, chatbot, queue=False)
ex2.click(lambda: gr.update(visible=True), None, chatbot, queue=False)
ex3.click(lambda: gr.update(visible=True), None, chatbot, queue=False)
# Input row: Embed the send button inside the text input box container.
with gr.Row(elem_classes="input-row"):
with gr.Column(elem_classes="input-container"):
user_input = gr.Textbox(
show_label=False,
placeholder="Ask AMAbot anything about Christopher",
container=False,
elem_classes="hf-chat-input"
)
send_btn = gr.Button("❯❯", elem_classes="send-button")
# Hidden inputs for fixed parameters
max_tokens_input = gr.Number(value=DEFAULT_MAX_TOKENS, visible=False)
temperature_input = gr.Number(value=DEFAULT_TEMPERATURE, visible=False)
top_p_input = gr.Number(value=DEFAULT_TOP_P, visible=False)
# Immediately show the chatbot when the send button is clicked or Enter is pressed
user_input.submit(lambda: gr.update(visible=True), None, chatbot, queue=False)
send_btn.click(lambda: gr.update(visible=True), None, chatbot, queue=False)
# ---------------------------
# Bind events for manual text submission.
# ---------------------------
user_input.submit(
update_chat,
inputs=[user_input, chatbot],
outputs=[chatbot, user_message_state, user_input]
).then(
get_assistant_response,
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state],
outputs=[chatbot, qa_chain_state]
)
send_btn.click(
update_chat,
inputs=[user_input, chatbot],
outputs=[chatbot, user_message_state, user_input]
).then(
get_assistant_response,
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state],
outputs=[chatbot, qa_chain_state]
)
# ---------------------------
# Bind events for example buttons.
# ---------------------------
ex1.click(
lambda history: update_chat("Who is Christopher?", history)[:2],
inputs=[chatbot],
outputs=[chatbot, user_message_state]
).then(
get_assistant_response,
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state],
outputs=[chatbot, qa_chain_state]
)
ex2.click(
lambda history: update_chat("Where is Christopher from?", history)[:2],
inputs=[chatbot],
outputs=[chatbot, user_message_state]
).then(
get_assistant_response,
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state],
outputs=[chatbot, qa_chain_state]
)
ex3.click(
lambda history: update_chat("What degrees does Christopher have, and what job titles has he held?", history)[:2],
inputs=[chatbot],
outputs=[chatbot, user_message_state]
).then(
get_assistant_response,
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state],
outputs=[chatbot, qa_chain_state]
)
if __name__ == "__main__":
demo.queue().launch(show_api=False)