ivnban27-ctl commited on
Commit
cfe8e1a
·
1 Parent(s): b74c038

training adherence scoring features

Browse files
app_config.py CHANGED
@@ -57,6 +57,7 @@ DB_BATTLES = 'battles'
57
  DB_ERRORS = 'completion_errors'
58
  DB_CPC = "cpc_comparison"
59
  DB_BP = "bad_practices_comparison"
 
60
 
61
  MAX_MSG_COUNT = 60
62
  WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
 
57
  DB_ERRORS = 'completion_errors'
58
  DB_CPC = "cpc_comparison"
59
  DB_BP = "bad_practices_comparison"
60
+ DB_TA = "convo_scoring_comparison"
61
 
62
  MAX_MSG_COUNT = 60
63
  WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
models/ta_models/config.py CHANGED
@@ -23,4 +23,152 @@ BP_THRESHOLD = 0.7
23
  BP_LAB2STR = {
24
  "is_advice": "Advice",
25
  "is_personal_info": "Personal Info Sharing",
26
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  BP_LAB2STR = {
24
  "is_advice": "Advice",
25
  "is_personal_info": "Personal Info Sharing",
26
+ }
27
+
28
+ QUESTION2PHASE = {
29
+ "question_1": ["0_ActiveEngagement","1_Explore"],
30
+ "question_4": ["1_Explore"],
31
+ "question_5": ["0_ActiveEngagement", "1_Explore"],
32
+ # "question_7": ["1_Explore"],
33
+ # "question_9": ["4_SP&NS"],
34
+ "question_10": ["4_SP&NS"],
35
+ # "question_11": ["4_SP&NS"],
36
+ "question_14": ["6_WrappingUp"],
37
+ # "question_15": ["ALL"],
38
+ "question_19": ["ALL"],
39
+ # "question_21": ["ALL"],
40
+ # "question_22": ["ALL"],
41
+ "question_23": ["2_IRA", "3_SafetyAssessment"],
42
+ }
43
+
44
+ QUESTION2FILTERARGS = {
45
+ "question_1": {
46
+ "phases": QUESTION2PHASE["question_1"],
47
+ "pre_n": 2,
48
+ "post_n": 8,
49
+ "ignore": ["7_Other"],
50
+ },
51
+ "question_4": {
52
+ "phases": QUESTION2PHASE["question_4"],
53
+ "pre_n": 5,
54
+ "post_n": 5,
55
+ "ignore": ["7_Other"],
56
+ },
57
+ "question_5": {
58
+ "phases": QUESTION2PHASE["question_5"],
59
+ "pre_n": 5,
60
+ "post_n": 5,
61
+ "ignore": ["7_Other"],
62
+ },
63
+ # "question_7": {
64
+ # "phases": QUESTION2PHASE["question_7"],
65
+ # "pre_n": 5,
66
+ # "post_n": 15,
67
+ # "ignore": ["7_Other"],
68
+ # },
69
+ # "question_9": {
70
+ # "phases": QUESTION2PHASE["question_9"],
71
+ # "pre_n": 5,
72
+ # "post_n": 5,
73
+ # "ignore": ["7_Other"],
74
+ # },
75
+ "question_10": {
76
+ "phases": QUESTION2PHASE["question_10"],
77
+ "pre_n": 5,
78
+ "post_n": 5,
79
+ "ignore": ["7_Other"],
80
+ },
81
+ # "question_11": {
82
+ # "phases": QUESTION2PHASE["question_11"],
83
+ # "pre_n": 5,
84
+ # "post_n": 5,
85
+ # "ignore": ["7_Other"],
86
+ # },
87
+ "question_14": {
88
+ "phases": QUESTION2PHASE["question_14"],
89
+ "pre_n": 10,
90
+ "post_n": 0,
91
+ "ignore": ["7_Other"],
92
+ },
93
+ # "question_15": {
94
+ # "phases": QUESTION2PHASE["question_15"],
95
+ # "pre_n": 5,
96
+ # "post_n": 5,
97
+ # "ignore": ["7_Other"],
98
+ # },
99
+ "question_19": {
100
+ "phases": QUESTION2PHASE["question_19"],
101
+ "pre_n": 5,
102
+ "post_n": 5,
103
+ "ignore": ["7_Other"],
104
+ },
105
+ # "question_21": {
106
+ # "phases": QUESTION2PHASE["question_21"],
107
+ # "pre_n": 5,
108
+ # "post_n": 5,
109
+ # "ignore": ["7_Other"],
110
+ # },
111
+ # "question_22": {
112
+ # "phases": QUESTION2PHASE["question_22"],
113
+ # "pre_n": 5,
114
+ # "post_n": 5,
115
+ # "ignore": ["7_Other"],
116
+ # },
117
+ "question_23": {
118
+ "phases": QUESTION2PHASE["question_23"],
119
+ "pre_n": 5,
120
+ "post_n": 5,
121
+ "ignore": ["7_Other"],
122
+ },
123
+ }
124
+
125
+
126
+ START_INST = "<|user|>"
127
+ END_INST = "<|end|>\n<|assistant|>"
128
+
129
+ NAME2QUESTION = {
130
+ "question_1": "Did the helper introduce themself in the opening message? Answer only Yes or No.",
131
+ "question_4": "Did the helper actively listened to the texter's crisis? Answer only Yes or No.",
132
+ "question_5": "Did the helper reflect on the main issue that led the texter reach out? Answer only Yes or No.",
133
+ # "question_7": "Did the helper collaborated with the texter to identify the goal of the conversation? Answer only Yes or No.",
134
+ # "question_9": "Did the helper collaborated with the texter to create next steps? Answer only Yes or No.",
135
+ "question_10": "Did the helper explored texter's existing coping skills? Answer only Yes or No.",
136
+ # "question_11": "Did the helper explored texter’s social support? Answer only Yes or No.",
137
+ "question_14": "Did helper reflected the texter’s plan, reiterate coping skills, and end in a supportive way? Answer only Yes or No.",
138
+ # "question_15": "Did the helper consistently used Good Contact Techniques? Answer only Yes or No.",
139
+ "question_19": "Did the helper consistently reflected empathy through the conversation? Answer only Yes or No.",
140
+ # "question_21": "Did the helper shared personal information? Answer only Yes or No.",
141
+ # "question_22": "Did the helper gave advice? Answer only Yes or No.",
142
+ "question_23": "Did the helper explicitely initiated imminent risk assessment? Answer only Yes or No.",
143
+ }
144
+
145
+ NAME2PROMPT = {
146
+ k: "--------Conversation:\n{convo}\n{start_inst}" + v + "\n{end_inst}"
147
+ for k, v in NAME2QUESTION.items()
148
+ }
149
+
150
+ NAME2PROMPT_EXPL = {
151
+ k: v.split("Answer only Yes or No.")[0] + "Answer Yes or No, and give an explanation in a new line.\n{end_inst}"
152
+ for k, v in NAME2PROMPT.items()
153
+ }
154
+
155
+ QUESTIONDEFAULTS = {
156
+ "question_1": {True: "No, There was no evidence of Active Engagement", False: "No"},
157
+ "question_4": {True: "No, There was no evidence of Exploration Phase", False: "No"},
158
+ "question_5": {True: "No, There was no evidence of Exploration Phase", False: "No"},
159
+ # "question_7": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
160
+ # "question_9": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
161
+ "question_10": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
162
+ # "question_11": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
163
+ "question_14": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
164
+ # "question_15": "Did the helper consistently used Good Contact Techniques? Answer only Yes or No.",
165
+ "question_19": {True: "N/A Texter disengaged, Not Applicable", False: "N/A"},
166
+ # "question_21": "Did the helper shared personal information? Answer only Yes or No.",
167
+ # "question_22": "Did the helper gave advice? Answer only Yes or No.",
168
+ "question_23": {True: "No, There was no evidence of Imminent Risk Assessment", False: "No"},
169
+ }
170
+
171
+ TEXTER_PREFIX = "texter"
172
+ HELPER_PREFIX = "helper"
173
+
174
+ TA_OPTIONS = ["N/A", "No", "Yes"]
models/ta_models/ta_filter_utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import chain
2
+ from typing import List, Tuple
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ possible_movements = [-1, 1]
8
+
9
+
10
+ def dfs(indexes: List[int], x0: int, i: int, cur_island: List[int], d=2):
11
+ """Deep First Search Implementation for 2D movement.
12
+ To consider an Island only move one step left or right
13
+ See possible movements
14
+
15
+ Args:
16
+ indexes (List[int]): Indexes of positive examples. i.e [20,21,23,50,51]
17
+ x0 (int): Initial island anchor
18
+ i (int): Current index to test against anchor
19
+ cur_island (List[int]): Current Island from anchor
20
+ d (int, optional): Bounding distance to consider an island. Defaults to 2. For example
21
+ the list [20,21,23,50,51] has two islands with d=2: (20,21,23), and (50,51) but it has
22
+ three islands with d=: (20,21), (23), and (50,51)
23
+ """
24
+ rows = len(indexes)
25
+ if i < 0 or i >= rows:
26
+ return
27
+ if indexes[i] in cur_island:
28
+ return
29
+ if abs(indexes[x0] - indexes[i]) > d:
30
+ return
31
+ # computing coordinates with x0 as base
32
+ cur_island.append(indexes[i])
33
+
34
+ # repeat dfs for neighbors
35
+ for movement in possible_movements:
36
+ dfs(indexes, i, i + movement, cur_island, d)
37
+
38
+
39
+ def get_list_islands(indexes: List[int], **kwargs) -> List[List[int]]:
40
+ """Wrapper over DFS method to obtain islands from list of indexes of positive examples
41
+
42
+ Args:
43
+ indexes (List[int]): Indexes of positive examples. i.e [20,21,23,50,51]
44
+
45
+ Returns:
46
+ List[List[int]]: List of islands (each being a list)
47
+ """
48
+ islands = []
49
+ rows = len(indexes)
50
+ if rows == 0:
51
+ return islands
52
+
53
+ for i, valuei in enumerate(indexes):
54
+ # If already visited index in another dfs continue
55
+ if valuei in list(chain.from_iterable(islands)):
56
+ continue
57
+ # to hold coordinates of new island
58
+ cur_island = []
59
+ dfs(indexes, i, i, cur_island, **kwargs)
60
+
61
+ islands.append(cur_island)
62
+
63
+ return islands
64
+
65
+
66
+ def get_phases_islands_minmax(
67
+ convo: pd.DataFrame,
68
+ phases: List[str],
69
+ column: str = "convo_part",
70
+ ignore: List[str] = [],
71
+ **kwargs,
72
+ ) -> List[Tuple[int]]:
73
+ """Given a conversation with predicted Phases (or Parts), get minimum and maximum index of calculated islands.
74
+
75
+ Args:
76
+ convo (pd.DataFrame): Conversation with predicted phases stored in `column`
77
+ phases (List[str]): Phases to filter in
78
+ column (str, optional): Column where predicted phases information is stored. Defaults to "convo_part".
79
+ ignore (List[str], optional): Ignore phases list. Defaults to [].
80
+
81
+ Returns:
82
+ List[Tuple[int]]: Minimum and Maximum values of calulated islands. i.e [(20,30), (40,60)]
83
+ """
84
+
85
+ reset = convo.query(f"{column}=={column} and {column} not in @ignore").reset_index()
86
+ sub_ = reset.query(f"{column} in @phases").copy()
87
+ indexes = sub_.index.tolist()
88
+ islands = get_list_islands(indexes, **kwargs)
89
+ if len(islands) > 1:
90
+ # If there is more than one island we want to make sure to root out comparable small islands
91
+ # I.e. if there is an island with 10 messages, and island of 1 messages is not useful in that context.
92
+ max_len = np.max([len(x) for x in islands])
93
+ len_cut = 3 if max_len > 9 else 2 if max_len > 3 else 1
94
+ islands = [x for x in islands if len(x) > len_cut]
95
+
96
+ islands = [reset.iloc[x] for x in islands]
97
+ minmax_islands = [(x["index"].min(), x["index"].max()) for x in islands]
98
+
99
+ return minmax_islands
100
+
101
+
102
+ def filter_convo(
103
+ convo: pd.DataFrame,
104
+ phases: List[str],
105
+ column: str = "convo_part",
106
+ strategy: str = "islands",
107
+ pre_n: int = 5,
108
+ post_n: int = 5,
109
+ return_all_on_empty: bool = False,
110
+ **kwargs,
111
+ ) -> pd.DataFrame:
112
+ """Filter convo to include only specified phases. Take into account that sometimes predicted phases
113
+ can be messy. i.e. a prediciton of explore, explore, explore, safety_planning, explore; should return all
114
+ these messages as explore (probably safety_planning message has a low probability here.)
115
+
116
+ Args:
117
+ convo (pd.DataFrame): Conversation with predicted phases stored in `column`
118
+ phases (List[str]): Phases to filter in
119
+ column (str, optional): Column where predicted phases information is stored. Defaults to "convo_part".
120
+ strategy (str, optional): Strategy to use, can be minmax or islands. Defaults to "islands".
121
+ pre_n (int, optional): How many messages pre-phase to include. Defaults to 5.
122
+ post_n (int, optional): How many messages post-phase to include. Defaults to 5.
123
+ return_all_on_empty (bool, optional): Whether to return all messages when specified phases is not found. Defaults to False.
124
+
125
+ Returns:
126
+ pd.DataFrame: Filtered messages from the convo
127
+ """
128
+ if phases == ["ALL"]:
129
+ minidx = convo.index.min()
130
+ maxidx = convo.index.max()
131
+ minmax = [(minidx, maxidx)]
132
+ elif strategy == "minmax":
133
+ minidx = convo.query(f"{column} in @phases").index.min()
134
+ maxidx = convo.query(f"{column} in @phases").index.max() + 1
135
+ minmax = [(minidx, maxidx)]
136
+ elif strategy == "islands":
137
+ minmax = get_phases_islands_minmax(convo, phases, column, **kwargs)
138
+ parts = []
139
+ for minidx, maxidx in minmax:
140
+ minidx = max(convo.index.min(), minidx - pre_n)
141
+ maxidx = min(convo.index.max(), maxidx + post_n)
142
+ parts.append(convo.loc[minidx:maxidx])
143
+ if len(parts) == 0:
144
+ if return_all_on_empty:
145
+ return convo
146
+ else:
147
+ return pd.DataFrame(columns=convo.columns)
148
+ filtered = pd.concat(parts)
149
+ filtered = filtered[~filtered.index.duplicated(keep="first")]
150
+ return filtered
models/ta_models/ta_prompt_utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+
3
+ import pandas as pd
4
+
5
+ from .config import QUESTION2FILTERARGS, TEXTER_PREFIX, HELPER_PREFIX
6
+
7
+ # Utils to filter convo according to a phase
8
+ from .ta_filter_utils import filter_convo
9
+
10
+
11
+ def join_messages(
12
+ grp: pd.DataFrame, texter_prefix: str = "texter", helper_prefix: str = "helper"
13
+ ) -> str:
14
+ """join messages from dataframe using texter an helper prefixes
15
+
16
+ Args:
17
+ grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
18
+ Must have the following columns:
19
+ - actor_role
20
+ - message
21
+
22
+ texter_prefix (str, optional): prefix to use as the texter. Defaults to "texter".
23
+ helper_prefix (str, optional): prefix to use as the counselor (helper). Defaults to "helper".
24
+
25
+ Returns:
26
+ str: joined messages string separated by prefixes
27
+ """
28
+
29
+ if "actor_role" not in grp:
30
+ raise Exception("Column 'actor_role' not in DataFrame")
31
+ if "message" not in grp:
32
+ raise Exception("Column 'message' not in DataFrame")
33
+
34
+ roles = grp.actor_role.replace(
35
+ {"texter": texter_prefix, "counselor": helper_prefix, "helper": helper_prefix}
36
+ )
37
+ messages = roles.str.strip() + ": " + grp.message.str.strip()
38
+ return "\n".join(messages)
39
+
40
+
41
+ def _get_context(grp: pd.DataFrame, **kwargs) -> str:
42
+ """Get context as a str taking into account message to delete, context marker
43
+ and the type of question to use. This allows for better truncation later
44
+
45
+ Args:
46
+ grp (pd.DataFrame): conversation in DataFrame with each row corresponding to each **message**.
47
+ Must have the following columns:
48
+ - actor_role
49
+ - message
50
+ - `column`
51
+ column (str): column name in which the marker of the problem is
52
+
53
+ Returns:
54
+ pd.DataFrame: joined messages string separated by prefixes
55
+ """
56
+
57
+ if "actor_role" not in grp:
58
+ raise Exception("Column 'actor_role' not in DataFrame")
59
+ if "message" not in grp:
60
+ raise Exception("Column 'message' not in DataFrame")
61
+
62
+ join_args = list(inspect.signature(join_messages).parameters)
63
+ join_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in join_args}
64
+
65
+ ## DEPRECATED
66
+ # context_args = list(inspect.signature(get_context_on_marker).parameters)
67
+ # context_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in context_args}
68
+
69
+ return join_messages(grp, **join_kwargs)
70
+
71
+
72
+ def load_context(
73
+ messages: pd.DataFrame,
74
+ question: str,
75
+ message_col: str,
76
+ col_type: str,
77
+ inference: bool = False,
78
+ **kwargs,
79
+ ) -> pd.DataFrame:
80
+ """Load and filter conversation from messages given a question (with configured parameters of what phase that question is answered)
81
+
82
+ Args:
83
+ messages (pd.DataFrame): Messages dataframe with conversation_id, actor_role, `message_col` and phase prediction
84
+ question (str): Question to get context to
85
+ message_col (str): Column where messages are
86
+ col_type (str): type of message_col, can be "individual" or "joined"
87
+ base_dir (str, optional): Base directory to find model base args. Defaults to "../../".
88
+
89
+ Raises:
90
+ Exception: If question is not supported
91
+
92
+ Returns:
93
+ pd.DataFrame: filtered messages according to question configuration
94
+ """
95
+
96
+ if question not in QUESTION2FILTERARGS:
97
+ raise Exception(f"Question {question} not supported")
98
+
99
+ texter_prefix = TEXTER_PREFIX
100
+ helper_prefix = HELPER_PREFIX
101
+ context_data = messages.copy()
102
+
103
+ def convo_cpc_get_context(grp, **kwargs):
104
+ """Filter convo according to Convo Phase Classifier (CPC) predictions"""
105
+ context_ = filter_convo(grp, **QUESTION2FILTERARGS[question])
106
+ return _get_context(context_, **kwargs)
107
+
108
+ if col_type == "individual":
109
+ if "actor_role" in context_data:
110
+ context_data.dropna(subset=["actor_role"], inplace=True)
111
+ if "delete_message" in context_data:
112
+ context_data.delete_message.replace({1: True}, inplace=True)
113
+ context_data.delete_message.fillna(False, inplace=True)
114
+
115
+ context_data = (
116
+ context_data.groupby("conversation_id")
117
+ .apply(
118
+ convo_cpc_get_context,
119
+ helper_prefix=helper_prefix,
120
+ texter_prefix=texter_prefix,
121
+ )
122
+ .rename("q_context")
123
+ )
124
+ elif col_type == "joined":
125
+ context_data = context_data.groupby("conversation_id")[[message_col]].max()
126
+ context_data.rename(columns={message_col: "q_context"}, inplace=True)
127
+
128
+ return context_data
models/ta_models/ta_utils.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import requests
4
+ import string
5
+ import streamlit as st
6
+ from streamlit.logger import get_logger
7
+ from app_config import ENDPOINT_NAMES
8
+ from models.ta_models.config import NAME2PROMPT, NAME2PROMPT_EXPL, START_INST, END_INST, QUESTIONDEFAULTS, TA_OPTIONS, NAME2QUESTION
9
+ import pandas as pd
10
+ from langchain_core.messages import AIMessage, HumanMessage
11
+ from models.ta_models.ta_prompt_utils import load_context
12
+ from utils.mongo_utils import new_convo_scoring_comparison
13
+
14
+ logger = get_logger(__name__)
15
+ TA_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["training_adherence"]['name'])
16
+ HEADERS = {
17
+ "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
18
+ "Content-Type": "application/json",
19
+ }
20
+
21
+ def memory2df(memory, conversation_id="convo1234"):
22
+ df = []
23
+ for i, msg in enumerate(memory.buffer_as_messages):
24
+ actor_role = "texter" if type(msg) == AIMessage else "helper" if type(msg) == HumanMessage else None
25
+ if actor_role:
26
+ convo_part = msg.response_metadata.get("phase",None)
27
+ row = {"conversation_id":conversation_id, "message_number":i+1, "actor_role":actor_role, "message":msg.content, "convo_part":convo_part}
28
+ df.append(row)
29
+
30
+ return pd.DataFrame(df)
31
+
32
+ def get_default(question, make_explanation=False):
33
+ return QUESTIONDEFAULTS[question][make_explanation]
34
+
35
+ def get_context(memory, question, make_explanation=False, **kwargs):
36
+ df = memory2df(memory, **kwargs)
37
+ contexti = load_context(df, question, "messages", "individual").iloc[0]
38
+ if contexti == "":
39
+ return ""
40
+
41
+ if make_explanation:
42
+ return NAME2PROMPT_EXPL[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
43
+ else:
44
+ return NAME2PROMPT[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
45
+
46
+ def post_process_response(full_response, delimiter="\n\n", n=2):
47
+ parts = full_response.split(delimiter)[:n]
48
+ response = extract_response(parts[0])
49
+ logger.debug(f"Response extracted is {response}")
50
+ if len(parts) > 1:
51
+ if len(parts[0]) < len(parts[1]):
52
+ full_response = parts[1]
53
+ else: full_response = parts[0]
54
+ else:
55
+ full_response = parts[0]
56
+ explanation = full_response.lstrip(response).lstrip(string.punctuation)
57
+ explanation = explanation.strip()
58
+ logger.debug(f"Explanation extracted is {explanation}")
59
+ return response, explanation
60
+
61
+ def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
62
+ full_convo = memory.load_memory_variables({})[memory.memory_key]
63
+ PROMPT = get_context(memory, question, make_explanation=make_explanation, **kwargs)
64
+ logger.debug(f"Raw TA prompt is {PROMPT}")
65
+ if PROMPT == "":
66
+ full_response = get_default(question, make_explanation)
67
+ # response, explanation = post_process_response(full_response)
68
+ return full_convo, PROMPT, full_response
69
+
70
+ max_tokens = 128 if make_explanation else 3
71
+ body_request = {
72
+ "prompt": PROMPT,
73
+ "temperature": 0,
74
+ "max_tokens": max_tokens,
75
+ }
76
+
77
+ try:
78
+ # Send request to Serving
79
+ response = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
80
+ if response.status_code == 200:
81
+ response = response.json()
82
+ full_response = response[0]['choices'][0]['text']
83
+ logger.debug(f"Raw TA response is {full_response}")
84
+ # response, explanation = post_process_response(full_response)
85
+ return full_convo, PROMPT, full_response
86
+ except:
87
+ pass
88
+
89
+ def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str:
90
+ """Extract Response from generated answer
91
+ Extract only search strings
92
+
93
+ Args:
94
+ x (str): prediction
95
+ default (str, optional): default in case no response founds. Defaults to "N/A".
96
+
97
+ Returns:
98
+ str: _description_
99
+ """
100
+
101
+ try:
102
+ return re.findall("|".join(TA_OPTIONS), x)[0]
103
+ except Exception:
104
+ return default
105
+
106
+ def ta_push_convo_comparison(ytrue, ypred):
107
+ new_convo_scoring_comparison(**{
108
+ "client": st.session_state['db_client'],
109
+ "convo_id": st.session_state['convo_id'],
110
+ "context": st.session_state["context"] + "\nhelper:" + st.session_state["last_message"],
111
+ "ytrue": ytrue,
112
+ "ypred": ypred,
113
+ })
pages/convosim.py CHANGED
@@ -49,7 +49,7 @@ with st.sidebar:
49
  if 'counselor_name' not in st.session_state:
50
  st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
51
  # temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
52
- issue = st.selectbox("Select a Scenario", ISSUES, index=0, format_func=issue2label,
53
  on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
54
  )
55
  supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
@@ -135,8 +135,9 @@ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG
135
  if any([x['score'] for x in st.session_state['bp_prediction']]):
136
  for bp in st.session_state['bp_prediction']:
137
  if bp["score"]:
138
- st.error(f"Detected {BP_LAB2STR[bp['label']]} in the last message!")
139
  st.session_state.changed_bp = True
 
140
  else:
141
  sent_request_llm(llm_chain, prompt)
142
 
@@ -171,6 +172,9 @@ with st.sidebar:
171
  key="sel_bp"
172
  )
173
 
 
 
 
174
  st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
175
  if st.session_state['total_messages'] >= MAX_MSG_COUNT:
176
  st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
 
49
  if 'counselor_name' not in st.session_state:
50
  st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
51
  # temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
52
+ issue = st.selectbox("Select a Scenario", ISSUES, index=ISSUES.index(st.session_state['issue']), format_func=issue2label,
53
  on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
54
  )
55
  supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
 
135
  if any([x['score'] for x in st.session_state['bp_prediction']]):
136
  for bp in st.session_state['bp_prediction']:
137
  if bp["score"]:
138
+ st.toast(f"Detected {BP_LAB2STR[bp['label']]} in the last message!", icon=":material/warning:")
139
  st.session_state.changed_bp = True
140
+ sent_request_llm(llm_chain, prompt)
141
  else:
142
  sent_request_llm(llm_chain, prompt)
143
 
 
172
  key="sel_bp"
173
  )
