Spaces:
Sleeping
Sleeping
deleted unnecessary file
Browse files- convosim.py +0 -99
convosim.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import streamlit as st
|
3 |
-
from streamlit.logger import get_logger
|
4 |
-
from langchain.schema.messages import HumanMessage
|
5 |
-
from utils.mongo_utils import get_db_client
|
6 |
-
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
|
7 |
-
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
-
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
-
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
|
10 |
-
|
11 |
-
logger = get_logger(__name__)
|
12 |
-
openai_api_key = os.environ['OPENAI_API_KEY']
|
13 |
-
temperature = 0.8
|
14 |
-
# username = "barb-chase" #"ivnban-ctl"
|
15 |
-
|
16 |
-
if "sent_messages" not in st.session_state:
|
17 |
-
st.session_state['sent_messages'] = 0
|
18 |
-
if "total_messages" not in st.session_state:
|
19 |
-
st.session_state['total_messages'] = 0
|
20 |
-
if "issue" not in st.session_state:
|
21 |
-
st.session_state['issue'] = ISSUES[0]
|
22 |
-
if 'previous_source' not in st.session_state:
|
23 |
-
st.session_state['previous_source'] = SOURCES[0]
|
24 |
-
if 'db_client' not in st.session_state:
|
25 |
-
st.session_state["db_client"] = get_db_client()
|
26 |
-
if 'texter_name' not in st.session_state:
|
27 |
-
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
|
28 |
-
logger.debug(f"texter name is {st.session_state['texter_name']}")
|
29 |
-
|
30 |
-
memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
|
31 |
-
|
32 |
-
with st.sidebar:
|
33 |
-
username = st.text_input("Username", value='Dani', max_chars=30)
|
34 |
-
if 'counselor_name' not in st.session_state:
|
35 |
-
st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
|
36 |
-
# temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
|
37 |
-
issue = st.selectbox("Select a Scenario", ISSUES, index=0, format_func=issue2label,
|
38 |
-
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
39 |
-
)
|
40 |
-
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
|
41 |
-
language = st.selectbox("Select a Language", supported_languages, index=0,
|
42 |
-
format_func=lambda x: "English" if x=="en" else "Spanish",
|
43 |
-
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
44 |
-
)
|
45 |
-
|
46 |
-
source = st.selectbox("Select a source Model A", SOURCES, index=0,
|
47 |
-
format_func=source2label,
|
48 |
-
)
|
49 |
-
|
50 |
-
changed_source = any([
|
51 |
-
st.session_state['previous_source'] != source,
|
52 |
-
st.session_state['issue'] != issue,
|
53 |
-
st.session_state['counselor_name'] != username,
|
54 |
-
])
|
55 |
-
if changed_source:
|
56 |
-
st.session_state["counselor_name"] = username
|
57 |
-
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
|
58 |
-
logger.debug(f"texter name is {st.session_state['texter_name']}")
|
59 |
-
st.session_state['previous_source'] = source
|
60 |
-
st.session_state['issue'] = issue
|
61 |
-
st.session_state['sent_messages'] = 0
|
62 |
-
st.session_state['total_messages'] = 0
|
63 |
-
create_memory_add_initial_message(memories,
|
64 |
-
issue,
|
65 |
-
language,
|
66 |
-
changed_source=changed_source,
|
67 |
-
counselor_name=st.session_state["counselor_name"],
|
68 |
-
texter_name=st.session_state["texter_name"])
|
69 |
-
st.session_state['previous_source'] = source
|
70 |
-
memoryA = st.session_state[list(memories.keys())[0]]
|
71 |
-
# issue only without "." marker for model compatibility
|
72 |
-
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
|
73 |
-
|
74 |
-
st.title("💬 Simulator")
|
75 |
-
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
76 |
-
for msg in memoryA.buffer_as_messages:
|
77 |
-
role = "user" if type(msg) == HumanMessage else "assistant"
|
78 |
-
st.chat_message(role).write(msg.content)
|
79 |
-
|
80 |
-
if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
|
81 |
-
st.session_state['sent_messages'] += 1
|
82 |
-
st.chat_message("user").write(prompt)
|
83 |
-
if 'convo_id' not in st.session_state:
|
84 |
-
push_convo2db(memories, username, language)
|
85 |
-
responses = custom_chain_predict(llm_chain, prompt, stopper)
|
86 |
-
# responses = llm_chain.predict(input=prompt, stop=stopper)
|
87 |
-
# response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
|
88 |
-
for response in responses:
|
89 |
-
st.chat_message("assistant").write(response)
|
90 |
-
|
91 |
-
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
92 |
-
if st.session_state['total_messages'] >= MAX_MSG_COUNT:
|
93 |
-
st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
|
94 |
-
elif st.session_state['total_messages'] >= WARN_MSG_COUT:
|
95 |
-
st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:")
|
96 |
-
|
97 |
-
with st.sidebar:
|
98 |
-
st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
|
99 |
-
st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|