gosha6037 commited on
Commit
dd71c9a
1 Parent(s): a4f8b32

Added code for predict_cluster_bloom

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import sys
 
2
 
3
  import gradio as gr
4
  import torch
@@ -6,8 +7,11 @@ import transformers
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
  sys.path.insert(0, './petals/')
 
 
9
 
10
  from petals.client.remote_model import DistributedBloomForCausalLM
 
11
 
12
  MODEL_NAME = "bigscience/bloom-petals"
13
  tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
@@ -63,6 +67,43 @@ def predict_common_bloom(model, tokenizer, input_text, history, person_descripti
63
  return response_new, history_new
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
67
  person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
68
  new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
@@ -110,6 +151,11 @@ def predict(
110
  tokenizer = tokenizer_bloomd
111
  print(f'Lets go history: {history}')
112
  return predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
 
 
 
 
 
113
  else:
114
  model_name = 'DialoGPT-medium'
115
  model = model_DialoGPT_medium
 
1
  import sys
2
+ import json
3
 
4
  import gradio as gr
5
  import torch
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
  sys.path.insert(0, './petals/')
10
+ sys.path.insert(0, './personalized-chat-bot/')
11
+
12
 
13
  from petals.client.remote_model import DistributedBloomForCausalLM
14
+ from models.personality_clustering import PersonalityClustering
15
 
16
  MODEL_NAME = "bigscience/bloom-petals"
17
  tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
 
67
  return response_new, history_new
68
 
69
 
70
+ def predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
71
+ new_user_input_ids = tokenizer.encode(input_text + '\n', return_tensors='pt')
72
+ print('Started predict_common_bloom')
73
+ print(f'history: {history}')
74
+ if history != []:
75
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
76
+ else:
77
+ bot_input_ids = new_user_input_ids
78
+ print(f'bot_input_ids: {bot_input_ids}')
79
+
80
+ history = model.generate(
81
+ bot_input_ids,
82
+ max_new_tokens=number_of_new_tokens,
83
+ pad_token_id=tokenizer.eos_token_id
84
+ ).tolist()
85
+ print(f'history: {history}')
86
+
87
+ decode_all = tokenizer.decode(history[0][:len(bot_input_ids[0])])
88
+ all_responses = tokenizer.decode(history[0][len(bot_input_ids[0]):]).split('\n')
89
+ if all_responses[0]:
90
+ decode_all += all_responses[0] + '\n'
91
+ else:
92
+ decode_all += all_responses[1] + '\n'
93
+ print(f'decode_all: {decode_all}')
94
+
95
+ history_new = tokenizer.encode(decode_all, return_tensors='pt')
96
+ print(f'history_new: {history_new}')
97
+
98
+ decode_all_split = decode_all.split('\n')
99
+ print(f'decode_all_split: {decode_all_split}')
100
+
101
+ response_new = [(decode_all_split[i], decode_all_split[i + 1]) for i in range(0, len(decode_all_split) - 1, 2)]
102
+ print(f'response_new: {response_new}')
103
+
104
+ return response_new, history_new
105
+
106
+
107
  def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
108
  person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
109
  new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
 
151
  tokenizer = tokenizer_bloomd
152
  print(f'Lets go history: {history}')
153
  return predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
154
+ elif model_name == 'bloom-petals-cluster':
155
+ model = model_bloomd
156
+ tokenizer = tokenizer_bloomd
157
+ print(f'Lets go history: {history}')
158
+ return predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
159
  else:
160
  model_name = 'DialoGPT-medium'
161
  model = model_DialoGPT_medium