Quyet commited on
Commit
de337bd
·
1 Parent(s): 617fa8c

update chat state history, add initial greeting

Browse files
Files changed (2) hide show
  1. README.md +3 -3
  2. app.py +42 -29
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: PsyPlus
3
  emoji: 🤖
4
- colorFrom: green
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.11.0
8
  app_file: app.py
9
  pinned: false
10
  license: gpl-3.0
 
1
  ---
2
  title: PsyPlus
3
  emoji: 🤖
4
+ colorFrom: pink
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.10.0
8
  app_file: app.py
9
  pinned: false
10
  license: gpl-3.0
app.py CHANGED
@@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
4
  from threading import Timer
5
  import gradio as gr
6
 
 
7
  from transformers import (
8
  GPT2LMHeadModel, GPT2Tokenizer,
9
  AutoModelForSequenceClassification, AutoTokenizer,
@@ -11,6 +12,8 @@ from transformers import (
11
  )
12
  # reference: https://huggingface.co/spaces/bentrevett/emotion-prediction
13
  # and https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
 
 
14
 
15
  def euc_100():
16
  # 1,2,3. asks about the user's emotions and store data
@@ -77,16 +80,14 @@ def euc_100():
77
 
78
 
79
  def load_neural_emotion_detector():
80
- model_name = "joeddav/distilbert-base-uncased-go-emotions-student"
81
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
82
  tokenizer = AutoTokenizer.from_pretrained(model_name)
83
  pipe = pipeline('text-classification', model=model, tokenizer=tokenizer,
84
  return_all_scores=True, truncation=True)
85
  return pipe
86
 
87
- def sort_predictions(predictions):
88
- return sorted(predictions, key=lambda x: x['score'], reverse=True)
89
-
90
  def plot_emotion_distribution(predictions):
91
  fig, ax = plt.subplots()
92
  ax.bar(x=[i for i, _ in enumerate(prediction)],
@@ -98,9 +99,9 @@ def plot_emotion_distribution(predictions):
98
 
99
  def rulebase(text):
100
  keywords = {
101
- 'life_safety': ["death", "suicide", "murder", "to perish together", "jump off the building"],
102
- 'immediacy': ["now", "immediately", "tomorrow", "today"],
103
- 'manifestation': ["never stop", "every moment", "strong", "very"]
104
  }
105
 
106
  # if found dangerous kw/topics
@@ -127,7 +128,7 @@ def euc_200(text, testing=True):
127
  if not testing:
128
  pipe = load_neural_emotion_detector()
129
  prediction = pipe(text)[0]
130
- prediction = sort_predictions(prediction)
131
  plot_emotion_distribution(prediction)
132
 
133
  # get the most probable emotion. TODO: modify this part, may take sum of prob. over all negative emotion
@@ -174,46 +175,58 @@ def euc_200(text, testing=True):
174
  pass
175
 
176
 
177
- tokenizer = GPT2Tokenizer.from_pretrained("tareknaous/dialogpt-empathetic-dialogues")
178
- model = GPT2LMHeadModel.from_pretrained("tareknaous/dialogpt-empathetic-dialogues")
179
- model.eval()
180
-
181
- def chat(message, history):
182
- history = history or []
183
  eos = tokenizer.eos_token
184
- input_str = eos.join([x for turn in history for x in turn] + [message])
 
 
 
 
185
 
186
- bot_input_ids = tokenizer.encode(input_str, return_tensors='pt')
187
- bot_output_ids = model.generate(bot_input_ids,
 
 
188
  max_length=1000,
189
  do_sample=True, top_p=0.9, temperature=0.8,
190
  pad_token_id=tokenizer.eos_token_id)
191
- response = tokenizer.decode(bot_output_ids[:, bot_input_ids.shape[-1]:][0],
192
  skip_special_tokens=True)
 
 
193
 
194
- history.append((message, response))
195
- return history, history
 
196
 
197
 
198
  if __name__ == '__main__':
199
  # euc_100()
200
  # euc_200('I am happy about my academic record.')
201
  parser = argparse.ArgumentParser()
202
- parser.add_argument('--run_share_mode', type=int, default=0, help='if test on own server, need to use share mode')
203
  args = parser.parse_args()
 
 
 
 
 
 
 
204
 
205
- title = "PsyPlus Empathetic Chatbot"
206
- description = "Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT"
 
207
  iface = gr.Interface(
208
  chat,
209
- ["text", "state"],
210
- ["chatbot", "state"],
211
- allow_screenshot=False,
212
- allow_flagging="never",
213
  title=title,
214
  description=description,
215
  )
216
- if args.run_share_mode == 0:
217
  iface.launch(debug=True)
218
  else:
219
- iface.launch(debug=True, server_name="0.0.0.0", server_port=2022, share=True)
 
4
  from threading import Timer
5
  import gradio as gr
6
 
7
+ import torch
8
  from transformers import (
9
  GPT2LMHeadModel, GPT2Tokenizer,
10
  AutoModelForSequenceClassification, AutoTokenizer,
 
12
  )
13
  # reference: https://huggingface.co/spaces/bentrevett/emotion-prediction
14
  # and https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
15
+ # gradio vs streamlit https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
16
+ # https://gradio.app/interface_state/
17
 
18
  def euc_100():
19
  # 1,2,3. asks about the user's emotions and store data
 
80
 
81
 
82
  def load_neural_emotion_detector():
83
+ model_name = 'joeddav/distilbert-base-uncased-go-emotions-student'
84
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
85
+ model.eval()
86
  tokenizer = AutoTokenizer.from_pretrained(model_name)
87
  pipe = pipeline('text-classification', model=model, tokenizer=tokenizer,
88
  return_all_scores=True, truncation=True)
89
  return pipe
90
 
 
 
 
91
  def plot_emotion_distribution(predictions):
92
  fig, ax = plt.subplots()
93
  ax.bar(x=[i for i, _ in enumerate(prediction)],
 
99
 
100
  def rulebase(text):
101
  keywords = {
102
+ 'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'],
103
+ 'immediacy': ['now', 'immediately', 'tomorrow', 'today'],
104
+ 'manifestation': ['never stop', 'every moment', 'strong', 'very']
105
  }
106
 
107
  # if found dangerous kw/topics
 
128
  if not testing:
129
  pipe = load_neural_emotion_detector()
130
  prediction = pipe(text)[0]
131
+ prediction = sorted(predictions, key=lambda x: x['score'], reverse=True)
132
  plot_emotion_distribution(prediction)
133
 
134
  # get the most probable emotion. TODO: modify this part, may take sum of prob. over all negative emotion
 
175
  pass
176
 
177
 
178
+ def _chat(message, history, model, tokenizer, args):
 
 
 
 
 
179
  eos = tokenizer.eos_token
180
+ history = history or {
181
+ 'text': args.greeting,
182
+ 'input_ids': tokenizer.encode(args.greeting[-1][1] + eos, return_tensors='pt'),
183
+ }
184
+ # TODO: only take the latest X turns, otherwise the text becomes longer and takes more time to process
185
 
186
+ message_ids = tokenizer.encode(message + eos, return_tensors='pt')
187
+ history['input_ids'] = torch.cat([history['input_ids'], message_ids], dim=-1)
188
+
189
+ bot_output_ids = model.generate(history['input_ids'],
190
  max_length=1000,
191
  do_sample=True, top_p=0.9, temperature=0.8,
192
  pad_token_id=tokenizer.eos_token_id)
193
+ response = tokenizer.decode(bot_output_ids[:, history['input_ids'].shape[-1]:][0],
194
  skip_special_tokens=True)
195
+ if args.run_on_own_server == 1:
196
+ print((message, response), bot_output_ids[0][-10:])
197
 
198
+ history['input_ids'] = bot_output_ids
199
+ history['text'].append((message, response))
200
+ return history['text'], history
201
 
202
 
203
  if __name__ == '__main__':
204
  # euc_100()
205
  # euc_200('I am happy about my academic record.')
206
  parser = argparse.ArgumentParser()
207
+ parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode')
208
  args = parser.parse_args()
209
+ args.greeting = [('','Hi you!')]
210
+
211
+ tokenizer = GPT2Tokenizer.from_pretrained('tareknaous/dialogpt-empathetic-dialogues')
212
+ model = GPT2LMHeadModel.from_pretrained('tareknaous/dialogpt-empathetic-dialogues')
213
+ model.eval()
214
+ def chat(message, history):
215
+ return _chat(message, history, model, tokenizer, args)
216
 
217
+ title = 'PsyPlus Empathetic Chatbot'
218
+ description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
219
+ chatbot = gr.Chatbot(value=args.greeting)
220
  iface = gr.Interface(
221
  chat,
222
+ ['text', 'state'],
223
+ [chatbot, 'state'],
224
+ # css=".gradio-container {background-color: white}",
225
+ allow_flagging='never',
226
  title=title,
227
  description=description,
228
  )
229
+ if args.run_on_own_server == 0:
230
  iface.launch(debug=True)
231
  else:
232
+ iface.launch(debug=True, server_name='0.0.0.0', server_port=2022, share=True)