174
 
175
+ if st.button("Score Conversation"):
176
+ st.switch_page("pages/training_adherence.py")
177
+
178
  st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
179
  if st.session_state['total_messages'] >= MAX_MSG_COUNT:
180
  st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
pages/training_adherence.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from collections import defaultdict
4
+ from langchain_core.messages import HumanMessage
5
+ from utils.app_utils import are_models_alive
6
+ from models.ta_models.ta_utils import TA_predict_convo, ta_push_convo_comparison, post_process_response
7
+ from models.ta_models.config import QUESTION2PHASE, NAME2QUESTION, TA_OPTIONS
8
+
9
+ if "memory" not in st.session_state:
10
+ st.switch_page("pages/convosim.py")
11
+
12
+ if not are_models_alive():
13
+ st.switch_page("pages/model_loader.py")
14
+
15
+ memory = st.session_state['memory']
16
+ @st.cache_data(show_spinner="Retrieving responses from the server ...")
17
+ def get_ta_responses():
18
+ data = defaultdict(defaultdict)
19
+ # with st.spinner("Retrieving responses from the server ...")
20
+ for question in QUESTION2PHASE.keys():
21
+ # responses = ["Yes, The helper showed some respect.",
22
+ # "Yes. The helper is good! No doubt",
23
+ # "N/A, Texter disengaged.",
24
+ # "No. While texter is trying is lacking.",
25
+ # "No \n\n This is an explanation."]
26
+ # full_response = np.random.choice(responses)
27
+ full_convo, prompt, full_response = TA_predict_convo(memory, question, make_explanation=True, conversation_id=st.session_state['convo_id'])
28
+ response, explanation = post_process_response(full_response)
29
+ data[question]["response"] = response
30
+ data[question]["explanation"] = explanation
31
+ return data
32
+
33
+ with st.container():
34
+ col1, col2 = st.columns(2)
35
+ if col1.button("Go Back"):
36
+ get_ta_responses.clear()
37
+ st.switch_page("pages/convosim.py")
38
+ expl = col2.checkbox("Show Scoring Explanations")
39
+
40
+ tab1, tab2 = st.tabs(["Scoring", "Conversation"])
41
+ data = get_ta_responses()
42
+
43
+ with tab2:
44
+ for msg in memory.buffer_as_messages:
45
+ role = "user" if type(msg) == HumanMessage else "assistant"
46
+ st.chat_message(role).write(msg.content)
47
+
48
+ with tab1:
49
+ for question in QUESTION2PHASE.keys():
50
+ with st.container(border=True):
51
+ question_str = NAME2QUESTION[question].split(' Answer')[0]
52
+ st.radio(
53
+ f"**{question_str}**", options=TA_OPTIONS,
54
+ index=TA_OPTIONS.index(data[question]['response']), horizontal=True,
55
+ key=f"{question}_manual"
56
+ )
57
+ if expl:
58
+ st.write(data[question]["explanation"])
59
+
60
+ with st.container():
61
+ col1, col2 = st.columns(2)
62
+ if col1.button("Go Back", key="goback2"):
63
+ get_ta_responses.clear()
64
+ st.switch_page("pages/convosim.py")
65
+ if col2.button("Submit Scoring", type="primary"):
66
+ ytrue = {
67
+ question: {
68
+ "response":st.session_state[f"{question}_manual"]
69
+ }
70
+ for question in QUESTION2PHASE.keys()
71
+ }
72
+ ta_push_convo_comparison(ytrue, data)
73
+ get_ta_responses.clear()
74
+ st.switch_page("pages/convosim.py")
utils/app_utils.py CHANGED
@@ -101,6 +101,7 @@ def is_model_alive(name, timeout=2, model_type="classificator"):
101
  except:
