Spaces:
Runtime error
Runtime error
Added code for predict_cluster_bloom
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import sys
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
@@ -6,8 +7,11 @@ import transformers
|
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
|
8 |
sys.path.insert(0, './petals/')
|
|
|
|
|
9 |
|
10 |
from petals.client.remote_model import DistributedBloomForCausalLM
|
|
|
11 |
|
12 |
MODEL_NAME = "bigscience/bloom-petals"
|
13 |
tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
@@ -63,6 +67,43 @@ def predict_common_bloom(model, tokenizer, input_text, history, person_descripti
|
|
63 |
return response_new, history_new
|
64 |
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
|
67 |
person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
|
68 |
new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
|
@@ -110,6 +151,11 @@ def predict(
|
|
110 |
tokenizer = tokenizer_bloomd
|
111 |
print(f'Lets go history: {history}')
|
112 |
return predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
|
|
|
|
|
|
|
|
|
|
|
113 |
else:
|
114 |
model_name = 'DialoGPT-medium'
|
115 |
model = model_DialoGPT_medium
|
|
|
1 |
import sys
|
2 |
+
import json
|
3 |
|
4 |
import gradio as gr
|
5 |
import torch
|
|
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
|
9 |
sys.path.insert(0, './petals/')
|
10 |
+
sys.path.insert(0, './personalized-chat-bot/')
|
11 |
+
|
12 |
|
13 |
from petals.client.remote_model import DistributedBloomForCausalLM
|
14 |
+
from models.personality_clustering import PersonalityClustering
|
15 |
|
16 |
MODEL_NAME = "bigscience/bloom-petals"
|
17 |
tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
67 |
return response_new, history_new
|
68 |
|
69 |
|
70 |
+
def predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
|
71 |
+
new_user_input_ids = tokenizer.encode(input_text + '\n', return_tensors='pt')
|
72 |
+
print('Started predict_common_bloom')
|
73 |
+
print(f'history: {history}')
|
74 |
+
if history != []:
|
75 |
+
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
|
76 |
+
else:
|
77 |
+
bot_input_ids = new_user_input_ids
|
78 |
+
print(f'bot_input_ids: {bot_input_ids}')
|
79 |
+
|
80 |
+
history = model.generate(
|
81 |
+
bot_input_ids,
|
82 |
+
max_new_tokens=number_of_new_tokens,
|
83 |
+
pad_token_id=tokenizer.eos_token_id
|
84 |
+
).tolist()
|
85 |
+
print(f'history: {history}')
|
86 |
+
|
87 |
+
decode_all = tokenizer.decode(history[0][:len(bot_input_ids[0])])
|
88 |
+
all_responses = tokenizer.decode(history[0][len(bot_input_ids[0]):]).split('\n')
|
89 |
+
if all_responses[0]:
|
90 |
+
decode_all += all_responses[0] + '\n'
|
91 |
+
else:
|
92 |
+
decode_all += all_responses[1] + '\n'
|
93 |
+
print(f'decode_all: {decode_all}')
|
94 |
+
|
95 |
+
history_new = tokenizer.encode(decode_all, return_tensors='pt')
|
96 |
+
print(f'history_new: {history_new}')
|
97 |
+
|
98 |
+
decode_all_split = decode_all.split('\n')
|
99 |
+
print(f'decode_all_split: {decode_all_split}')
|
100 |
+
|
101 |
+
response_new = [(decode_all_split[i], decode_all_split[i + 1]) for i in range(0, len(decode_all_split) - 1, 2)]
|
102 |
+
print(f'response_new: {response_new}')
|
103 |
+
|
104 |
+
return response_new, history_new
|
105 |
+
|
106 |
+
|
107 |
def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
|
108 |
person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
|
109 |
new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
|
|
|
151 |
tokenizer = tokenizer_bloomd
|
152 |
print(f'Lets go history: {history}')
|
153 |
return predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
|
154 |
+
elif model_name == 'bloom-petals-cluster':
|
155 |
+
model = model_bloomd
|
156 |
+
tokenizer = tokenizer_bloomd
|
157 |
+
print(f'Lets go history: {history}')
|
158 |
+
return predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
|
159 |
else:
|
160 |
model_name = 'DialoGPT-medium'
|
161 |
model = model_DialoGPT_medium
|