convosim-ui-dev / pages /comparisor.py
ivnban27-ctl's picture
changes to comparisor on new role models GCT and SP
eda0ce6
raw
history blame
8.73 kB
import os
import random
import datetime as dt
import streamlit as st
from streamlit.logger import get_logger
from langchain.schema.messages import HumanMessage
from utils.mongo_utils import get_db_client, new_comparison, new_battle_result
from utils.app_utils import create_memory_add_initial_message, get_random_name
from utils.memory_utils import clear_memory, push_convo2db
from utils.chain_utils import get_chain
from app_config import ISSUES, SOURCES, source2label
logger = get_logger(__name__)
openai_api_key = os.environ['OPENAI_API_KEY']
memories = {
'memoryA': {"issue": ISSUES[0], "source": SOURCES[0]},
'memoryB': {"issue": ISSUES[0], "source": SOURCES[1]},
'commonMemory': {"issue": ISSUES[0], "source": SOURCES[0]}
}
if 'db_client' not in st.session_state:
st.session_state["db_client"] = get_db_client()
if 'previous_sourceA' not in st.session_state:
st.session_state['previous_sourceA'] = SOURCES[0]
if 'previous_sourceB' not in st.session_state:
st.session_state['previous_sourceB'] = SOURCES[0]
if 'counselor_name' not in st.session_state:
st.session_state["counselor_name"] = get_random_name()
if 'texter_name' not in st.session_state:
st.session_state["texter_name"] = get_random_name()
def delete_last_message(memory):
last_prompt = memory.chat_memory.messages[-2].content
memory.chat_memory.messages = memory.chat_memory.messages[:-2]
return last_prompt
def replace_last_message(memory, new_message):
memory.chat_memory.messages = memory.chat_memory.messages[:-1]
memory.chat_memory.add_ai_message(new_message)
def regenerateA():
last_prompt = delete_last_message(memoryA)
new_response = llm_chainA.predict(input=last_prompt, stop=stopperA)
col1.chat_message("user").write(last_prompt)
col1.chat_message("assistant").write(new_response)
return new_response
def regenerateB():
last_prompt = delete_last_message(memoryB)
new_response = llm_chainB.predict(input=last_prompt, stop=stopperB)
col2.chat_message("user").write(last_prompt)
col2.chat_message("assistant").write(new_response)
return new_response
def replaceA():
last_prompt = memoryB.chat_memory.messages[-2].content
new_message = memoryB.chat_memory.messages[-1].content
replace_last_message(memoryA, new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_two'
)
def replaceB():
last_prompt = memoryA.chat_memory.messages[-2].content
new_message = memoryA.chat_memory.messages[-1].content
replace_last_message(memoryB, new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_one'
)
def regenerateBoth():
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='both_bad'
)
responseA = regenerateA()
responseB = regenerateB()
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
def bothGood():
if len(memoryA.buffer_as_messages) == 1:
pass
else:
i = random.choice([memoryA, memoryB])
last_prompt = i.chat_memory.messages[-2].content
last_reponse = i.chat_memory.messages[-1].content
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='tie'
)
with st.sidebar:
username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
issue = st.selectbox("Select an Issue", ISSUES, index=0,
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
language = st.selectbox("Select a Language", supported_languages, index=0,
format_func=lambda x: "English" if x=="en" else "Spanish",
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
with st.expander("Model A"):
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
sourceA = st.selectbox("Select a source Model A", SOURCES, index=0,
format_func=source2label
)
with st.expander("Model B"):
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
sourceB = st.selectbox("Select a source Model B", SOURCES, index=0,
format_func=source2label
)
sbcol1, sbcol2 = st.columns(2)
beta = sbcol1.button("A is better", on_click=replaceB)
betb = sbcol2.button("B is better", on_click=replaceA)
same = sbcol1.button("Tie", on_click=bothGood)
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth)
# regenA = sbcol1.button("Regenerate A", on_click=regenerateA)
# regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
clear = st.button("Clear History", on_click=clear_memory, kwargs={"memories":memories, "username":username, "language":language})
memories = {
'memoryA': {"issue": issue, "source": sourceA},
'memoryB': {"issue": issue, "source": sourceB},
'commonMemory': {"issue": issue, "source": SOURCES[0]}
}
changed_source = any([
st.session_state['previous_sourceA'] != sourceA,
st.session_state['previous_sourceB'] != sourceB
])
if changed_source:
st.session_state["counselor_name"] = get_random_name()
st.session_state["texter_name"] = get_random_name()
create_memory_add_initial_message(memories,
issue,
language,
changed_source=changed_source,
counselor_name=st.session_state["counselor_name"],
texter_name=st.session_state["texter_name"])
memoryA = st.session_state[list(memories.keys())[0]]
memoryB = st.session_state[list(memories.keys())[1]]
llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA, texter_name=st.session_state["texter_name"])
llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB, texter_name=st.session_state["texter_name"])
st.title(f"💬 History")
for msg in st.session_state['commonMemory'].buffer_as_messages:
role = "user" if type(msg) == HumanMessage else "assistant"
st.chat_message(role).write(msg.content)
col1, col2 = st.columns(2)
col1.title(f"💬 Simulator A")
col2.title(f"💬 Simulator B")
def reset_buttons():
buttons = [beta, betb, same, bbad,
#regenA, regenB
]
for but in buttons:
but = False
def disable_chat():
buttons = [beta, betb, same, bbad]
if any(buttons):
return False
else:
return True
if prompt := st.chat_input(disabled=disable_chat()):
if 'convo_id' not in st.session_state:
push_convo2db(memories, username, language)
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
col1.chat_message("user").write(prompt)
col2.chat_message("user").write(prompt)
responseA = llm_chainA.predict(input=prompt, stop=stopperA)
responseB = llm_chainB.predict(input=prompt, stop=stopperB)
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
col1.chat_message("assistant").write(responseA)
col2.chat_message("assistant").write(responseB)
reset_buttons()