Spaces:
Sleeping
Sleeping
Commit
·
cfe8e1a
1
Parent(s):
b74c038
training adherence scoring features
Browse files- app_config.py +1 -0
- models/ta_models/config.py +149 -1
- models/ta_models/ta_filter_utils.py +150 -0
- models/ta_models/ta_prompt_utils.py +128 -0
- models/ta_models/ta_utils.py +113 -0
- pages/convosim.py +6 -2
- pages/training_adherence.py +74 -0
- utils/app_utils.py +1 -0
- utils/mongo_utils.py +25 -10
app_config.py
CHANGED
@@ -57,6 +57,7 @@ DB_BATTLES = 'battles'
|
|
57 |
DB_ERRORS = 'completion_errors'
|
58 |
DB_CPC = "cpc_comparison"
|
59 |
DB_BP = "bad_practices_comparison"
|
|
|
60 |
|
61 |
MAX_MSG_COUNT = 60
|
62 |
WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
|
|
|
57 |
DB_ERRORS = 'completion_errors'
|
58 |
DB_CPC = "cpc_comparison"
|
59 |
DB_BP = "bad_practices_comparison"
|
60 |
+
DB_TA = "convo_scoring_comparison"
|
61 |
|
62 |
MAX_MSG_COUNT = 60
|
63 |
WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
|
models/ta_models/config.py
CHANGED
@@ -23,4 +23,152 @@ BP_THRESHOLD = 0.7
|
|
23 |
BP_LAB2STR = {
|
24 |
"is_advice": "Advice",
|
25 |
"is_personal_info": "Personal Info Sharing",
|
26 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
BP_LAB2STR = {
|
24 |
"is_advice": "Advice",
|
25 |
"is_personal_info": "Personal Info Sharing",
|
26 |
+
}
|
27 |
+
|
28 |
+
QUESTION2PHASE = {
|
29 |
+
"question_1": ["0_ActiveEngagement","1_Explore"],
|
30 |
+
"question_4": ["1_Explore"],
|
31 |
+
"question_5": ["0_ActiveEngagement", "1_Explore"],
|
32 |
+
# "question_7": ["1_Explore"],
|
33 |
+
# "question_9": ["4_SP&NS"],
|
34 |
+
"question_10": ["4_SP&NS"],
|
35 |
+
# "question_11": ["4_SP&NS"],
|
36 |
+
"question_14": ["6_WrappingUp"],
|
37 |
+
# "question_15": ["ALL"],
|
38 |
+
"question_19": ["ALL"],
|
39 |
+
# "question_21": ["ALL"],
|
40 |
+
# "question_22": ["ALL"],
|
41 |
+
"question_23": ["2_IRA", "3_SafetyAssessment"],
|
42 |
+
}
|
43 |
+
|
44 |
+
QUESTION2FILTERARGS = {
|
45 |
+
"question_1": {
|
46 |
+
"phases": QUESTION2PHASE["question_1"],
|
47 |
+
"pre_n": 2,
|
48 |
+
"post_n": 8,
|
49 |
+
"ignore": ["7_Other"],
|
50 |
+
},
|
51 |
+
"question_4": {
|
52 |
+
"phases": QUESTION2PHASE["question_4"],
|
53 |
+
"pre_n": 5,
|
54 |
+
"post_n": 5,
|
55 |
+
"ignore": ["7_Other"],
|
56 |
+
},
|
57 |
+
"question_5": {
|
58 |
+
"phases": QUESTION2PHASE["question_5"],
|
59 |
+
"pre_n": 5,
|
60 |
+
"post_n": 5,
|
61 |
+
"ignore": ["7_Other"],
|
62 |
+
},
|
63 |
+
# "question_7": {
|
64 |
+
# "phases": QUESTION2PHASE["question_7"],
|
65 |
+
# "pre_n": 5,
|
66 |
+
# "post_n": 15,
|
67 |
+
# "ignore": ["7_Other"],
|
68 |
+
# },
|
69 |
+
# "question_9": {
|
70 |
+
# "phases": QUESTION2PHASE["question_9"],
|
71 |
+
# "pre_n": 5,
|
72 |
+
# "post_n": 5,
|
73 |
+
# "ignore": ["7_Other"],
|
74 |
+
# },
|
75 |
+
"question_10": {
|
76 |
+
"phases": QUESTION2PHASE["question_10"],
|
77 |
+
"pre_n": 5,
|
78 |
+
"post_n": 5,
|
79 |
+
"ignore": ["7_Other"],
|
80 |
+
},
|
81 |
+
# "question_11": {
|
82 |
+
# "phases": QUESTION2PHASE["question_11"],
|
83 |
+
# "pre_n": 5,
|
84 |
+
# "post_n": 5,
|
85 |
+
# "ignore": ["7_Other"],
|
86 |
+
# },
|
87 |
+
"question_14": {
|
88 |
+
"phases": QUESTION2PHASE["question_14"],
|
89 |
+
"pre_n": 10,
|
90 |
+
"post_n": 0,
|
91 |
+
"ignore": ["7_Other"],
|
92 |
+
},
|
93 |
+
# "question_15": {
|
94 |
+
# "phases": QUESTION2PHASE["question_15"],
|
95 |
+
# "pre_n": 5,
|
96 |
+
# "post_n": 5,
|
97 |
+
# "ignore": ["7_Other"],
|
98 |
+
# },
|
99 |
+
"question_19": {
|
100 |
+
"phases": QUESTION2PHASE["question_19"],
|
101 |
+
"pre_n": 5,
|
102 |
+
"post_n": 5,
|
103 |
+
"ignore": ["7_Other"],
|
104 |
+
},
|
105 |
+
# "question_21": {
|
106 |
+
# "phases": QUESTION2PHASE["question_21"],
|
107 |
+
# "pre_n": 5,
|
108 |
+
# "post_n": 5,
|
109 |
+
# "ignore": ["7_Other"],
|
110 |
+
# },
|
111 |
+
# "question_22": {
|
112 |
+
# "phases": QUESTION2PHASE["question_22"],
|
113 |
+
# "pre_n": 5,
|
114 |
+
# "post_n": 5,
|
115 |
+
# "ignore": ["7_Other"],
|
116 |
+
# },
|
117 |
+
"question_23": {
|
118 |
+
"phases": QUESTION2PHASE["question_23"],
|
119 |
+
"pre_n": 5,
|
120 |
+
"post_n": 5,
|
121 |
+
"ignore": ["7_Other"],
|
122 |
+
},
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
START_INST = "<|user|>"
|
127 |
+
END_INST = "<|end|>\n<|assistant|>"
|
128 |
+
|
129 |
+
NAME2QUESTION = {
|
130 |
+
"question_1": "Did the helper introduce themself in the opening message? Answer only Yes or No.",
|
131 |
+
"question_4": "Did the helper actively listened to the texter's crisis? Answer only Yes or No.",
|
132 |
+
"question_5": "Did the helper reflect on the main issue that led the texter reach out? Answer only Yes or No.",
|
133 |
+
# "question_7": "Did the helper collaborated with the texter to identify the goal of the conversation? Answer only Yes or No.",
|
134 |
+
# "question_9": "Did the helper collaborated with the texter to create next steps? Answer only Yes or No.",
|
135 |
+
"question_10": "Did the helper explored texter's existing coping skills? Answer only Yes or No.",
|
136 |
+
# "question_11": "Did the helper explored texter’s social support? Answer only Yes or No.",
|
137 |
+
"question_14": "Did helper reflected the texter’s plan, reiterate coping skills, and end in a supportive way? Answer only Yes or No.",
|
138 |
+
# "question_15": "Did the helper consistently used Good Contact Techniques? Answer only Yes or No.",
|
139 |
+
"question_19": "Did the helper consistently reflected empathy through the conversation? Answer only Yes or No.",
|
140 |
+
# "question_21": "Did the helper shared personal information? Answer only Yes or No.",
|
141 |
+
# "question_22": "Did the helper gave advice? Answer only Yes or No.",
|
142 |
+
"question_23": "Did the helper explicitely initiated imminent risk assessment? Answer only Yes or No.",
|
143 |
+
}
|
144 |
+
|
145 |
+
NAME2PROMPT = {
|
146 |
+
k: "--------Conversation:\n{convo}\n{start_inst}" + v + "\n{end_inst}"
|
147 |
+
for k, v in NAME2QUESTION.items()
|
148 |
+
}
|
149 |
+
|
150 |
+
NAME2PROMPT_EXPL = {
|
151 |
+
k: v.split("Answer only Yes or No.")[0] + "Answer Yes or No, and give an explanation in a new line.\n{end_inst}"
|
152 |
+
for k, v in NAME2PROMPT.items()
|
153 |
+
}
|
154 |
+
|
155 |
+
QUESTIONDEFAULTS = {
|
156 |
+
"question_1": {True: "No, There was no evidence of Active Engagement", False: "No"},
|
157 |
+
"question_4": {True: "No, There was no evidence of Exploration Phase", False: "No"},
|
158 |
+
"question_5": {True: "No, There was no evidence of Exploration Phase", False: "No"},
|
159 |
+
# "question_7": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
160 |
+
# "question_9": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
161 |
+
"question_10": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
162 |
+
# "question_11": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
163 |
+
"question_14": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
164 |
+
# "question_15": "Did the helper consistently used Good Contact Techniques? Answer only Yes or No.",
|
165 |
+
"question_19": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
|
166 |
+
# "question_21": "Did the helper shared personal information? Answer only Yes or No.",
|
167 |
+
# "question_22": "Did the helper gave advice? Answer only Yes or No.",
|
168 |
+
"question_23": {True: "No, There was no evidence of Imminent Risk Assessment", False: "No"},
|
169 |
+
}
|
170 |
+
|
171 |
+
TEXTER_PREFIX = "texter"
|
172 |
+
HELPER_PREFIX = "helper"
|
173 |
+
|
174 |
+
TA_OPTIONS = ["N/A", "No", "Yes"]
|
models/ta_models/ta_filter_utils.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import chain
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
possible_movements = [-1, 1]
|
8 |
+
|
9 |
+
|
10 |
+
def dfs(indexes: List[int], x0: int, i: int, cur_island: List[int], d=2):
|
11 |
+
"""Deep First Search Implementation for 2D movement.
|
12 |
+
To consider an Island only move one step left or right
|
13 |
+
See possible movements
|
14 |
+
|
15 |
+
Args:
|
16 |
+
indexes (List[int]): Indexes of positive examples. i.e [20,21,23,50,51]
|
17 |
+
x0 (int): Initial island anchor
|
18 |
+
i (int): Current index to test against anchor
|
19 |
+
cur_island (List[int]): Current Island from anchor
|
20 |
+
d (int, optional): Bounding distance to consider an island. Defaults to 2. For example
|
21 |
+
the list [20,21,23,50,51] has two islands with d=2: (20,21,23), and (50,51) but it has
|
22 |
+
three islands with d=: (20,21), (23), and (50,51)
|
23 |
+
"""
|
24 |
+
rows = len(indexes)
|
25 |
+
if i < 0 or i >= rows:
|
26 |
+
return
|
27 |
+
if indexes[i] in cur_island:
|
28 |
+
return
|
29 |
+
if abs(indexes[x0] - indexes[i]) > d:
|
30 |
+
return
|
31 |
+
# computing coordinates with x0 as base
|
32 |
+
cur_island.append(indexes[i])
|
33 |
+
|
34 |
+
# repeat dfs for neighbors
|
35 |
+
for movement in possible_movements:
|
36 |
+
dfs(indexes, i, i + movement, cur_island, d)
|
37 |
+
|
38 |
+
|
39 |
+
def get_list_islands(indexes: List[int], **kwargs) -> List[List[int]]:
|
40 |
+
"""Wrapper over DFS method to obtain islands from list of indexes of positive examples
|
41 |
+
|
42 |
+
Args:
|
43 |
+
indexes (List[int]): Indexes of positive examples. i.e [20,21,23,50,51]
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
List[List[int]]: List of islands (each being a list)
|
47 |
+
"""
|
48 |
+
islands = []
|
49 |
+
rows = len(indexes)
|
50 |
+
if rows == 0:
|
51 |
+
return islands
|
52 |
+
|
53 |
+
for i, valuei in enumerate(indexes):
|
54 |
+
# If already visited index in another dfs continue
|
55 |
+
if valuei in list(chain.from_iterable(islands)):
|
56 |
+
continue
|
57 |
+
# to hold coordinates of new island
|
58 |
+
cur_island = []
|
59 |
+
dfs(indexes, i, i, cur_island, **kwargs)
|
60 |
+
|
61 |
+
islands.append(cur_island)
|
62 |
+
|
63 |
+
return islands
|
64 |
+
|
65 |
+
|
66 |
+
def get_phases_islands_minmax(
|
67 |
+
convo: pd.DataFrame,
|
68 |
+
phases: List[str],
|
69 |
+
column: str = "convo_part",
|
70 |
+
ignore: List[str] = [],
|
71 |
+
**kwargs,
|
72 |
+
) -> List[Tuple[int]]:
|
73 |
+
"""Given a conversation with predicted Phases (or Parts), get minimum and maximum index of calculated islands.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
convo (pd.DataFrame): Conversation with predicted phases stored in `column`
|
77 |
+
phases (List[str]): Phases to filter in
|
78 |
+
column (str, optional): Column where predicted phases information is stored. Defaults to "convo_part".
|
79 |
+
ignore (List[str], optional): Ignore phases list. Defaults to [].
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
List[Tuple[int]]: Minimum and Maximum values of calulated islands. i.e [(20,30), (40,60)]
|
83 |
+
"""
|
84 |
+
|
85 |
+
reset = convo.query(f"{column}=={column} and {column} not in @ignore").reset_index()
|
86 |
+
sub_ = reset.query(f"{column} in @phases").copy()
|
87 |
+
indexes = sub_.index.tolist()
|
88 |
+
islands = get_list_islands(indexes, **kwargs)
|
89 |
+
if len(islands) > 1:
|
90 |
+
# If there is more than one island we want to make sure to root out comparable small islands
|
91 |
+
# I.e. if there is an island with 10 messages, and island of 1 messages is not useful in that context.
|
92 |
+
max_len = np.max([len(x) for x in islands])
|
93 |
+
len_cut = 3 if max_len > 9 else 2 if max_len > 3 else 1
|
94 |
+
islands = [x for x in islands if len(x) > len_cut]
|
95 |
+
|
96 |
+
islands = [reset.iloc[x] for x in islands]
|
97 |
+
minmax_islands = [(x["index"].min(), x["index"].max()) for x in islands]
|
98 |
+
|
99 |
+
return minmax_islands
|
100 |
+
|
101 |
+
|
102 |
+
def filter_convo(
|
103 |
+
convo: pd.DataFrame,
|
104 |
+
phases: List[str],
|
105 |
+
column: str = "convo_part",
|
106 |
+
strategy: str = "islands",
|
107 |
+
pre_n: int = 5,
|
108 |
+
post_n: int = 5,
|
109 |
+
return_all_on_empty: bool = False,
|
110 |
+
**kwargs,
|
111 |
+
) -> pd.DataFrame:
|
112 |
+
"""Filter convo to include only specified phases. Take into account that sometimes predicted phases
|
113 |
+
can be messy. i.e. a prediciton of explore, explore, explore, safety_planning, explore; should return all
|
114 |
+
these messages as explore (probably safety_planning message has a low probability here.)
|
115 |
+
|
116 |
+
Args:
|
117 |
+
convo (pd.DataFrame): Conversation with predicted phases stored in `column`
|
118 |
+
phases (List[str]): Phases to filter in
|
119 |
+
column (str, optional): Column where predicted phases information is stored. Defaults to "convo_part".
|
120 |
+
strategy (str, optional): Strategy to use, can be minmax or islands. Defaults to "islands".
|
121 |
+
pre_n (int, optional): How many messages pre-phase to include. Defaults to 5.
|
122 |
+
post_n (int, optional): How many messages post-phase to include. Defaults to 5.
|
123 |
+
return_all_on_empty (bool, optional): Whether to return all messages when specified phases is not found. Defaults to False.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
pd.DataFrame: Filtered messages from the convo
|
127 |
+
"""
|
128 |
+
if phases == ["ALL"]:
|
129 |
+
minidx = convo.index.min()
|
130 |
+
maxidx = convo.index.max()
|
131 |
+
minmax = [(minidx, maxidx)]
|
132 |
+
elif strategy == "minmax":
|
133 |
+
minidx = convo.query(f"{column} in @phases").index.min()
|
134 |
+
maxidx = convo.query(f"{column} in @phases").index.max() + 1
|
135 |
+
minmax = [(minidx, maxidx)]
|
136 |
+
elif strategy == "islands":
|
137 |
+
minmax = get_phases_islands_minmax(convo, phases, column, **kwargs)
|
138 |
+
parts = []
|
139 |
+
for minidx, maxidx in minmax:
|
140 |
+
minidx = max(convo.index.min(), minidx - pre_n)
|
141 |
+
maxidx = min(convo.index.max(), maxidx + post_n)
|
142 |
+
parts.append(convo.loc[minidx:maxidx])
|
143 |
+
if len(parts) == 0:
|
144 |
+
if return_all_on_empty:
|
145 |
+
return convo
|
146 |
+
else:
|
147 |
+
return pd.DataFrame(columns=convo.columns)
|
148 |
+
filtered = pd.concat(parts)
|
149 |
+
filtered = filtered[~filtered.index.duplicated(keep="first")]
|
150 |
+
return filtered
|
models/ta_models/ta_prompt_utils.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from .config import QUESTION2FILTERARGS, TEXTER_PREFIX, HELPER_PREFIX
|
6 |
+
|
7 |
+
# Utils to filter convo according to a phase
|
8 |
+
from .ta_filter_utils import filter_convo
|
9 |
+
|
10 |
+
|
11 |
+
def join_messages(
|
12 |
+
grp: pd.DataFrame, texter_prefix: str = "texter", helper_prefix: str = "helper"
|
13 |
+
) -> str:
|
14 |
+
"""join messages from dataframe using texter an helper prefixes
|
15 |
+
|
16 |
+
Args:
|
17 |
+
grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
|
18 |
+
Must have the following columns:
|
19 |
+
- actor_role
|
20 |
+
- message
|
21 |
+
|
22 |
+
texter_prefix (str, optional): prefix to use as the texter. Defaults to "texter".
|
23 |
+
helper_prefix (str, optional): prefix to use as the counselor (helper). Defaults to "helper".
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
str: joined messages string separated by prefixes
|
27 |
+
"""
|
28 |
+
|
29 |
+
if "actor_role" not in grp:
|
30 |
+
raise Exception("Column 'actor_role' not in DataFrame")
|
31 |
+
if "message" not in grp:
|
32 |
+
raise Exception("Column 'message' not in DataFrame")
|
33 |
+
|
34 |
+
roles = grp.actor_role.replace(
|
35 |
+
{"texter": texter_prefix, "counselor": helper_prefix, "helper": helper_prefix}
|
36 |
+
)
|
37 |
+
messages = roles.str.strip() + ": " + grp.message.str.strip()
|
38 |
+
return "\n".join(messages)
|
39 |
+
|
40 |
+
|
41 |
+
def _get_context(grp: pd.DataFrame, **kwargs) -> str:
|
42 |
+
"""Get context as a str taking into account message to delete, context marker
|
43 |
+
and the type of question to use. This allows for better truncation later
|
44 |
+
|
45 |
+
Args:
|
46 |
+
grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
|
47 |
+
Must have the following columns:
|
48 |
+
- actor_role
|
49 |
+
- message
|
50 |
+
- `column`
|
51 |
+
column (str): column name in which the marker of the problem is
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
pd.DataFrame: joined messages string separated by prefixes
|
55 |
+
"""
|
56 |
+
|
57 |
+
if "actor_role" not in grp:
|
58 |
+
raise Exception("Column 'actor_role' not in DataFrame")
|
59 |
+
if "message" not in grp:
|
60 |
+
raise Exception("Column 'message' not in DataFrame")
|
61 |
+
|
62 |
+
join_args = list(inspect.signature(join_messages).parameters)
|
63 |
+
join_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in join_args}
|
64 |
+
|
65 |
+
## DEPRECATED
|
66 |
+
# context_args = list(inspect.signature(get_context_on_marker).parameters)
|
67 |
+
# context_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in context_args}
|
68 |
+
|
69 |
+
return join_messages(grp, **join_kwargs)
|
70 |
+
|
71 |
+
|
72 |
+
def load_context(
|
73 |
+
messages: pd.DataFrame,
|
74 |
+
question: str,
|
75 |
+
message_col: str,
|
76 |
+
col_type: str,
|
77 |
+
inference: bool = False,
|
78 |
+
**kwargs,
|
79 |
+
) -> pd.DataFrame:
|
80 |
+
"""Load and filter conversation from messages given a question (with configured parameters of what phase that question is answered)
|
81 |
+
|
82 |
+
Args:
|
83 |
+
messages (pd.DataFrame): Messages dataframe with conversation_id, actor_role, `message_col` and phase prediction
|
84 |
+
question (str): Question to get context to
|
85 |
+
message_col (str): Column where messages are
|
86 |
+
col_type (str): type of message_col, can be "individual" or "joined"
|
87 |
+
base_dir (str, optional): Base directory to find model base args. Defaults to "../../".
|
88 |
+
|
89 |
+
Raises:
|
90 |
+
Exception: If question is not supported
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
pd.DataFrame: filtered messages according to question configuration
|
94 |
+
"""
|
95 |
+
|
96 |
+
if question not in QUESTION2FILTERARGS:
|
97 |
+
raise Exception(f"Question {question} not supported")
|
98 |
+
|
99 |
+
texter_prefix = TEXTER_PREFIX
|
100 |
+
helper_prefix = HELPER_PREFIX
|
101 |
+
context_data = messages.copy()
|
102 |
+
|
103 |
+
def convo_cpc_get_context(grp, **kwargs):
|
104 |
+
"""Filter convo according to Convo Phase Classifier (CPC) predictions"""
|
105 |
+
context_ = filter_convo(grp, **QUESTION2FILTERARGS[question])
|
106 |
+
return _get_context(context_, **kwargs)
|
107 |
+
|
108 |
+
if col_type == "individual":
|
109 |
+
if "actor_role" in context_data:
|
110 |
+
context_data.dropna(subset=["actor_role"], inplace=True)
|
111 |
+
if "delete_message" in context_data:
|
112 |
+
context_data.delete_message.replace({1: True}, inplace=True)
|
113 |
+
context_data.delete_message.fillna(False, inplace=True)
|
114 |
+
|
115 |
+
context_data = (
|
116 |
+
context_data.groupby("conversation_id")
|
117 |
+
.apply(
|
118 |
+
convo_cpc_get_context,
|
119 |
+
helper_prefix=helper_prefix,
|
120 |
+
texter_prefix=texter_prefix,
|
121 |
+
)
|
122 |
+
.rename("q_context")
|
123 |
+
)
|
124 |
+
elif col_type == "joined":
|
125 |
+
context_data = context_data.groupby("conversation_id")[[message_col]].max()
|
126 |
+
context_data.rename(columns={message_col: "q_context"}, inplace=True)
|
127 |
+
|
128 |
+
return context_data
|
models/ta_models/ta_utils.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import requests
|
4 |
+
import string
|
5 |
+
import streamlit as st
|
6 |
+
from streamlit.logger import get_logger
|
7 |
+
from app_config import ENDPOINT_NAMES
|
8 |
+
from models.ta_models.config import NAME2PROMPT, NAME2PROMPT_EXPL, START_INST, END_INST, QUESTIONDEFAULTS, TA_OPTIONS, NAME2QUESTION
|
9 |
+
import pandas as pd
|
10 |
+
from langchain_core.messages import AIMessage, HumanMessage
|
11 |
+
from models.ta_models.ta_prompt_utils import load_context
|
12 |
+
from utils.mongo_utils import new_convo_scoring_comparison
|
13 |
+
|
14 |
+
logger = get_logger(__name__)
|
15 |
+
TA_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["training_adherence"]['name'])
|
16 |
+
HEADERS = {
|
17 |
+
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
18 |
+
"Content-Type": "application/json",
|
19 |
+
}
|
20 |
+
|
21 |
+
def memory2df(memory, conversation_id="convo1234"):
|
22 |
+
df = []
|
23 |
+
for i, msg in enumerate(memory.buffer_as_messages):
|
24 |
+
actor_role = "texter" if type(msg) == AIMessage else "helper" if type(msg) == HumanMessage else None
|
25 |
+
if actor_role:
|
26 |
+
convo_part = msg.response_metadata.get("phase",None)
|
27 |
+
row = {"conversation_id":conversation_id, "message_number":i+1, "actor_role":actor_role, "message":msg.content, "convo_part":convo_part}
|
28 |
+
df.append(row)
|
29 |
+
|
30 |
+
return pd.DataFrame(df)
|
31 |
+
|
32 |
+
def get_default(question, make_explanation=False):
|
33 |
+
return QUESTIONDEFAULTS[question][make_explanation]
|
34 |
+
|
35 |
+
def get_context(memory, question, make_explanation=False, **kwargs):
|
36 |
+
df = memory2df(memory, **kwargs)
|
37 |
+
contexti = load_context(df, question, "messages", "individual").iloc[0]
|
38 |
+
if contexti == "":
|
39 |
+
return ""
|
40 |
+
|
41 |
+
if make_explanation:
|
42 |
+
return NAME2PROMPT_EXPL[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
|
43 |
+
else:
|
44 |
+
return NAME2PROMPT[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
|
45 |
+
|
46 |
+
def post_process_response(full_response, delimiter="\n\n", n=2):
|
47 |
+
parts = full_response.split(delimiter)[:n]
|
48 |
+
response = extract_response(parts[0])
|
49 |
+
logger.debug(f"Response extracted is {response}")
|
50 |
+
if len(parts) > 1:
|
51 |
+
if len(parts[0]) < len(parts[1]):
|
52 |
+
full_response = parts[1]
|
53 |
+
else: full_response = parts[0]
|
54 |
+
else:
|
55 |
+
full_response = parts[0]
|
56 |
+
explanation = full_response.lstrip(response).lstrip(string.punctuation)
|
57 |
+
explanation = explanation.strip()
|
58 |
+
logger.debug(f"Explanation extracted is {explanation}")
|
59 |
+
return response, explanation
|
60 |
+
|
61 |
+
def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
|
62 |
+
full_convo = memory.load_memory_variables({})[memory.memory_key]
|
63 |
+
PROMPT = get_context(memory, question, make_explanation=make_explanation, **kwargs)
|
64 |
+
logger.debug(f"Raw TA prompt is {PROMPT}")
|
65 |
+
if PROMPT == "":
|
66 |
+
full_response = get_default(question, make_explanation)
|
67 |
+
# response, explanation = post_process_response(full_response)
|
68 |
+
return full_convo, PROMPT, full_response
|
69 |
+
|
70 |
+
max_tokens = 128 if make_explanation else 3
|
71 |
+
body_request = {
|
72 |
+
"prompt": PROMPT,
|
73 |
+
"temperature": 0,
|
74 |
+
"max_tokens": max_tokens,
|
75 |
+
}
|
76 |
+
|
77 |
+
try:
|
78 |
+
# Send request to Serving
|
79 |
+
response = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
|
80 |
+
if response.status_code == 200:
|
81 |
+
response = response.json()
|
82 |
+
full_response = response[0]['choices'][0]['text']
|
83 |
+
logger.debug(f"Raw TA response is {full_response}")
|
84 |
+
# response, explanation = post_process_response(full_response)
|
85 |
+
return full_convo, PROMPT, full_response
|
86 |
+
except:
|
87 |
+
pass
|
88 |
+
|
89 |
+
def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str:
|
90 |
+
"""Extract Response from generated answer
|
91 |
+
Extract only search strings
|
92 |
+
|
93 |
+
Args:
|
94 |
+
x (str): prediction
|
95 |
+
default (str, optional): default in case no response founds. Defaults to "N/A".
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
str: _description_
|
99 |
+
"""
|
100 |
+
|
101 |
+
try:
|
102 |
+
return re.findall("|".join(TA_OPTIONS), x)[0]
|
103 |
+
except Exception:
|
104 |
+
return default
|
105 |
+
|
106 |
+
def ta_push_convo_comparison(ytrue, ypred):
|
107 |
+
new_convo_scoring_comparison(**{
|
108 |
+
"client": st.session_state['db_client'],
|
109 |
+
"convo_id": st.session_state['convo_id'],
|
110 |
+
"context": st.session_state["context"] + "\nhelper:" + st.session_state["last_message"],
|
111 |
+
"ytrue": ytrue,
|
112 |
+
"ypred": ypred,
|
113 |
+
})
|
pages/convosim.py
CHANGED
@@ -49,7 +49,7 @@ with st.sidebar:
|
|
49 |
if 'counselor_name' not in st.session_state:
|
50 |
st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
|
51 |
# temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
|
52 |
-
issue = st.selectbox("Select a Scenario", ISSUES, index=
|
53 |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
54 |
)
|
55 |
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
|
@@ -135,8 +135,9 @@ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG
|
|
135 |
if any([x['score'] for x in st.session_state['bp_prediction']]):
|
136 |
for bp in st.session_state['bp_prediction']:
|
137 |
if bp["score"]:
|
138 |
-
st.
|
139 |
st.session_state.changed_bp = True
|
|
|
140 |
else:
|
141 |
sent_request_llm(llm_chain, prompt)
|
142 |
|
@@ -171,6 +172,9 @@ with st.sidebar:
|
|
171 |
key="sel_bp"
|
172 |
)
|
173 |
|
|
|
|
|
|
|
174 |
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
175 |
if st.session_state['total_messages'] >= MAX_MSG_COUNT:
|
176 |
st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
|
|
|
49 |
if 'counselor_name' not in st.session_state:
|
50 |
st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
|
51 |
# temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
|
52 |
+
issue = st.selectbox("Select a Scenario", ISSUES, index=ISSUES.index(st.session_state['issue']), format_func=issue2label,
|
53 |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
54 |
)
|
55 |
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
|
|
|
135 |
if any([x['score'] for x in st.session_state['bp_prediction']]):
|
136 |
for bp in st.session_state['bp_prediction']:
|
137 |
if bp["score"]:
|
138 |
+
st.toast(f"Detected {BP_LAB2STR[bp['label']]} in the last message!", icon=":material/warning:")
|
139 |
st.session_state.changed_bp = True
|
140 |
+
sent_request_llm(llm_chain, prompt)
|
141 |
else:
|
142 |
sent_request_llm(llm_chain, prompt)
|
143 |
|
|
|
172 |
key="sel_bp"
|
173 |
)
|
174 |
|
175 |
+
if st.button("Score Conversation"):
|
176 |
+
st.switch_page("pages/training_adherence.py")
|
177 |
+
|
178 |
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
179 |
if st.session_state['total_messages'] >= MAX_MSG_COUNT:
|
180 |
st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
|
pages/training_adherence.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
from collections import defaultdict
|
4 |
+
from langchain_core.messages import HumanMessage
|
5 |
+
from utils.app_utils import are_models_alive
|
6 |
+
from models.ta_models.ta_utils import TA_predict_convo, ta_push_convo_comparison, post_process_response
|
7 |
+
from models.ta_models.config import QUESTION2PHASE, NAME2QUESTION, TA_OPTIONS
|
8 |
+
|
9 |
+
if "memory" not in st.session_state:
|
10 |
+
st.switch_page("pages/convosim.py")
|
11 |
+
|
12 |
+
if not are_models_alive():
|
13 |
+
st.switch_page("pages/model_loader.py")
|
14 |
+
|
15 |
+
memory = st.session_state['memory']
|
16 |
+
@st.cache_data(show_spinner="Retrieving responses from the server ...")
|
17 |
+
def get_ta_responses():
|
18 |
+
data = defaultdict(defaultdict)
|
19 |
+
# with st.spinner("Retrieving responses from the server ...")
|
20 |
+
for question in QUESTION2PHASE.keys():
|
21 |
+
# responses = ["Yes, The helper showed some respect.",
|
22 |
+
# "Yes. The helper is good! No doubt",
|
23 |
+
# "N/A, Texter disengaged.",
|
24 |
+
# "No. While texter is trying is lacking.",
|
25 |
+
# "No \n\n This is an explanation."]
|
26 |
+
# full_response = np.random.choice(responses)
|
27 |
+
full_convo, prompt, full_response = TA_predict_convo(memory, question, make_explanation=True, conversation_id=st.session_state['convo_id'])
|
28 |
+
response, explanation = post_process_response(full_response)
|
29 |
+
data[question]["response"] = response
|
30 |
+
data[question]["explanation"] = explanation
|
31 |
+
return data
|
32 |
+
|
33 |
+
with st.container():
|
34 |
+
col1, col2 = st.columns(2)
|
35 |
+
if col1.button("Go Back"):
|
36 |
+
get_ta_responses.clear()
|
37 |
+
st.switch_page("pages/convosim.py")
|
38 |
+
expl = col2.checkbox("Show Scoring Explanations")
|
39 |
+
|
40 |
+
tab1, tab2 = st.tabs(["Scoring", "Conversation"])
|
41 |
+
data = get_ta_responses()
|
42 |
+
|
43 |
+
with tab2:
|
44 |
+
for msg in memory.buffer_as_messages:
|
45 |
+
role = "user" if type(msg) == HumanMessage else "assistant"
|
46 |
+
st.chat_message(role).write(msg.content)
|
47 |
+
|
48 |
+
with tab1:
|
49 |
+
for question in QUESTION2PHASE.keys():
|
50 |
+
with st.container(border=True):
|
51 |
+
question_str = NAME2QUESTION[question].split(' Answer')[0]
|
52 |
+
st.radio(
|
53 |
+
f"**{question_str}**", options=TA_OPTIONS,
|
54 |
+
index=TA_OPTIONS.index(data[question]['response']), horizontal=True,
|
55 |
+
key=f"{question}_manual"
|
56 |
+
)
|
57 |
+
if expl:
|
58 |
+
st.write(data[question]["explanation"])
|
59 |
+
|
60 |
+
with st.container():
|
61 |
+
col1, col2 = st.columns(2)
|
62 |
+
if col1.button("Go Back", key="goback2"):
|
63 |
+
get_ta_responses.clear()
|
64 |
+
st.switch_page("pages/convosim.py")
|
65 |
+
if col2.button("Submit Scoring", type="primary"):
|
66 |
+
ytrue = {
|
67 |
+
question: {
|
68 |
+
"response":st.session_state[f"{question}_manual"]
|
69 |
+
}
|
70 |
+
for question in QUESTION2PHASE.keys()
|
71 |
+
}
|
72 |
+
ta_push_convo_comparison(ytrue, data)
|
73 |
+
get_ta_responses.clear()
|
74 |
+
st.switch_page("pages/convosim.py")
|
utils/app_utils.py
CHANGED
@@ -101,6 +101,7 @@ def is_model_alive(name, timeout=2, model_type="classificator"):
|
|
101 |
except:
|
102 |
return "404"
|
103 |
|
|
|
104 |
def are_models_alive():
|
105 |
models_alive = []
|
106 |
for config in ENDPOINT_NAMES.values():
|
|
|
101 |
except:
|
102 |
return "404"
|
103 |
|
104 |
+
@st.cache_data(ttl=300, show_spinner=False)
|
105 |
def are_models_alive():
|
106 |
models_alive = []
|
107 |
for config in ENDPOINT_NAMES.values():
|
utils/mongo_utils.py
CHANGED
@@ -4,7 +4,7 @@ import streamlit as st
|
|
4 |
from streamlit.logger import get_logger
|
5 |
from pymongo.mongo_client import MongoClient
|
6 |
from pymongo.server_api import ServerApi
|
7 |
-
from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS, DB_CPC, DB_BP
|
8 |
|
9 |
DB_URL = os.environ['MONGO_URL']
|
10 |
DB_USR = os.environ['MONGO_USR']
|
@@ -19,7 +19,7 @@ def get_db_client():
|
|
19 |
# Send a ping to confirm a successful connection
|
20 |
try:
|
21 |
client.admin.command('ping')
|
22 |
-
logger.
|
23 |
return client
|
24 |
except Exception as e:
|
25 |
logger.error(e)
|
@@ -38,7 +38,7 @@ def new_convo(client, issue, language, username, is_comparison, model_one, model
|
|
38 |
db = client[DB_SCHEMA]
|
39 |
convos = db[DB_CONVOS]
|
40 |
convo_id = convos.insert_one(convo).inserted_id
|
41 |
-
logger.
|
42 |
st.session_state['convo_id'] = convo_id
|
43 |
|
44 |
def new_comparison(client, prompt_timestamp, completion_timestamp,
|
@@ -66,7 +66,7 @@ def new_comparison(client, prompt_timestamp, completion_timestamp,
|
|
66 |
db = client[DB_SCHEMA]
|
67 |
comparisons = db[DB_COMPLETIONS]
|
68 |
comparison_id = comparisons.insert_one(comparison).inserted_id
|
69 |
-
logger.
|
70 |
st.session_state['comparison_id'] = comparison_id
|
71 |
|
72 |
def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner):
|
@@ -84,7 +84,7 @@ def new_battle_result(client, comparison_id, convo_id, username, model_one, mode
|
|
84 |
db = client[DB_SCHEMA]
|
85 |
battles = db[DB_BATTLES]
|
86 |
battle_id = battles.insert_one(battle).inserted_id
|
87 |
-
logger.
|
88 |
|
89 |
def new_completion_error(client, comparison_id, username, model):
|
90 |
error = {
|
@@ -97,12 +97,12 @@ def new_completion_error(client, comparison_id, username, model):
|
|
97 |
db = client[DB_SCHEMA]
|
98 |
errors = db[DB_ERRORS]
|
99 |
error_id = errors.insert_one(error).inserted_id
|
100 |
-
logger.
|
101 |
|
102 |
def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
|
103 |
# context = memory.load_memory_variables({})[memory.memory_key]
|
104 |
comp = {
|
105 |
-
"
|
106 |
"conversation_id": convo_id,
|
107 |
"model": model,
|
108 |
"context": context,
|
@@ -114,12 +114,12 @@ def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, yp
|
|
114 |
db = client[DB_SCHEMA]
|
115 |
cpc_comps = db[DB_CPC]
|
116 |
comarison_id = cpc_comps.insert_one(comp).inserted_id
|
117 |
-
|
118 |
|
119 |
def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
|
120 |
# context = memory.load_memory_variables({})[memory.memory_key]
|
121 |
comp = {
|
122 |
-
"
|
123 |
"conversation_id": convo_id,
|
124 |
"model": model,
|
125 |
"context": context,
|
@@ -133,7 +133,22 @@ def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypr
|
|
133 |
db = client[DB_SCHEMA]
|
134 |
bp_comps = db[DB_BP]
|
135 |
comarison_id = bp_comps.insert_one(comp).inserted_id
|
136 |
-
logger.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
def get_non_assesed_comparison(client, username):
|
139 |
from bson.son import SON
|
|
|
4 |
from streamlit.logger import get_logger
|
5 |
from pymongo.mongo_client import MongoClient
|
6 |
from pymongo.server_api import ServerApi
|
7 |
+
from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS, DB_CPC, DB_BP, DB_TA
|
8 |
|
9 |
DB_URL = os.environ['MONGO_URL']
|
10 |
DB_USR = os.environ['MONGO_USR']
|
|
|
19 |
# Send a ping to confirm a successful connection
|
20 |
try:
|
21 |
client.admin.command('ping')
|
22 |
+
logger.debug(f"DBUTILS: Pinged your deployment. You successfully connected to MongoDB!")
|
23 |
return client
|
24 |
except Exception as e:
|
25 |
logger.error(e)
|
|
|
38 |
db = client[DB_SCHEMA]
|
39 |
convos = db[DB_CONVOS]
|
40 |
convo_id = convos.insert_one(convo).inserted_id
|
41 |
+
logger.debug(f"DBUTILS: new convo id is {convo_id}")
|
42 |
st.session_state['convo_id'] = convo_id
|
43 |
|
44 |
def new_comparison(client, prompt_timestamp, completion_timestamp,
|
|
|
66 |
db = client[DB_SCHEMA]
|
67 |
comparisons = db[DB_COMPLETIONS]
|
68 |
comparison_id = comparisons.insert_one(comparison).inserted_id
|
69 |
+
logger.debug(f"DBUTILS: new comparison id is {comparison_id}")
|
70 |
st.session_state['comparison_id'] = comparison_id
|
71 |
|
72 |
def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner):
|
|
|
84 |
db = client[DB_SCHEMA]
|
85 |
battles = db[DB_BATTLES]
|
86 |
battle_id = battles.insert_one(battle).inserted_id
|
87 |
+
logger.debug(f"DBUTILS: new battle id is {battle_id}")
|
88 |
|
89 |
def new_completion_error(client, comparison_id, username, model):
|
90 |
error = {
|
|
|
97 |
db = client[DB_SCHEMA]
|
98 |
errors = db[DB_ERRORS]
|
99 |
error_id = errors.insert_one(error).inserted_id
|
100 |
+
logger.debug(f"DBUTILS: new error id is {error_id}")
|
101 |
|
102 |
def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
|
103 |
# context = memory.load_memory_variables({})[memory.memory_key]
|
104 |
comp = {
|
105 |
+
"CPC_timestamp": dt.datetime.now(tz=dt.timezone.utc),
|
106 |
"conversation_id": convo_id,
|
107 |
"model": model,
|
108 |
"context": context,
|
|
|
114 |
db = client[DB_SCHEMA]
|
115 |
cpc_comps = db[DB_CPC]
|
116 |
comarison_id = cpc_comps.insert_one(comp).inserted_id
|
117 |
+
logger.debug(f"DBUTILS: new error id is {comarison_id}")
|
118 |
|
119 |
def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
|
120 |
# context = memory.load_memory_variables({})[memory.memory_key]
|
121 |
comp = {
|
122 |
+
"BP_timestamp": dt.datetime.now(tz=dt.timezone.utc),
|
123 |
"conversation_id": convo_id,
|
124 |
"model": model,
|
125 |
"context": context,
|
|
|
133 |
db = client[DB_SCHEMA]
|
134 |
bp_comps = db[DB_BP]
|
135 |
comarison_id = bp_comps.insert_one(comp).inserted_id
|
136 |
+
logger.debug(f"DBUTILS: new BP id is {comarison_id}")
|
137 |
+
|
138 |
+
def new_convo_scoring_comparison(client, convo_id, context, ytrue, ypred):
|
139 |
+
# context = memory.load_memory_variables({})[memory.memory_key]
|
140 |
+
comp = {
|
141 |
+
"scoring_timestamp": dt.datetime.now(tz=dt.timezone.utc),
|
142 |
+
"conversation_id": convo_id,
|
143 |
+
"context": context,
|
144 |
+
"manual_scoring": ytrue,
|
145 |
+
"model_scoring": ypred,
|
146 |
+
}
|
147 |
+
|
148 |
+
db = client[DB_SCHEMA]
|
149 |
+
ta_comps = db[DB_TA]
|
150 |
+
comarison_id = ta_comps.insert_one(comp).inserted_id
|
151 |
+
logger.debug(f"DBUTILS: new TA convo comparison id is {comarison_id}")
|
152 |
|
153 |
def get_non_assesed_comparison(client, username):
|
154 |
from bson.son import SON
|