Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from cryptography.fernet import Fernet | |
# --- LangChain / RAG Imports --- | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationSummaryMemory #ConversationBufferMemory | |
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint | |
def load_decrypted_preprompt(file_path="pre_prompt.enc"): | |
""" | |
Load and decrypt the pre-prompt from the encrypted file using the key | |
stored in the environment variable 'ENCRYPTION_KEY'. | |
""" | |
# Retrieve the encryption key from the environment | |
key_str = os.getenv("KEY", "") | |
if not key_str: | |
raise ValueError("Missing ENCRYPTION_KEY environment variable!") | |
key = key_str.encode() # Key must be in bytes | |
fernet = Fernet(key) | |
# Read the encrypted pre-prompt | |
with open(file_path, "rb") as file: | |
encrypted_text = file.read() | |
# Decrypt and decode the text | |
decrypted_text = fernet.decrypt(encrypted_text) | |
return decrypted_text.decode("utf-8") | |
# Instead of hardcoding, load the pre-prompt dynamically. | |
PRE_PROMPT = load_decrypted_preprompt() | |
# Default parameters for the QA chain | |
DEFAULT_TEMPERATURE = 0.7 | |
DEFAULT_MAX_TOKENS = 1024 | |
DEFAULT_TOP_K = 10 | |
DEFAULT_TOP_P = 0.95 | |
def load_vector_db(index_path="faiss_index", model_name="sentence-transformers/all-MiniLM-L6-v2"): | |
""" | |
Load the FAISS vector database from disk, allowing dangerous deserialization. | |
""" | |
embeddings = HuggingFaceEmbeddings(model_name=model_name) | |
vector_db = FAISS.load_local( | |
index_path, | |
embeddings, | |
allow_dangerous_deserialization=True # Only set this to True if you trust your data source! | |
) | |
return vector_db | |
def initialize_qa_chain(temperature, max_tokens, top_k, vector_db): | |
""" | |
Initialize the retrieval-augmented QA chain using the pre-built vector database. | |
""" | |
if vector_db is None: | |
return None | |
HF_TOKEN = os.getenv("AMAbot_r", "") # use for publishing | |
if not HF_TOKEN: | |
raise ValueError("Missing HF_TOKEN environment variable!") | |
llm = HuggingFaceEndpoint( | |
# repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", | |
repo_id="Qwen/Qwen2.5-1.5B-Instruct", | |
# repo_id="google/gemma-2b-it", | |
huggingfacehub_api_token=HF_TOKEN, # Only needed if the model endpoint requires authentication | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
top_k=top_k, | |
task="text-generation" | |
) | |
memory = ConversationSummaryMemory( | |
llm=llm, | |
max_token_limit=500, # Adjust this to control the summary size | |
memory_key="chat_history", | |
return_messages=True | |
) | |
retriever = vector_db.as_retriever() | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm, | |
retriever=retriever, | |
chain_type="stuff", | |
memory=memory, | |
return_source_documents=False, # Do not return source documents | |
verbose=False, | |
) | |
return qa_chain | |
def format_chat_history(history): | |
""" | |
Format chat history (a list of dictionaries) into a list of strings for the QA chain. | |
Each entry is prefixed with "User:" or "Assistant:" accordingly. | |
""" | |
formatted = [] | |
for message in history: | |
if message["role"] == "user": | |
formatted.append(f"User: {message['content']}") | |
elif message["role"] == "assistant": | |
formatted.append(f"Assistant: {message['content']}") | |
return formatted | |
def update_chat(message, history): | |
""" | |
Append the user's message to the chat history and clear the input box. | |
Returns: | |
- Updated chat history (for the Chatbot) | |
- The user message (to be used as input for the next function) | |
- An empty string to clear the textbox. | |
""" | |
if history is None: | |
history = [] | |
history = history.copy() | |
history.append({"role": "user", "content": message}) | |
return history, message, "" | |
def get_assistant_response(message, history, max_tokens, temperature, top_p, qa_chain_state_dict): | |
qa_chain = qa_chain_state_dict.get("qa_chain") | |
if qa_chain is not None: | |
# Format chat history to the plain-text format expected by the QA chain. | |
formatted_history = format_chat_history(history) | |
# Update the pre-prompt to encourage speculative responses. | |
speculative_pre_prompt = PRE_PROMPT + "\nIf you're not completely sure, please provide your best guess and mention that it is speculative." | |
combined_question = speculative_pre_prompt + "\n" + message | |
# Try retrieving an answer via the QA chain. | |
response = qa_chain.invoke({"question": combined_question, "chat_history": formatted_history}) | |
answer = response.get("answer", "").strip() | |
# If no answer is returned, try the fallback plain chat mode with adjusted parameters. | |
if not answer: | |
# Increase temperature and optionally max_tokens for fallback. | |
increased_temperature = min(temperature + 0.2, 1.0) # Cap temperature at 1.0 | |
increased_max_tokens = max_tokens + 128 # Increase max tokens for a longer response if needed | |
speculative_prompt = speculative_pre_prompt + "\n" + message | |
messages = [{"role": "system", "content": speculative_prompt}] + history | |
response = "" | |
result = client.chat_completion( | |
messages, | |
max_tokens=increased_max_tokens, | |
stream=False, | |
temperature=increased_temperature, | |
top_p=top_p, | |
) | |
for token_message in result: | |
token = token_message.choices[0].delta.content | |
response += token | |
answer = response.strip() | |
# Final fallback if still empty. | |
if not answer: | |
answer = ("I'm sorry, I couldn't retrieve a clear answer. " | |
"However, based on the available context, here is my best guess: " | |
"[speculative answer].") | |
history.append({"role": "assistant", "content": answer}) | |
return history, {"qa_chain": qa_chain} | |
# Fallback: Plain Chat Mode using the InferenceClient when no QA chain is available. | |
messages = [{"role": "system", "content": PRE_PROMPT}] + history | |
response = "" | |
result = client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=False, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
# for token_message in result: | |
# token = token_message.choices[0].delta.content | |
# response += token | |
response = result.choices[0].message.content.strip() | |
# response = response.strip() | |
if not response: | |
response = ("I'm sorry, I couldn't generate a response. Please try asking in a different way. " | |
"Alternatively, consider contacting Christopher directly: https://gcmarais.com/contact/") | |
history.append({"role": "assistant", "content": response}) | |
return history, {"qa_chain": qa_chain} | |
HF_TOKEN = os.getenv("AMAbot_r", "") # use for publishing | |
if not HF_TOKEN: | |
raise ValueError("Missing HF_TOKEN environment variable!") | |
# Global InferenceClient for plain chat (fallback) | |
client = InferenceClient( | |
# "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", | |
"Qwen/Qwen2.5-1.5B-Instruct", | |
# "google/gemma-2b-it", | |
token=HF_TOKEN) | |
# --- Auto-load vector database and initialize QA chain at startup --- | |
try: | |
vector_db = load_vector_db("faiss_index") | |
db_status_msg = "Vector DB loaded successfully." | |
except Exception as e: | |
vector_db = None | |
db_status_msg = f"Failed to load Vector DB: {e}" | |
if vector_db: | |
qa_chain = initialize_qa_chain(DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS, DEFAULT_TOP_K, vector_db) | |
else: | |
qa_chain = None | |
qa_chain_state_initial = {"qa_chain": qa_chain} | |
# New function to immediately send an example query: | |
def send_example(example_text, history, max_tokens, temperature, top_p, qa_chain_state): | |
if history is None: | |
history = [] | |
# Simulate appending the user's message. | |
history, _, _ = update_chat(example_text, history) | |
# Get the assistant's response. | |
history, qa_chain_state = get_assistant_response(example_text, history, max_tokens, temperature, top_p, qa_chain_state) | |
# Also hide the examples row. | |
return history, qa_chain_state, gr.update(visible=False) | |
# --------------------------- | |
# Gradio Interface Layout | |
# --------------------------- | |
# Create a theme instance using one of Gradio's prebuilt themes | |
# Custom CSS that forces light mode regardless of browser settings. | |
custom_css = """ | |
:root { | |
--primary-200: transparent !important; | |
color-scheme: light !important; | |
background-color: #fff !important; | |
color: #333 !important; | |
} | |
/* Override the background color for user messages in the Chatbot */ | |
#chatbot .message.user { | |
background-color: #ccc !important; /* Grey background */ | |
color: #222 !important; | |
} | |
.gradio-container footer { | |
display: none !important; | |
} | |
.gradio-container { | |
width: 100% !important; | |
max-width: none !important; | |
margin: 0; | |
} | |
.gradio-container .fillable { | |
width: 100% !important; | |
max-width: unset !important; | |
margin: 0; | |
} | |
.hf-chat-input textarea:focus { | |
outline: none !important; | |
box-shadow: none !important; | |
border-color: #c2c2c2 !important; | |
} | |
.hf-chat-input:focus { | |
outline: none !important; | |
box-shadow: none !important; | |
border-color: #c2c2c2 !important; /* or use your preferred grey */ | |
} | |
.block-container { | |
width: 100% !important; | |
max-width: none !important; | |
} | |
""" | |
with gr.Blocks(fill_width=True, css=custom_css, theme=gr.themes.Default(primary_hue="sky")) as demo: | |
# Insert custom CSS for layout: | |
gr.HTML(""" | |
<script> | |
window.addEventListener("load", () => { | |
document.documentElement.setAttribute("data-theme", "light"); | |
}); | |
</script> | |
<style> | |
:root { | |
color-scheme: light !important; | |
background-color: #fff !important; | |
color: #333 !important; | |
} | |
body .gradio-container .chatbot .hf-chat-input button .textbox textarea { | |
background-color: #fff !important; | |
color: #333 !important; | |
} | |
.example-row { | |
flex-grow: 1 !important; | |
width: 100% !important; | |
display: flex; | |
flex-direction: row; | |
flex-wrap: wrap; /* Will wrap to vertical if there's not enough space */ | |
justify-content: center; /* or flex-start, depending on your layout preference */ | |
gap: 10px; /* optional: add spacing between buttons */ | |
} | |
/* Container for the input box and embedded send button */ | |
.input-container { | |
position: relative; | |
width: 100%; | |
} | |
/* Style for the input text to mimic Hugging Face Chat UI */ | |
.hf-chat-input { | |
background-color: #f9f9f9; | |
border: 1px solid #e0e0e0; | |
border-radius: 20px; | |
padding: 10px 50px 10px 20px; /* extra right padding to make room for the send button */ | |
font-size: 16px; | |
width: 100%; | |
box-sizing: border-box; | |
transition: border-color 0.2s ease; | |
} | |
.hf-chat-input:focus { | |
outline: none; | |
border-color: #c2c2c2; | |
} | |
/* Style for the embedded send button */ | |
.send-button { | |
position: absolute; | |
right: 10px; /* adjust as needed */ | |
top: 50%; | |
transform: translateY(-50%); | |
width: 15px !important; /* desired width */ | |
height: 30px !important; /* desired height */ | |
padding: 0; | |
background: #fff; | |
border: none; | |
border-radius: 50%; | |
cursor: pointer; | |
transition: background-color 0.2s ease; | |
display: flex; /* use flexbox for centering */ | |
align-items: center; | |
justify-content: center; | |
font-size: 16px; /* ensure consistent text size */ | |
line-height: 1; | |
} | |
.send-button:hover, | |
.send-button:focus, | |
.send-button:active { | |
background-color: #f0f0f0; | |
outline: none; /* remove focus outline */ | |
top: 50% !important; | |
transform: translateY(-50%) !important; | |
} | |
/* Overall input row styling */ | |
.input-row { | |
display: flex; | |
align-items: center; | |
width: 100%; | |
gap: 10px; | |
} | |
</style> | |
""") | |
# Keep the QA chain state in Gradio | |
qa_chain_state = gr.State(value=qa_chain_state_initial) | |
# Hidden state to temporarily hold the user message for processing | |
user_message_state = gr.State() | |
# Chat window using dictionary message format; initially hidden | |
chatbot = gr.Chatbot(label="AMAbot", show_label=True, elem_id="chatbot", height=250, type="messages", visible=False) | |
# --------------------------- | |
# Example Inputs Row (clickable examples) | |
# --------------------------- | |
with gr.Row(elem_classes="example-row", visible=True) as examples_container: | |
ex1 = gr.Button("Who?") | |
ex2 = gr.Button("Where?") | |
ex3 = gr.Button("What?") | |
# Immediately show the chatbot when an example button is clicked (non-blocking) | |
ex1.click(lambda: gr.update(visible=True), None, chatbot, queue=False) | |
ex2.click(lambda: gr.update(visible=True), None, chatbot, queue=False) | |
ex3.click(lambda: gr.update(visible=True), None, chatbot, queue=False) | |
# Input row: Embed the send button inside the text input box container. | |
with gr.Row(elem_classes="input-row"): | |
with gr.Column(elem_classes="input-container"): | |
user_input = gr.Textbox( | |
show_label=False, | |
placeholder="Ask AMAbot anything about Christopher", | |
container=False, | |
elem_classes="hf-chat-input" | |
) | |
send_btn = gr.Button("❯❯", elem_classes="send-button") | |
# Hidden inputs for fixed parameters | |
max_tokens_input = gr.Number(value=DEFAULT_MAX_TOKENS, visible=False) | |
temperature_input = gr.Number(value=DEFAULT_TEMPERATURE, visible=False) | |
top_p_input = gr.Number(value=DEFAULT_TOP_P, visible=False) | |
# Immediately show the chatbot when the send button is clicked or Enter is pressed | |
user_input.submit(lambda: gr.update(visible=True), None, chatbot, queue=False) | |
send_btn.click(lambda: gr.update(visible=True), None, chatbot, queue=False) | |
# --------------------------- | |
# Bind events for manual text submission. | |
# --------------------------- | |
user_input.submit( | |
update_chat, | |
inputs=[user_input, chatbot], | |
outputs=[chatbot, user_message_state, user_input] | |
).then( | |
get_assistant_response, | |
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], | |
outputs=[chatbot, qa_chain_state] | |
) | |
send_btn.click( | |
update_chat, | |
inputs=[user_input, chatbot], | |
outputs=[chatbot, user_message_state, user_input] | |
).then( | |
get_assistant_response, | |
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], | |
outputs=[chatbot, qa_chain_state] | |
) | |
# --------------------------- | |
# Bind events for example buttons. | |
# --------------------------- | |
ex1.click( | |
lambda history: update_chat("Who is Christopher?", history)[:2], | |
inputs=[chatbot], | |
outputs=[chatbot, user_message_state] | |
).then( | |
get_assistant_response, | |
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], | |
outputs=[chatbot, qa_chain_state] | |
) | |
ex2.click( | |
lambda history: update_chat("Where is Christopher from?", history)[:2], | |
inputs=[chatbot], | |
outputs=[chatbot, user_message_state] | |
).then( | |
get_assistant_response, | |
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], | |
outputs=[chatbot, qa_chain_state] | |
) | |
ex3.click( | |
lambda history: update_chat("What degrees does Christopher have, and what job titles has he held?", history)[:2], | |
inputs=[chatbot], | |
outputs=[chatbot, user_message_state] | |
).then( | |
get_assistant_response, | |
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], | |
outputs=[chatbot, qa_chain_state] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(show_api=False) |