Spaces:
Build error
Build error
File size: 13,304 Bytes
f30862a 93152cc f30862a 93152cc f30862a 617fa8c 342c3ab de337bd f30862a 342c3ab f30862a 93152cc f30862a 342c3ab f30862a 342c3ab f30862a 342c3ab f30862a 342c3ab f30862a 93152cc f30862a 93152cc f30862a 93152cc f30862a 93152cc f30862a 342c3ab f30862a 342c3ab f30862a 342c3ab f30862a 342c3ab f30862a 342c3ab f30862a de337bd f30862a 342c3ab f30862a 93152cc 342c3ab f30862a 93152cc f30862a 342c3ab f30862a 93152cc f30862a 93152cc f30862a 93152cc f30862a de337bd f30862a 93152cc 617fa8c de337bd 93152cc 342c3ab 93152cc f30862a 342c3ab f30862a de337bd 617fa8c 93152cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 |
'''
Dialog System of PsyPlus (dvq)
reference:
https://huggingface.co/spaces/bentrevett/emotion-prediction
https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
https://huggingface.co/benjaminbeilharz/t5-empatheticdialogues
gradio vs streamlit
https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
https://gradio.app/interface_state/ -> global and local varible affect the separation of sessions
TODO
Add command to reset/jump to a function, e.g >reset, >euc_100
Add diagram in Gradio Interface showing sentimate analysis
Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement
Personalize: create database, load and save data
Run command
python app.py --run_on_own_server 1 --initial_chat_state free_chat
'''
import argparse
import re, time
import matplotlib.pyplot as plt
from threading import Timer
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
def option():
parser = argparse.ArgumentParser()
parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode')
parser.add_argument('--dialog_model', type=str, default='tareknaous/dialogpt-empathetic-dialogues')
parser.add_argument('--emotion_model', type=str, default='joeddav/distilbert-base-uncased-go-emotions-student')
parser.add_argument('--account', type=str, default=None)
parser.add_argument('--initial_chat_state', type=str, default='euc_100', choices=['euc_100', 'euc_200', 'free_chat'])
args = parser.parse_args()
return args
args = option()
# store the list of messages that are showed in therapies and models as global variables
# let all chat-session-wise variables placed in TherapyChatBot
class ChatHelper:
# chat and emotion-detection models
ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True)
ed_threshold = 0.3
dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model)
dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model)
eos = dialog_tokenizer.eos_token
# tokenizer.__call__ -> input_ids, attention_mask
# tokenizer.encode -> only inputs_ids, which is required by model.generate function
invalid_input = 'Invalid input, my friend :) Plz input again'
good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?'
good_case = 'Nice to hear that!'
bad_mood_over = 'Whether your bad mood is over? (Yes or No)'
not_answer = "It's okay, maybe you don't want to answer this question."
fill_form = ('It has come to our attention that you may suffer from {}.\n'
'If you want to know more about yourself, some professional scales are provided to quantify your current status.\n'
'After a period of time (maybe a week/two months/a month) trying to follow the solutions we suggested, '
'you can fill out these scales again to see if you have improved.\n'
'Do you want to fill in the form now? (Okay or Later)')
display_form = '<Display the form>.\n'
reference = 'Here are some reference articles about bad emotions. You can take a look :) <Display references>\n'
emotion_types = ['Overall', 'Happiness', 'Anxiety'] # 'Surprise', 'Sadness', 'Depression', 'Anger', 'Fear',
euc_100 = {
'q': emotion_types,
'good_mood': [
'You seem to be in a good mood today. Is there anything you could notice that makes you happy?',
'I am glad that you are willing to share the experience with me. Thanks for letting me know.',
],
'bad_mood': [
'You seem not to be in a good mood. What specific thing is bothering you the most right now?',
'I see. So when it is happening, what feelings or emotions have you got?',
'And what do you think about those feelings or emotions at that time?',
'Could you think of any evidence for your above-mentioned thought?',
'Here are some reference articles about bad emotions. You can take a look :)',
],
}
negative_emotions = ['remorse', 'nervousness', 'annoyance', 'anger', 'grief', 'fear', 'disapproval',
'confusion', 'embarrassment', 'disgust', 'sadness', 'disappointment']
euc_200 = 'Now go back to the last chat. You said that "{}".\n'
greeting_template = {
'euc_100': 'How was your day? On the scale 1 to 10, '
'how would you judge your emotion through the following categories:\nOverall',
# euc_200 is only trigger when you say smt more negative than a certain threshol
# thus the greeting here is only for debuging euc_200
'euc_200': fill_form.format('anxiety'),
'free_chat': 'Hi you! How is it going?',
}
def plot_emotion_distribution(predictions):
fig, ax = plt.subplots()
ax.bar(x=[i for i, _ in enumerate(prediction)],
height=[p['score'] for p in prediction],
tick_label=[p['label'] for p in prediction])
ax.tick_params(rotation=90)
ax.set_ylim(0, 1)
plt.show()
def ed_rulebase(text):
keywords = {
'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'],
'immediacy': ['now', 'immediately', 'tomorrow', 'today'],
'manifestation': ['never stop', 'every moment', 'strong', 'very']
}
# if found dangerous kw/topics
if re.search(rf"{'|'.join(keywords['life_safety'])}", text) != None and \
sum([re.search(rf"{'|'.join(keywords[k])}", text) != None for k in ['immediacy','manifestation']]) >= 1:
print('We noticed that you may need immediate professional assistance, would you like to make a phone call? '
'The Hong Kong Lifeline number is (852) 2382 0000')
x = input('Choose 1. "Dial to the number" or 2. "No dangerous emotion la": ')
if x == '1':
print('Let you connect to the office')
else:
print('Sorry for our misdetection. We just want to make sure that you could get immediate help when needed. '
'Would you mind if we send this conversation to the cloud to finetune the model.')
y = input('Yes or No: ')
if y == 'Yes':
pass # do smt here
class TherapyChatBot:
def __init__(self, args):
# check state to control the dialog
self.chat_state = args.initial_chat_state # name of the chat function/therapy segment the model is in
self.message_prev = None
self.chat_state_prev = None
self.run_on_own_server = args.run_on_own_server
self.account = args.account
# additional attribute for euc_100
self.euc_100_input_time = []
self.euc_100_emotion_degree = []
self.already_trigger_euc_200 = False
# chat history.
# TODO: if we want to personalize and save the conversation,
# we can load data from database
self.greeting = [('', ChatHelper.greeting_template[self.chat_state])]
self.history = {'input_ids': torch.tensor([[ChatHelper.dialog_tokenizer.bos_token_id]]),
'text': self.greeting} if not self.account else open(f'database/{hash(self.account)}', 'rb')
if 'euc_100' in self.chat_state:
self.chat_state = 'euc_100.q.0'
def __call__(self, message, prefix=''):
# if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion
if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200:
prediction = ChatHelper.ed_pipe(message)[0]
prediction = sorted(prediction, key=lambda x: x['score'], reverse=True)
if self.run_on_own_server:
print(prediction)
# plot_emotion_distribution(prediction)
emotion = prediction[0]
# if message is negative, change state immediately
if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \
(emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > ChatHelper.ed_threshold):
self.chat_state_prev = self.chat_state
self.chat_state = 'euc_200'
self.message_prev = message
self.already_trigger_euc_200 = True
response = ChatHelper.fill_form.format(emotion['label'])
# set up rule to update state inside each dialog function
elif self.chat_state.startswith('euc_100'):
response = self.euc_100(message)
if self.chat_state == 'free_chat':
last_two_turns_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt')
self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1)
elif self.chat_state.startswith('euc_200'):
return self.euc_200(message)
else: # free_chat
response = self.free_chat(message)
if prefix:
response = prefix + response
self.history['text'].append((self.message_prev, response))
else:
self.history['text'].append((message, response))
def euc_100(self, x):
_, subsection, entry = self.chat_state.split('.')
entry = int(entry)
if subsection == 'q':
if x.isnumeric() and (0 < int(x) < 11):
self.euc_100_emotion_degree.append(int(x))
self.euc_100_input_time.append(time.gmtime())
if entry == len(ChatHelper.euc_100['q']) - 1:
if self.run_on_own_server:
print(self.euc_100_emotion_degree)
mood = 'good_mood' if self.euc_100_emotion_degree[0] > 5 else 'bad_mood'
self.chat_state = f'euc_100.{mood}.0'
response = ChatHelper.euc_100[mood][0]
else:
self.chat_state = f'euc_100.q.{entry+1}'
response = ChatHelper.euc_100['q'][entry+1]
else:
response = ChatHelper.invalid_input
elif subsection == 'good_mood':
if x == '':
response = ChatHelper.good_mood_over
else:
response = ChatHelper.good_case
response += '\n' + ChatHelper.euc_100['good_mood'][1]
self.chat_state = 'free_chat'
elif subsection == 'bad_mood':
if entry == -1:
if 'yes' in x.lower() or 'better' in x.lower():
response = ChatHelper.good_case
else:
entry = int(self.chat_state_prev.rsplit('.', 1))
response = ChatHelper.not_answer + '\n' + ChatHelper.euc_100['bad_mood'][entry+1]
if entry == len(ChatHelper.euc_100['bad_mood']) - 2:
self.chat_state = 'free_chat'
else:
self.chat_state = f'euc_100.bad_mood.{entry+1}'
if x == '':
response = ChatHelper.bad_mood_over
self.chat_state_prev = self.chat_state
self.chat_state = 'euc_100.bad_mood.-1'
else:
response = ChatHelper.euc_100['bad_mood'][entry+1]
if entry == len(ChatHelper.euc_100['bad_mood']) - 2:
self.chat_state = 'free_chat'
else:
self.chat_state = f'euc_100.bad_mood.{entry+1}'
return response
def euc_200(self, x):
# don't ask question in euc_200, because they're similar to question in euc_100
if x.lower() == 'okay':
response = ChatHelper.display_form
else:
response = ChatHelper.reference
response += ChatHelper.euc_200.format(self.message_prev)
message = self.message_prev
self.message_prev = x
self.chat_state = self.chat_state_prev
return self.__call__(message, prefix=response)
def free_chat(self, message):
message_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt')
self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1)
input_ids = self.history['input_ids'].clone()
while True:
bot_output_ids = ChatHelper.dialog_model.generate(input_ids, max_length=1000,
do_sample=True, top_p=0.9, temperature=0.8, num_beams=2,
pad_token_id=ChatHelper.dialog_tokenizer.eos_token_id)
response = ChatHelper.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:],
skip_special_tokens=True)
if response.strip() != '':
break
elif input_ids[0].tolist().count(ChatHelper.dialog_tokenizer.eos_token_id) > 0:
idx = input_ids[0].tolist().index(ChatHelper.dialog_tokenizer.eos_token_id)
input_ids = input_ids[:, (idx+1):]
else:
input_ids = message_ids
if self.run_on_own_server:
print(input_ids)
self.history['input_ids'] = torch.cat([self.history['input_ids'], bot_output_ids[0:1, input_ids.shape[-1]:]], dim=-1)
if self.run_on_own_server == 1:
print((message, response), '\n', self.history['input_ids'])
return response
if __name__ == '__main__':
def chat(message, bot):
bot = bot or TherapyChatBot(args)
bot(message)
return bot.history['text'], bot
title = 'PsyPlus Empathetic Chatbot'
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
greeting = [('', ChatHelper.greeting_template[args.initial_chat_state])]
chatbot = gr.Chatbot(value=greeting)
iface = gr.Interface(
chat, ['text', 'state'], [chatbot, 'state'],
allow_flagging='never', title=title, description=description,
)
if args.run_on_own_server == 0:
iface.launch(debug=True)
else:
iface.launch(debug=True, share=True) |