import os import gradio as gr from huggingface_hub import InferenceClient from cryptography.fernet import Fernet # --- LangChain / RAG Imports --- from langchain_community.vectorstores import FAISS from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationSummaryMemory #ConversationBufferMemory from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint def load_decrypted_preprompt(file_path="pre_prompt.enc"): """ Load and decrypt the pre-prompt from the encrypted file using the key stored in the environment variable 'ENCRYPTION_KEY'. """ # Retrieve the encryption key from the environment key_str = os.getenv("KEY", "") if not key_str: raise ValueError("Missing ENCRYPTION_KEY environment variable!") key = key_str.encode() # Key must be in bytes fernet = Fernet(key) # Read the encrypted pre-prompt with open(file_path, "rb") as file: encrypted_text = file.read() # Decrypt and decode the text decrypted_text = fernet.decrypt(encrypted_text) return decrypted_text.decode("utf-8") # Instead of hardcoding, load the pre-prompt dynamically. PRE_PROMPT = load_decrypted_preprompt() # Default parameters for the QA chain DEFAULT_TEMPERATURE = 0.7 DEFAULT_MAX_TOKENS = 1024 DEFAULT_TOP_K = 10 DEFAULT_TOP_P = 0.95 def load_vector_db(index_path="faiss_index", model_name="sentence-transformers/all-MiniLM-L6-v2"): """ Load the FAISS vector database from disk, allowing dangerous deserialization. """ embeddings = HuggingFaceEmbeddings(model_name=model_name) vector_db = FAISS.load_local( index_path, embeddings, allow_dangerous_deserialization=True # Only set this to True if you trust your data source! ) return vector_db def initialize_qa_chain(temperature, max_tokens, top_k, vector_db): """ Initialize the retrieval-augmented QA chain using the pre-built vector database. """ if vector_db is None: return None HF_TOKEN = os.getenv("AMAbot_r", "") # use for publishing if not HF_TOKEN: raise ValueError("Missing HF_TOKEN environment variable!") llm = HuggingFaceEndpoint( # repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", repo_id="Qwen/Qwen2.5-1.5B-Instruct", # repo_id="google/gemma-2b-it", huggingfacehub_api_token=HF_TOKEN, # Only needed if the model endpoint requires authentication temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, task="text-generation" ) memory = ConversationSummaryMemory( llm=llm, max_token_limit=500, # Adjust this to control the summary size memory_key="chat_history", return_messages=True ) retriever = vector_db.as_retriever() qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=False, # Do not return source documents verbose=False, ) return qa_chain def format_chat_history(history): """ Format chat history (a list of dictionaries) into a list of strings for the QA chain. Each entry is prefixed with "User:" or "Assistant:" accordingly. """ formatted = [] for message in history: if message["role"] == "user": formatted.append(f"User: {message['content']}") elif message["role"] == "assistant": formatted.append(f"Assistant: {message['content']}") return formatted def update_chat(message, history): """ Append the user's message to the chat history and clear the input box. Returns: - Updated chat history (for the Chatbot) - The user message (to be used as input for the next function) - An empty string to clear the textbox. """ if history is None: history = [] history = history.copy() history.append({"role": "user", "content": message}) return history, message, "" def get_assistant_response(message, history, max_tokens, temperature, top_p, qa_chain_state_dict): qa_chain = qa_chain_state_dict.get("qa_chain") if qa_chain is not None: # Format chat history to the plain-text format expected by the QA chain. formatted_history = format_chat_history(history) # Update the pre-prompt to encourage speculative responses. speculative_pre_prompt = PRE_PROMPT + "\nIf you're not completely sure, please provide your best guess and mention that it is speculative." combined_question = speculative_pre_prompt + "\n" + message # Try retrieving an answer via the QA chain. response = qa_chain.invoke({"question": combined_question, "chat_history": formatted_history}) answer = response.get("answer", "").strip() # If no answer is returned, try the fallback plain chat mode with adjusted parameters. if not answer: # Increase temperature and optionally max_tokens for fallback. increased_temperature = min(temperature + 0.2, 1.0) # Cap temperature at 1.0 increased_max_tokens = max_tokens + 128 # Increase max tokens for a longer response if needed speculative_prompt = speculative_pre_prompt + "\n" + message messages = [{"role": "system", "content": speculative_prompt}] + history response = "" result = client.chat_completion( messages, max_tokens=increased_max_tokens, stream=False, temperature=increased_temperature, top_p=top_p, ) for token_message in result: token = token_message.choices[0].delta.content response += token answer = response.strip() # Final fallback if still empty. if not answer: answer = ("I'm sorry, I couldn't retrieve a clear answer. " "However, based on the available context, here is my best guess: " "[speculative answer].") history.append({"role": "assistant", "content": answer}) return history, {"qa_chain": qa_chain} # Fallback: Plain Chat Mode using the InferenceClient when no QA chain is available. messages = [{"role": "system", "content": PRE_PROMPT}] + history response = "" result = client.chat_completion( messages, max_tokens=max_tokens, stream=False, temperature=temperature, top_p=top_p, ) # for token_message in result: # token = token_message.choices[0].delta.content # response += token response = result.choices[0].message.content.strip() # response = response.strip() if not response: response = ("I'm sorry, I couldn't generate a response. Please try asking in a different way. " "Alternatively, consider contacting Christopher directly: https://gcmarais.com/contact/") history.append({"role": "assistant", "content": response}) return history, {"qa_chain": qa_chain} HF_TOKEN = os.getenv("AMAbot_r", "") # use for publishing if not HF_TOKEN: raise ValueError("Missing HF_TOKEN environment variable!") # Global InferenceClient for plain chat (fallback) client = InferenceClient( # "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "Qwen/Qwen2.5-1.5B-Instruct", # "google/gemma-2b-it", token=HF_TOKEN) # --- Auto-load vector database and initialize QA chain at startup --- try: vector_db = load_vector_db("faiss_index") db_status_msg = "Vector DB loaded successfully." except Exception as e: vector_db = None db_status_msg = f"Failed to load Vector DB: {e}" if vector_db: qa_chain = initialize_qa_chain(DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS, DEFAULT_TOP_K, vector_db) else: qa_chain = None qa_chain_state_initial = {"qa_chain": qa_chain} # New function to immediately send an example query: def send_example(example_text, history, max_tokens, temperature, top_p, qa_chain_state): if history is None: history = [] # Simulate appending the user's message. history, _, _ = update_chat(example_text, history) # Get the assistant's response. history, qa_chain_state = get_assistant_response(example_text, history, max_tokens, temperature, top_p, qa_chain_state) # Also hide the examples row. return history, qa_chain_state, gr.update(visible=False) # --------------------------- # Gradio Interface Layout # --------------------------- # Create a theme instance using one of Gradio's prebuilt themes # Custom CSS that forces light mode regardless of browser settings. custom_css = """ :root { --primary-200: transparent !important; color-scheme: light !important; background-color: #fff !important; color: #333 !important; } /* Override the background color for user messages in the Chatbot */ #chatbot .message.user { background-color: #ccc !important; /* Grey background */ color: #222 !important; } .gradio-container footer { display: none !important; } .gradio-container { width: 100% !important; max-width: none !important; margin: 0; } .gradio-container .fillable { width: 100% !important; max-width: unset !important; margin: 0; } .hf-chat-input textarea:focus { outline: none !important; box-shadow: none !important; border-color: #c2c2c2 !important; } .hf-chat-input:focus { outline: none !important; box-shadow: none !important; border-color: #c2c2c2 !important; /* or use your preferred grey */ } .block-container { width: 100% !important; max-width: none !important; } """ with gr.Blocks(fill_width=True, css=custom_css, theme=gr.themes.Default(primary_hue="sky")) as demo: # Insert custom CSS for layout: gr.HTML(""" """) # Keep the QA chain state in Gradio qa_chain_state = gr.State(value=qa_chain_state_initial) # Hidden state to temporarily hold the user message for processing user_message_state = gr.State() # Chat window using dictionary message format; initially hidden chatbot = gr.Chatbot(label="AMAbot", show_label=True, elem_id="chatbot", height=250, type="messages", visible=False) # --------------------------- # Example Inputs Row (clickable examples) # --------------------------- with gr.Row(elem_classes="example-row", visible=True) as examples_container: ex1 = gr.Button("Who?") ex2 = gr.Button("Where?") ex3 = gr.Button("What?") # Immediately show the chatbot when an example button is clicked (non-blocking) ex1.click(lambda: gr.update(visible=True), None, chatbot, queue=False) ex2.click(lambda: gr.update(visible=True), None, chatbot, queue=False) ex3.click(lambda: gr.update(visible=True), None, chatbot, queue=False) # Input row: Embed the send button inside the text input box container. with gr.Row(elem_classes="input-row"): with gr.Column(elem_classes="input-container"): user_input = gr.Textbox( show_label=False, placeholder="Ask AMAbot anything about Christopher", container=False, elem_classes="hf-chat-input" ) send_btn = gr.Button("❯❯", elem_classes="send-button") # Hidden inputs for fixed parameters max_tokens_input = gr.Number(value=DEFAULT_MAX_TOKENS, visible=False) temperature_input = gr.Number(value=DEFAULT_TEMPERATURE, visible=False) top_p_input = gr.Number(value=DEFAULT_TOP_P, visible=False) # Immediately show the chatbot when the send button is clicked or Enter is pressed user_input.submit(lambda: gr.update(visible=True), None, chatbot, queue=False) send_btn.click(lambda: gr.update(visible=True), None, chatbot, queue=False) # --------------------------- # Bind events for manual text submission. # --------------------------- user_input.submit( update_chat, inputs=[user_input, chatbot], outputs=[chatbot, user_message_state, user_input] ).then( get_assistant_response, inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], outputs=[chatbot, qa_chain_state] ) send_btn.click( update_chat, inputs=[user_input, chatbot], outputs=[chatbot, user_message_state, user_input] ).then( get_assistant_response, inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], outputs=[chatbot, qa_chain_state] ) # --------------------------- # Bind events for example buttons. # --------------------------- ex1.click( lambda history: update_chat("Who is Christopher?", history)[:2], inputs=[chatbot], outputs=[chatbot, user_message_state] ).then( get_assistant_response, inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], outputs=[chatbot, qa_chain_state] ) ex2.click( lambda history: update_chat("Where is Christopher from?", history)[:2], inputs=[chatbot], outputs=[chatbot, user_message_state] ).then( get_assistant_response, inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], outputs=[chatbot, qa_chain_state] ) ex3.click( lambda history: update_chat("What degrees does Christopher have, and what job titles has he held?", history)[:2], inputs=[chatbot], outputs=[chatbot, user_message_state] ).then( get_assistant_response, inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], outputs=[chatbot, qa_chain_state] ) if __name__ == "__main__": demo.queue().launch(show_api=False)