Spaces:
Build error
Build error
Quyet
commited on
Commit
·
de337bd
1
Parent(s):
617fa8c
update chat state history, add initial greeting
Browse files
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
title: PsyPlus
|
3 |
emoji: 🤖
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: gpl-3.0
|
|
|
1 |
---
|
2 |
title: PsyPlus
|
3 |
emoji: 🤖
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.10.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: gpl-3.0
|
app.py
CHANGED
@@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
|
|
4 |
from threading import Timer
|
5 |
import gradio as gr
|
6 |
|
|
|
7 |
from transformers import (
|
8 |
GPT2LMHeadModel, GPT2Tokenizer,
|
9 |
AutoModelForSequenceClassification, AutoTokenizer,
|
@@ -11,6 +12,8 @@ from transformers import (
|
|
11 |
)
|
12 |
# reference: https://huggingface.co/spaces/bentrevett/emotion-prediction
|
13 |
# and https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
|
|
|
|
|
14 |
|
15 |
def euc_100():
|
16 |
# 1,2,3. asks about the user's emotions and store data
|
@@ -77,16 +80,14 @@ def euc_100():
|
|
77 |
|
78 |
|
79 |
def load_neural_emotion_detector():
|
80 |
-
model_name =
|
81 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
|
|
82 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
83 |
pipe = pipeline('text-classification', model=model, tokenizer=tokenizer,
|
84 |
return_all_scores=True, truncation=True)
|
85 |
return pipe
|
86 |
|
87 |
-
def sort_predictions(predictions):
|
88 |
-
return sorted(predictions, key=lambda x: x['score'], reverse=True)
|
89 |
-
|
90 |
def plot_emotion_distribution(predictions):
|
91 |
fig, ax = plt.subplots()
|
92 |
ax.bar(x=[i for i, _ in enumerate(prediction)],
|
@@ -98,9 +99,9 @@ def plot_emotion_distribution(predictions):
|
|
98 |
|
99 |
def rulebase(text):
|
100 |
keywords = {
|
101 |
-
'life_safety': [
|
102 |
-
'immediacy': [
|
103 |
-
'manifestation': [
|
104 |
}
|
105 |
|
106 |
# if found dangerous kw/topics
|
@@ -127,7 +128,7 @@ def euc_200(text, testing=True):
|
|
127 |
if not testing:
|
128 |
pipe = load_neural_emotion_detector()
|
129 |
prediction = pipe(text)[0]
|
130 |
-
prediction =
|
131 |
plot_emotion_distribution(prediction)
|
132 |
|
133 |
# get the most probable emotion. TODO: modify this part, may take sum of prob. over all negative emotion
|
@@ -174,46 +175,58 @@ def euc_200(text, testing=True):
|
|
174 |
pass
|
175 |
|
176 |
|
177 |
-
tokenizer
|
178 |
-
model = GPT2LMHeadModel.from_pretrained("tareknaous/dialogpt-empathetic-dialogues")
|
179 |
-
model.eval()
|
180 |
-
|
181 |
-
def chat(message, history):
|
182 |
-
history = history or []
|
183 |
eos = tokenizer.eos_token
|
184 |
-
|
|
|
|
|
|
|
|
|
185 |
|
186 |
-
|
187 |
-
|
|
|
|
|
188 |
max_length=1000,
|
189 |
do_sample=True, top_p=0.9, temperature=0.8,
|
190 |
pad_token_id=tokenizer.eos_token_id)
|
191 |
-
response = tokenizer.decode(bot_output_ids[:,
|
192 |
skip_special_tokens=True)
|
|
|
|
|
193 |
|
194 |
-
history
|
195 |
-
|
|
|
196 |
|
197 |
|
198 |
if __name__ == '__main__':
|
199 |
# euc_100()
|
200 |
# euc_200('I am happy about my academic record.')
|
201 |
parser = argparse.ArgumentParser()
|
202 |
-
parser.add_argument('--
|
203 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
title =
|
206 |
-
description =
|
|
|
207 |
iface = gr.Interface(
|
208 |
chat,
|
209 |
-
[
|
210 |
-
[
|
211 |
-
|
212 |
-
allow_flagging=
|
213 |
title=title,
|
214 |
description=description,
|
215 |
)
|
216 |
-
if args.
|
217 |
iface.launch(debug=True)
|
218 |
else:
|
219 |
-
iface.launch(debug=True, server_name=
|
|
|
4 |
from threading import Timer
|
5 |
import gradio as gr
|
6 |
|
7 |
+
import torch
|
8 |
from transformers import (
|
9 |
GPT2LMHeadModel, GPT2Tokenizer,
|
10 |
AutoModelForSequenceClassification, AutoTokenizer,
|
|
|
12 |
)
|
13 |
# reference: https://huggingface.co/spaces/bentrevett/emotion-prediction
|
14 |
# and https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
|
15 |
+
# gradio vs streamlit https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
|
16 |
+
# https://gradio.app/interface_state/
|
17 |
|
18 |
def euc_100():
|
19 |
# 1,2,3. asks about the user's emotions and store data
|
|
|
80 |
|
81 |
|
82 |
def load_neural_emotion_detector():
|
83 |
+
model_name = 'joeddav/distilbert-base-uncased-go-emotions-student'
|
84 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
85 |
+
model.eval()
|
86 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
87 |
pipe = pipeline('text-classification', model=model, tokenizer=tokenizer,
|
88 |
return_all_scores=True, truncation=True)
|
89 |
return pipe
|
90 |
|
|
|
|
|
|
|
91 |
def plot_emotion_distribution(predictions):
|
92 |
fig, ax = plt.subplots()
|
93 |
ax.bar(x=[i for i, _ in enumerate(prediction)],
|
|
|
99 |
|
100 |
def rulebase(text):
|
101 |
keywords = {
|
102 |
+
'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'],
|
103 |
+
'immediacy': ['now', 'immediately', 'tomorrow', 'today'],
|
104 |
+
'manifestation': ['never stop', 'every moment', 'strong', 'very']
|
105 |
}
|
106 |
|
107 |
# if found dangerous kw/topics
|
|
|
128 |
if not testing:
|
129 |
pipe = load_neural_emotion_detector()
|
130 |
prediction = pipe(text)[0]
|
131 |
+
prediction = sorted(predictions, key=lambda x: x['score'], reverse=True)
|
132 |
plot_emotion_distribution(prediction)
|
133 |
|
134 |
# get the most probable emotion. TODO: modify this part, may take sum of prob. over all negative emotion
|
|
|
175 |
pass
|
176 |
|
177 |
|
178 |
+
def _chat(message, history, model, tokenizer, args):
|
|
|
|
|
|
|
|
|
|
|
179 |
eos = tokenizer.eos_token
|
180 |
+
history = history or {
|
181 |
+
'text': args.greeting,
|
182 |
+
'input_ids': tokenizer.encode(args.greeting[-1][1] + eos, return_tensors='pt'),
|
183 |
+
}
|
184 |
+
# TODO: only take the latest X turns, otherwise the text becomes longer and takes more time to process
|
185 |
|
186 |
+
message_ids = tokenizer.encode(message + eos, return_tensors='pt')
|
187 |
+
history['input_ids'] = torch.cat([history['input_ids'], message_ids], dim=-1)
|
188 |
+
|
189 |
+
bot_output_ids = model.generate(history['input_ids'],
|
190 |
max_length=1000,
|
191 |
do_sample=True, top_p=0.9, temperature=0.8,
|
192 |
pad_token_id=tokenizer.eos_token_id)
|
193 |
+
response = tokenizer.decode(bot_output_ids[:, history['input_ids'].shape[-1]:][0],
|
194 |
skip_special_tokens=True)
|
195 |
+
if args.run_on_own_server == 1:
|
196 |
+
print((message, response), bot_output_ids[0][-10:])
|
197 |
|
198 |
+
history['input_ids'] = bot_output_ids
|
199 |
+
history['text'].append((message, response))
|
200 |
+
return history['text'], history
|
201 |
|
202 |
|
203 |
if __name__ == '__main__':
|
204 |
# euc_100()
|
205 |
# euc_200('I am happy about my academic record.')
|
206 |
parser = argparse.ArgumentParser()
|
207 |
+
parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode')
|
208 |
args = parser.parse_args()
|
209 |
+
args.greeting = [('','Hi you!')]
|
210 |
+
|
211 |
+
tokenizer = GPT2Tokenizer.from_pretrained('tareknaous/dialogpt-empathetic-dialogues')
|
212 |
+
model = GPT2LMHeadModel.from_pretrained('tareknaous/dialogpt-empathetic-dialogues')
|
213 |
+
model.eval()
|
214 |
+
def chat(message, history):
|
215 |
+
return _chat(message, history, model, tokenizer, args)
|
216 |
|
217 |
+
title = 'PsyPlus Empathetic Chatbot'
|
218 |
+
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
|
219 |
+
chatbot = gr.Chatbot(value=args.greeting)
|
220 |
iface = gr.Interface(
|
221 |
chat,
|
222 |
+
['text', 'state'],
|
223 |
+
[chatbot, 'state'],
|
224 |
+
# css=".gradio-container {background-color: white}",
|
225 |
+
allow_flagging='never',
|
226 |
title=title,
|
227 |
description=description,
|
228 |
)
|
229 |
+
if args.run_on_own_server == 0:
|
230 |
iface.launch(debug=True)
|
231 |
else:
|
232 |
+
iface.launch(debug=True, server_name='0.0.0.0', server_port=2022, share=True)
|