Spaces:
Sleeping
Sleeping
import streamlit as st # import the Streamlit library | |
from langchain.chains import ConversationChain | |
from langchain.llms import OpenAIChat # import OpenAI model | |
from langchain.chains.conversation.memory import ConversationEntityMemory | |
from langchain.chains.conversation.prompt import ENTITY_MEMORY_CONVERSATION_TEMPLATE | |
import pickle | |
# Initialize session State | |
st.session_state["show_new_chat_button"] = False | |
if "id" not in st.session_state: | |
st.session_state["id"] = 0 | |
if "conversation" not in st.session_state: | |
st.session_state.conversation = [] | |
if "input" not in st.session_state: | |
st.session_state["input"] = "" | |
if "stored_session" not in st.session_state: | |
st.session_state["stored_session"]={} | |
if "input_temp" not in st.session_state: | |
st.session_state["input_temp"] = "" | |
# Set the title of the Streamlit app | |
st.title("HomemadeGPT π€ - The custom chatbot you need") | |
# Historique des conversations | |
conversation_history = st.empty() | |
API_KEY = st.sidebar.text_input("API-Key", type="password") | |
with st.sidebar.expander(" π οΈ Settings ", expanded=False): | |
# Option to preview memory store | |
if 'entity_memory' in st.session_state: | |
if st.checkbox("Preview memory store"): | |
st.write(st.session_state.entity_memory.store) | |
# Option to preview memory buffer | |
if st.checkbox("Preview memory buffer"): | |
st.write(st.session_state.entity_memory.buffer) | |
MODEL = st.selectbox(label='Model', options=['gpt-3.5-turbo','gpt-4','gpt-4-32k','text-davinci-003','text-davinci-002']) | |
K = st.number_input(' (#)Summary of prompts to consider',min_value=3,max_value=1000) | |
def clear_text(): | |
""" | |
A function that clears the text in the input box when the user type a search query and press enter | |
""" | |
st.session_state["input_temp"] = st.session_state["input"] | |
st.session_state["input"] = "" | |
def get_text(): | |
""" | |
Get the user input text. | |
Returns: | |
(str): The text entered by the user | |
""" | |
input_text = st.text_input("You: ", key="input", placeholder = "Your AI assistant ! Ask me anything...", label_visibility='hidden',on_change=clear_text) | |
return input_text | |
def new_chat(): | |
""" | |
Clears session state and start a new chat | |
""" | |
save_current_chat() | |
clean_screen() | |
clean_memory() | |
st.session_state["id"] += 1 | |
def clean_screen(): | |
""" | |
Clears the current conversation screen | |
""" | |
st.session_state.conversation = [] | |
st.session_state["input"] = "" | |
st.session_state["input_temp"] = "" | |
def clean_memory(): | |
""" | |
Clears the current conversation memory | |
""" | |
st.session_state.entity_memory.store = {} | |
st.session_state.entity_memory.buffer.clear() | |
def save_current_chat(): | |
""" | |
Save the current chat in st.session_state["stored_session"] | |
""" | |
saved_dict=dict() | |
saved_dict['conversation'] = st.session_state['conversation'] | |
saved_dict['conversation_memory'] = pickle.dumps(st.session_state.entity_memory) | |
st.session_state["stored_session"][st.session_state["id"]]=saved_dict | |
def resume_chat(session_id): | |
""" | |
Clears session state and start a new chat | |
""" | |
save_current_chat() | |
clean_screen() | |
clean_memory() | |
st.session_state["id"] = session_id | |
st.session_state["conversation"] = st.session_state["stored_session"][session_id]["conversation"] | |
st.session_state.entity_memory = pickle.loads(st.session_state["stored_session"][session_id]["conversation_memory"]) | |
st.session_state["show_new_chat_button"] = True | |
def show_conv(): | |
""" | |
Render the current conversation in html | |
""" | |
conversation_html = "" | |
for entry in st.session_state.conversation: | |
if 'user' in entry: | |
conversation_html += f'<div style="margin: 10px; padding: 8px; border-radius: 5px; background-color: #8090FF; text-align: left;">π€΅ {entry["user"]}</div>' | |
if 'chatbot' in entry: | |
conversation_html += f'<div style="margin: 10px; padding: 8px; border-radius: 5px; background-color: #D7BB2C; display: flex; align-items: center;">π€ <pre style="color: white; background-color: #D7BB2C; padding: 8px; border-radius: 5px; max-width: calc(100% - 60px); white-space: pre-wrap; word-wrap: break-word; word-break: break-all;">{entry["chatbot"]}</pre></div>' | |
conversation_history.write(conversation_html, unsafe_allow_html=True) | |
### Main APP | |
# Allow the user to clear all stored conversation sessions | |
if st.session_state.stored_session: | |
if st.sidebar.button("Clear-all"): | |
st.session_state.stored_session={} | |
clean_screen() | |
if API_KEY : | |
# Create an Open AI instance | |
llm = OpenAIChat( | |
temperature=0, | |
openai_api_key=API_KEY, | |
model_name = MODEL | |
) | |
# Create conversation memory | |
if 'entity_memory' not in st.session_state: | |
st.session_state.entity_memory= ConversationEntityMemory(llm=llm, k=K) | |
# Create the Conversation Chain | |
st.session_state.Conversation = ConversationChain(llm=llm, | |
prompt = ENTITY_MEMORY_CONVERSATION_TEMPLATE, | |
memory = st.session_state.entity_memory) | |
else : | |
st.markdown(''' | |
``` | |
- 1. Enter API Key + Hit enter π | |
- 2. Ask anything via the text input widget | |
``` | |
''') | |
st.sidebar.warning('API key required to try this app.The API key is not stored in any form.') | |
st.sidebar.info("Your API-key is not stored in any form by this app. However, for transparency ensure to delete your API once used.") | |
# Get the user input | |
user_input = get_text() | |
if st.session_state["input_temp"] : | |
output = st.session_state.Conversation.run(input=st.session_state["input_temp"]) | |
st.session_state.conversation.append({"user": st.session_state["input_temp"]}) | |
st.session_state.conversation.append({"chatbot": output}) | |
st.session_state["show_new_chat_button"] = True | |
if st.session_state["show_new_chat_button"] : | |
st.sidebar.button("New Chat", on_click=new_chat, type='primary') | |
if "conversation" in st.session_state: | |
show_conv() | |
if st.session_state.stored_session.values(): | |
# Display stored conversation sessions in the sidebar | |
for i, sublist in enumerate(st.session_state.stored_session.values()): | |
with st.sidebar.expander(label= f"Conversation-Session:{i}"): | |
st.button("Resume session", on_click=resume_chat,kwargs={"session_id":i},type='primary', key=f"Conversation-Session:{i}") | |
st.markdown(sublist) | |