gosha6037 commited on
Commit
b649ec8
·
1 Parent(s): 776e43c

Added common bloom

Browse files
Files changed (2) hide show
  1. app.py +66 -29
  2. petals +1 -1
app.py CHANGED
@@ -9,18 +9,9 @@ sys.path.insert(0, './petals/')
9
 
10
  from petals.client.remote_model import DistributedBloomForCausalLM
11
 
12
- MODEL_NAME = "bigscience/test-bloomd-6b3"
13
- # INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"]
14
- tokenizer_bloomd_6b3 = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
15
- model_bloomd_6b3 = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME,
16
- # initial_peers=INITIAL_PEERS,
17
- low_cpu_mem_usage=True, torch_dtype=torch.float32)
18
-
19
  MODEL_NAME = "bigscience/bloom-petals"
20
  tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
21
- model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME,
22
- low_cpu_mem_usage=True, torch_dtype=torch.float32)
23
-
24
 
25
  tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
26
  model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
@@ -32,48 +23,95 @@ tokenizer_DialoGPT_large = AutoTokenizer.from_pretrained("microsoft/DialoGPT-lar
32
  model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def predict(
36
  input_text,
37
  history=None,
38
  person_description=None,
39
- number_of_new_tokens=1000,
40
  model_name=None,
41
  del_hist=None
42
  ):
 
43
  if history is None or del_hist == 'delete history':
44
  history = []
 
45
  if model_name == 'DialoGPT-small':
46
  model = model_DialoGPT_small
47
  tokenizer = tokenizer_DialoGPT_small
 
48
  elif model_name == 'DialoGPT-medium':
49
  model = model_DialoGPT_medium
50
  tokenizer = tokenizer_DialoGPT_medium
 
51
  elif model_name == 'DialoGPT-large':
52
  model = model_DialoGPT_large
53
  tokenizer = tokenizer_DialoGPT_large
54
- elif model_name == 'test-bloomd-6b3':
55
- model = model_bloomd_6b3
56
- tokenizer = tokenizer_bloomd_6b3
57
  elif model_name == 'bloom-petals':
58
  model = model_bloomd
59
  tokenizer = tokenizer_bloomd
 
 
60
  else:
 
61
  model = model_DialoGPT_medium
62
  tokenizer = tokenizer_DialoGPT_medium
63
-
64
- person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
65
- new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
66
-
67
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
68
- input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
69
- max_token_count = number_of_new_tokens + len(input_with_desc_ids[0])
70
- history = model.generate(input_with_desc_ids, max_length=max_token_count,
71
- pad_token_id=tokenizer.eos_token_id).tolist()
72
- history[0] = history[0][len(person_description_ids[0]):]
73
-
74
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
75
- response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
76
- return response, history
77
 
78
 
79
  gr.Interface(
@@ -89,7 +127,6 @@ gr.Interface(
89
  'DialoGPT-small',
90
  'DialoGPT-medium',
91
  'DialoGPT-large',
92
- 'test-bloomd-6b3',
93
  'bloom-petals',
94
  ]
95
  ),
 
9
 
10
  from petals.client.remote_model import DistributedBloomForCausalLM
11
 
 
 
 
 
 
 
 
12
  MODEL_NAME = "bigscience/bloom-petals"
13
  tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
14
+ model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True)
 
 
15
 
16
  tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
17
  model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
 
23
  model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
24
 
25
 
26
+ def predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
27
+ new_user_input_ids = tokenizer.encode(input_text + '\n', return_tensors='pt')
28
+ print('Started predict_common_bloom')
29
+ print(f'history: {history}')
30
+ if history != []:
31
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
32
+ else:
33
+ bot_input_ids = new_user_input_ids
34
+ print(f'bot_input_ids: {bot_input_ids}')
35
+
36
+ history = model.generate(
37
+ bot_input_ids,
38
+ max_new_tokens=number_of_new_tokens,
39
+ pad_token_id=tokenizer.eos_token_id
40
+ ).tolist()
41
+ print(f'history: {history}')
42
+
43
+ decode_all = tokenizer.decode(history[0][:len(bot_input_ids[0])])
44
+ all_responses = tokenizer.decode(history[0][len(bot_input_ids[0]):]).split('\n')
45
+ if all_responses[0]:
46
+ decode_all += all_responses[0] + '\n'
47
+ else:
48
+ decode_all += all_responses[1] + '\n'
49
+ print(f'decode_all: {decode_all}')
50
+
51
+ history_new = tokenizer.encode(decode_all, return_tensors='pt')
52
+ print(f'history_new: {history_new}')
53
+
54
+ decode_all_split = decode_all.split('\n')
55
+ print(f'decode_all_split: {decode_all_split}')
56
+
57
+ response_new = [(decode_all_split[i], decode_all_split[i + 1]) for i in range(0, len(decode_all_split) - 1, 2)]
58
+ print(f'response_new: {response_new}')
59
+
60
+ return response_new, history_new
61
+
62
+
63
+ def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
64
+ person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
65
+ new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
66
+
67
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
68
+ input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
69
+ history = model.generate(
70
+ input_with_desc_ids,
71
+ max_new_tokens=number_of_new_tokens,
72
+ pad_token_id=tokenizer.eos_token_id
73
+ ).tolist()
74
+ history[0] = history[0][len(person_description_ids[0]):]
75
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
76
+ response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
77
+
78
+ return response, history
79
+
80
+
81
  def predict(
82
  input_text,
83
  history=None,
84
  person_description=None,
85
+ number_of_new_tokens=10,
86
  model_name=None,
87
  del_hist=None
88
  ):
89
+
90
  if history is None or del_hist == 'delete history':
91
  history = []
92
+
93
  if model_name == 'DialoGPT-small':
94
  model = model_DialoGPT_small
95
  tokenizer = tokenizer_DialoGPT_small
96
+ return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
97
  elif model_name == 'DialoGPT-medium':
98
  model = model_DialoGPT_medium
99
  tokenizer = tokenizer_DialoGPT_medium
100
+ return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
101
  elif model_name == 'DialoGPT-large':
102
  model = model_DialoGPT_large
103
  tokenizer = tokenizer_DialoGPT_large
104
+ return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
 
 
105
  elif model_name == 'bloom-petals':
106
  model = model_bloomd
107
  tokenizer = tokenizer_bloomd
108
+ print(f'Lets go history: {history}')
109
+ return predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
110
  else:
111
+ model_name = 'DialoGPT-medium'
112
  model = model_DialoGPT_medium
113
  tokenizer = tokenizer_DialoGPT_medium
114
+ return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
  gr.Interface(
 
127
  'DialoGPT-small',
128
  'DialoGPT-medium',
129
  'DialoGPT-large',
 
130
  'bloom-petals',
131
  ]
132
  ),
petals CHANGED
@@ -1 +1 @@
1
- Subproject commit ab41223b17c17dd1035a42318b03d4b92decd063
 
1
+ Subproject commit d6992fca6363a83909f20f4c46839a004eb469e3