102
  return "404"
103
 
 
104
  def are_models_alive():
105
  models_alive = []
106
  for config in ENDPOINT_NAMES.values():
 
101
  except:
102
  return "404"
103
 
104
+ @st.cache_data(ttl=300, show_spinner=False)
105
  def are_models_alive():
106
  models_alive = []
107
  for config in ENDPOINT_NAMES.values():
utils/mongo_utils.py CHANGED
@@ -4,7 +4,7 @@ import streamlit as st
4
  from streamlit.logger import get_logger
5
  from pymongo.mongo_client import MongoClient
6
  from pymongo.server_api import ServerApi
7
- from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS, DB_CPC, DB_BP
8
 
9
  DB_URL = os.environ['MONGO_URL']
10
  DB_USR = os.environ['MONGO_USR']
@@ -19,7 +19,7 @@ def get_db_client():
19
  # Send a ping to confirm a successful connection
20
  try:
21
  client.admin.command('ping')
22
- logger.info(f"DBUTILS: Pinged your deployment. You successfully connected to MongoDB!")
23
  return client
24
  except Exception as e:
25
  logger.error(e)
@@ -38,7 +38,7 @@ def new_convo(client, issue, language, username, is_comparison, model_one, model
38
  db = client[DB_SCHEMA]
39
  convos = db[DB_CONVOS]
40
  convo_id = convos.insert_one(convo).inserted_id
41
- logger.info(f"DBUTILS: new convo id is {convo_id}")
42
  st.session_state['convo_id'] = convo_id
43
 
44
  def new_comparison(client, prompt_timestamp, completion_timestamp,
@@ -66,7 +66,7 @@ def new_comparison(client, prompt_timestamp, completion_timestamp,
66
  db = client[DB_SCHEMA]
67
  comparisons = db[DB_COMPLETIONS]
68
  comparison_id = comparisons.insert_one(comparison).inserted_id
69
- logger.info(f"DBUTILS: new comparison id is {comparison_id}")
70
  st.session_state['comparison_id'] = comparison_id
71
 
72
  def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner):
@@ -84,7 +84,7 @@ def new_battle_result(client, comparison_id, convo_id, username, model_one, mode
84
  db = client[DB_SCHEMA]
85
  battles = db[DB_BATTLES]
86
  battle_id = battles.insert_one(battle).inserted_id
87
- logger.info(f"DBUTILS: new battle id is {battle_id}")
88
 
89
  def new_completion_error(client, comparison_id, username, model):
90
  error = {
@@ -97,12 +97,12 @@ def new_completion_error(client, comparison_id, username, model):
97
  db = client[DB_SCHEMA]
98
  errors = db[DB_ERRORS]
99
  error_id = errors.insert_one(error).inserted_id
100
- logger.info(f"DBUTILS: new error id is {error_id}")
101
 
102
  def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
103
  # context = memory.load_memory_variables({})[memory.memory_key]
104
  comp = {
105
- "error_timestamp": dt.datetime.now(tz=dt.timezone.utc),
106
  "conversation_id": convo_id,
107
  "model": model,
108
  "context": context,
@@ -114,12 +114,12 @@ def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, yp
114
  db = client[DB_SCHEMA]
115
  cpc_comps = db[DB_CPC]
116
  comarison_id = cpc_comps.insert_one(comp).inserted_id
117
- # logger.info(f"DBUTILS: new error id is {error_id}")
118
 
119
  def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
120
  # context = memory.load_memory_variables({})[memory.memory_key]
121
  comp = {
122
- "error_timestamp": dt.datetime.now(tz=dt.timezone.utc),
123
  "conversation_id": convo_id,
124
  "model": model,
125
  "context": context,
@@ -133,7 +133,22 @@ def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypr
133
  db = client[DB_SCHEMA]
134
  bp_comps = db[DB_BP]
135
  comarison_id = bp_comps.insert_one(comp).inserted_id
136
- logger.info(f"DBUTILS: new BP id is {comarison_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  def get_non_assesed_comparison(client, username):
139
  from bson.son import SON
 
4
  from streamlit.logger import get_logger
5
  from pymongo.mongo_client import MongoClient
6
  from pymongo.server_api import ServerApi
7
+ from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS, DB_CPC, DB_BP, DB_TA
8
 
9
  DB_URL = os.environ['MONGO_URL']
10
  DB_USR = os.environ['MONGO_USR']
 
19
  # Send a ping to confirm a successful connection
20
  try:
21
  client.admin.command('ping')
22
+ logger.debug(f"DBUTILS: Pinged your deployment. You successfully connected to MongoDB!")
23
  return client
24
  except Exception as e:
25
  logger.error(e)
 
38
  db = client[DB_SCHEMA]
39
  convos = db[DB_CONVOS]
40
  convo_id = convos.insert_one(convo).inserted_id
41
+ logger.debug(f"DBUTILS: new convo id is {convo_id}")
42
  st.session_state['convo_id'] = convo_id
43
 
44
  def new_comparison(client, prompt_timestamp, completion_timestamp,
 
66
  db = client[DB_SCHEMA]
67
  comparisons = db[DB_COMPLETIONS]
68
  comparison_id = comparisons.insert_one(comparison).inserted_id
69
+ logger.debug(f"DBUTILS: new comparison id is {comparison_id}")
70
  st.session_state['comparison_id'] = comparison_id
71
 
72
  def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner):
 
84
  db = client[DB_SCHEMA]
85
  battles = db[DB_BATTLES]
86
  battle_id = battles.insert_one(battle).inserted_id
87
+ logger.debug(f"DBUTILS: new battle id is {battle_id}")
88
 
89
  def new_completion_error(client, comparison_id, username, model):
90
  error = {
 
97
  db = client[DB_SCHEMA]
98
  errors = db[DB_ERRORS]
99
  error_id = errors.insert_one(error).inserted_id
100
+ logger.debug(f"DBUTILS: new error id is {error_id}")
101
 
102
  def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
103
  # context = memory.load_memory_variables({})[memory.memory_key]
104
  comp = {
105
+ "CPC_timestamp": dt.datetime.now(tz=dt.timezone.utc),
106
  "conversation_id": convo_id,
107
  "model": model,
108
  "context": context,
 
114
  db = client[DB_SCHEMA]
115
  cpc_comps = db[DB_CPC]
116
  comarison_id = cpc_comps.insert_one(comp).inserted_id
117
+ logger.debug(f"DBUTILS: new error id is {comarison_id}")
118
 
119
  def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
120
  # context = memory.load_memory_variables({})[memory.memory_key]
121
  comp = {
122
+ "BP_timestamp": dt.datetime.now(tz=dt.timezone.utc),
123
  "conversation_id": convo_id,
124
  "model": model,
125
  "context": context,
 
133
  db = client[DB_SCHEMA]
134
  bp_comps = db[DB_BP]
135
  comarison_id = bp_comps.insert_one(comp).inserted_id
136
+ logger.debug(f"DBUTILS: new BP id is {comarison_id}")
137
+
138
+ def new_convo_scoring_comparison(client, convo_id, context, ytrue, ypred):
139
+ # context = memory.load_memory_variables({})[memory.memory_key]
140
+ comp = {
141
+ "scoring_timestamp": dt.datetime.now(tz=dt.timezone.utc),
142
+ "conversation_id": convo_id,
143
+ "context": context,
144
+ "manual_scoring": ytrue,
145
+ "model_scoring": ypred,
146
+ }
147
+
148
+ db = client[DB_SCHEMA]
149
+ ta_comps = db[DB_TA]
150
+ comarison_id = ta_comps.insert_one(comp).inserted_id
151
+ logger.debug(f"DBUTILS: new TA convo comparison id is {comarison_id}")
152
 
153
  def get_non_assesed_comparison(client, username):
154
  from bson.son import SON