Spaces:
Runtime error
Runtime error
arnocandel
commited on
Commit
·
8c85f5b
1
Parent(s):
c7b6f0f
Update with h2oGPT hash c146f88a7a4b65fe180dae6d92358a898b140e4a
Browse files- app.py +122 -44
- finetune.py +5 -1
- utils.py +62 -0
app.py
CHANGED
@@ -4,8 +4,7 @@ import sys
|
|
4 |
import os
|
5 |
import traceback
|
6 |
import typing
|
7 |
-
|
8 |
-
from utils import set_seed, flatten_list, clear_torch_cache, system_info_print
|
9 |
|
10 |
SEED = 1236
|
11 |
set_seed(SEED)
|
@@ -27,6 +26,12 @@ from finetune import get_loaders, example_data_points, generate_prompt, get_gith
|
|
27 |
human, bot, prompt_type_to_model_name, inv_prompt_type_to_model_lower
|
28 |
from stopping import CallbackToGenerator, Stream, StoppingCriteriaSub
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def main(
|
32 |
load_8bit: bool = False,
|
@@ -53,6 +58,7 @@ def main(
|
|
53 |
|
54 |
llama_type: bool = None,
|
55 |
debug: bool = False,
|
|
|
56 |
share: bool = True,
|
57 |
local_files_only: bool = False,
|
58 |
resume_download: bool = True,
|
@@ -90,15 +96,23 @@ def main(
|
|
90 |
):
|
91 |
# allow set token directly
|
92 |
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
93 |
-
|
94 |
-
if
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
temperature = 0.7
|
99 |
-
top_p = 1
|
100 |
-
top_k = 100
|
101 |
do_sample = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
# get defaults
|
104 |
model_lower = base_model.lower()
|
@@ -166,7 +180,7 @@ def main(
|
|
166 |
if not eval_sharegpt_as_output:
|
167 |
model, tokenizer, device = get_model(**locals())
|
168 |
model_state = [model, tokenizer, device, base_model]
|
169 |
-
fun = partial(evaluate, model_state, debug=debug, chat=chat)
|
170 |
else:
|
171 |
assert eval_sharegpt_prompts_only > 0
|
172 |
|
@@ -202,7 +216,7 @@ def main(
|
|
202 |
assert ex[1] in [None, ''] # should be no iinput
|
203 |
assert ex[2] in [None, ''] # should be no context
|
204 |
prompt = ex[0]
|
205 |
-
cutoff_len = 768 if
|
206 |
inputs = stokenizer(prompt, res,
|
207 |
return_tensors="pt",
|
208 |
truncation=True,
|
@@ -215,8 +229,9 @@ def main(
|
|
215 |
score = 0.0
|
216 |
clear_torch_cache()
|
217 |
except RuntimeError as e:
|
218 |
-
if 'Expected all tensors to be on the same device' in str(
|
219 |
-
|
|
|
220 |
print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
|
221 |
flush=True)
|
222 |
traceback.print_exc()
|
@@ -526,11 +541,12 @@ def go_gradio(**kwargs):
|
|
526 |
"""
|
527 |
else:
|
528 |
description = "For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).<br>"
|
529 |
-
if
|
530 |
description += """<p><b> DISCLAIMERS: </b><ul><i><li>The data used to train this model include The Pile and other sources. These may contain objectionable content, so the model may reproduce that material. Use application and responses at own risk.</i></li>"""
|
531 |
if kwargs['load_8bit']:
|
532 |
-
description += """<i><li> Model is loaded in 8-bit
|
533 |
-
description += """<i><li>
|
|
|
534 |
|
535 |
if kwargs['verbose']:
|
536 |
task_info_md = f"""
|
@@ -538,14 +554,43 @@ def go_gradio(**kwargs):
|
|
538 |
else:
|
539 |
task_info_md = ''
|
540 |
|
541 |
-
css_code = """footer {visibility: hidden}
|
542 |
-
body{background-
|
|
|
543 |
|
544 |
-
from gradio.themes.utils import colors, fonts, sizes
|
545 |
if kwargs['h2ocolors']:
|
546 |
-
|
547 |
-
|
548 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
spacing_size=sizes.spacing_md,
|
550 |
radius_size=sizes.radius_md,
|
551 |
text_size=sizes.text_md,
|
@@ -617,12 +662,12 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
617 |
{description}
|
618 |
{task_info_md}
|
619 |
""")
|
620 |
-
if
|
621 |
gr.HTML('''<center><a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate this Space to skip the queue and run in a private space</center>''')
|
622 |
|
623 |
# go button visible if
|
624 |
base_wanted = bool(kwargs['base_model']) and kwargs['login_mode_if_model0']
|
625 |
-
go_btn = gr.Button(value="
|
626 |
normal_block = gr.Row(visible=not base_wanted)
|
627 |
with normal_block:
|
628 |
with gr.Tabs():
|
@@ -685,7 +730,7 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
685 |
value=kwargs['stream_output'])
|
686 |
prompt_type = gr.Dropdown(prompt_types_strings,
|
687 |
value=kwargs['prompt_type'], label="Prompt Type",
|
688 |
-
visible=not
|
689 |
temperature = gr.Slider(minimum=0, maximum=3,
|
690 |
value=kwargs['temperature'],
|
691 |
label="Temperature",
|
@@ -698,12 +743,12 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
698 |
value=kwargs['top_k'], label="Top k",
|
699 |
info='Num. tokens to sample from'
|
700 |
)
|
701 |
-
max_beams = 8 if not
|
702 |
num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
|
703 |
value=min(max_beams, kwargs['num_beams']), label="Beams",
|
704 |
info="Number of searches for optimal overall probability. "
|
705 |
"Uses more GPU memory/compute")
|
706 |
-
max_max_new_tokens = 2048 if not
|
707 |
max_new_tokens = gr.Slider(
|
708 |
minimum=1, maximum=max_max_new_tokens, step=1,
|
709 |
value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
|
@@ -714,7 +759,7 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
714 |
)
|
715 |
early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
|
716 |
value=kwargs['early_stopping'])
|
717 |
-
max_max_time = 60 * 5 if not
|
718 |
max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
|
719 |
value=min(max_max_time, kwargs['max_time']), label="Max. time",
|
720 |
info="Max. time to search optimal output.")
|
@@ -724,17 +769,17 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
724 |
num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
|
725 |
value=kwargs['num_return_sequences'],
|
726 |
label="Number Returns", info="Must be <= num_beams",
|
727 |
-
visible=not
|
728 |
do_sample = gr.Checkbox(label="Sample", info="Sample, for diverse output(s)",
|
729 |
value=kwargs['do_sample'])
|
730 |
if kwargs['chat']:
|
731 |
iinput = gr.Textbox(lines=4, label="Input",
|
732 |
placeholder=kwargs['placeholder_input'],
|
733 |
-
visible=not
|
734 |
# nominally empty for chat mode
|
735 |
context = gr.Textbox(lines=1, label="Context",
|
736 |
info="Ignored in chat mode.",
|
737 |
-
visible=not
|
738 |
|
739 |
with gr.TabItem("Models"):
|
740 |
with gr.Row():
|
@@ -744,8 +789,8 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
744 |
model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model", value=kwargs['base_model'])
|
745 |
lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora'])
|
746 |
with gr.Column(scale=1):
|
747 |
-
load_msg = "Load Model/LORA" if not
|
748 |
-
else "LOAD DISABLED
|
749 |
load_model_button = gr.Button(load_msg)
|
750 |
model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
|
751 |
lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora'])
|
@@ -757,12 +802,27 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
757 |
add_model_button = gr.Button("Add new model name")
|
758 |
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
|
759 |
with gr.TabItem("System"):
|
760 |
-
|
|
|
|
|
|
|
761 |
with gr.Column():
|
762 |
system_text = gr.Textbox(label='System Info')
|
763 |
system_btn = gr.Button(value='Get System Info')
|
764 |
|
|
|
|
|
|
|
|
|
|
|
|
|
765 |
|
|
|
|
|
|
|
|
|
|
|
|
|
766 |
inputs_list = get_inputs_list(locals(), kwargs['model_lower'])
|
767 |
from functools import partial
|
768 |
all_kwargs = kwargs.copy()
|
@@ -811,7 +871,7 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
811 |
len(history[-1]) >= 2:
|
812 |
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
813 |
|
814 |
-
max_length_tokenize = 512 if
|
815 |
cutoff_len = max_length_tokenize*4 # restrict deberta related to max for LLM
|
816 |
|
817 |
question = history[-1][0]
|
@@ -833,7 +893,9 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
833 |
clear_torch_cache()
|
834 |
return 'Response Score: GPU OOM'
|
835 |
except RuntimeError as e:
|
836 |
-
if 'Expected all tensors to be on the same device' in str(e) or
|
|
|
|
|
837 |
print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
|
838 |
traceback.print_exc()
|
839 |
clear_torch_cache()
|
@@ -1025,7 +1087,7 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
1025 |
outputs=[model_state, model_used, lora_used, prompt_type])
|
1026 |
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
|
1027 |
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
|
1028 |
-
if not
|
1029 |
load_model_event = load_model_button.click(**load_model_args) \
|
1030 |
.then(**prompt_update_args) \
|
1031 |
.then(**chatbot_update_args) \
|
@@ -1079,7 +1141,7 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
1079 |
|
1080 |
|
1081 |
input_args_list = ['model_state']
|
1082 |
-
inputs_kwargs_list = ['debug', 'chat', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
|
1083 |
|
1084 |
|
1085 |
def get_inputs_list(inputs_dict, model_lower):
|
@@ -1142,6 +1204,7 @@ def evaluate(
|
|
1142 |
src_lang=None,
|
1143 |
tgt_lang=None,
|
1144 |
debug=False,
|
|
|
1145 |
chat=False,
|
1146 |
hard_stop_list=None,
|
1147 |
sanitize_bot_response=True,
|
@@ -1204,7 +1267,7 @@ def evaluate(
|
|
1204 |
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
1205 |
# stopping only starts once output is beyond prompt
|
1206 |
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
1207 |
-
stop_words = [human, bot]
|
1208 |
encounters = [1, 2]
|
1209 |
elif prompt_type == 'instruct_vicuna':
|
1210 |
# even below is not enough, generic strings and many ways to encode
|
@@ -1235,6 +1298,9 @@ def evaluate(
|
|
1235 |
# avoid padding in front of tokens
|
1236 |
if tokenizer.pad_token:
|
1237 |
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
|
|
|
|
|
|
1238 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
|
1239 |
else:
|
1240 |
stopping_criteria = StoppingCriteriaList()
|
@@ -1243,7 +1309,7 @@ def evaluate(
|
|
1243 |
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
|
1244 |
# RuntimeError: expected scalar type Half but found Float
|
1245 |
# with - 256
|
1246 |
-
max_length_tokenize = 768 - 256 if
|
1247 |
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
|
1248 |
output_smallest = 30 * 4
|
1249 |
prompt = prompt[-cutoff_len - output_smallest:]
|
@@ -1332,8 +1398,9 @@ def evaluate(
|
|
1332 |
clear_torch_cache()
|
1333 |
return
|
1334 |
except RuntimeError as e:
|
1335 |
-
if 'Expected all tensors to be on the same device' in str(
|
1336 |
-
|
|
|
1337 |
print(
|
1338 |
"GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
|
1339 |
flush=True)
|
@@ -1343,6 +1410,7 @@ def evaluate(
|
|
1343 |
else:
|
1344 |
raise
|
1345 |
|
|
|
1346 |
for output in CallbackToGenerator(generate, callback=None, **gen_kwargs):
|
1347 |
decoded_output = decoder(output)
|
1348 |
if output[-1] in [tokenizer.eos_token_id]:
|
@@ -1353,12 +1421,16 @@ def evaluate(
|
|
1353 |
raise StopIteration
|
1354 |
yield prompter.get_response(decoded_output, prompt=inputs_decoded,
|
1355 |
sanitize_bot_response=sanitize_bot_response)
|
1356 |
-
|
|
|
1357 |
else:
|
1358 |
outputs = model.generate(**gen_kwargs)
|
1359 |
outputs = [decoder(s) for s in outputs.sequences]
|
1360 |
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
1361 |
sanitize_bot_response=sanitize_bot_response)
|
|
|
|
|
|
|
1362 |
|
1363 |
|
1364 |
def get_generate_params(model_lower, chat,
|
@@ -1569,5 +1641,11 @@ if __name__ == "__main__":
|
|
1569 |
|
1570 |
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
|
1571 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1572 |
""", flush=True)
|
1573 |
fire.Fire(main)
|
|
|
4 |
import os
|
5 |
import traceback
|
6 |
import typing
|
7 |
+
from utils import set_seed, flatten_list, clear_torch_cache, system_info_print, zip_data, save_generate_output
|
|
|
8 |
|
9 |
SEED = 1236
|
10 |
set_seed(SEED)
|
|
|
26 |
human, bot, prompt_type_to_model_name, inv_prompt_type_to_model_lower
|
27 |
from stopping import CallbackToGenerator, Stream, StoppingCriteriaSub
|
28 |
|
29 |
+
is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
|
30 |
+
is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
|
31 |
+
is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
|
32 |
+
is_low_mem = is_hf # assumes run on 24GB consumer GPU
|
33 |
+
admin_pass = os.getenv("ADMIN_PASS")
|
34 |
+
|
35 |
|
36 |
def main(
|
37 |
load_8bit: bool = False,
|
|
|
58 |
|
59 |
llama_type: bool = None,
|
60 |
debug: bool = False,
|
61 |
+
save_dir: str = None,
|
62 |
share: bool = True,
|
63 |
local_files_only: bool = False,
|
64 |
resume_download: bool = True,
|
|
|
96 |
):
|
97 |
# allow set token directly
|
98 |
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
99 |
+
|
100 |
+
if is_public:
|
101 |
+
temperature = 0.4
|
102 |
+
top_p = 0.85
|
103 |
+
top_k = 70
|
|
|
|
|
|
|
104 |
do_sample = True
|
105 |
+
if is_low_mem:
|
106 |
+
base_model = 'h2oai/h2ogpt-oasst1-512-12b'
|
107 |
+
load_8bit = True
|
108 |
+
else:
|
109 |
+
base_model = 'h2oai/h2ogpt-oasst1-512-20b'
|
110 |
+
if is_low_mem:
|
111 |
+
load_8bit = True
|
112 |
+
if is_hf:
|
113 |
+
# must override share if in spaces
|
114 |
+
share = False
|
115 |
+
save_dir = os.getenv('SAVE_DIR', save_dir)
|
116 |
|
117 |
# get defaults
|
118 |
model_lower = base_model.lower()
|
|
|
180 |
if not eval_sharegpt_as_output:
|
181 |
model, tokenizer, device = get_model(**locals())
|
182 |
model_state = [model, tokenizer, device, base_model]
|
183 |
+
fun = partial(evaluate, model_state, debug=debug, chat=chat, save_dir=save_dir)
|
184 |
else:
|
185 |
assert eval_sharegpt_prompts_only > 0
|
186 |
|
|
|
216 |
assert ex[1] in [None, ''] # should be no iinput
|
217 |
assert ex[2] in [None, ''] # should be no context
|
218 |
prompt = ex[0]
|
219 |
+
cutoff_len = 768 if is_low_mem else 2048
|
220 |
inputs = stokenizer(prompt, res,
|
221 |
return_tensors="pt",
|
222 |
truncation=True,
|
|
|
229 |
score = 0.0
|
230 |
clear_torch_cache()
|
231 |
except RuntimeError as e:
|
232 |
+
if 'Expected all tensors to be on the same device' in str(e) or \
|
233 |
+
'expected scalar type Half but found Float' in str(e) or \
|
234 |
+
'probability tensor contains either' in str(e):
|
235 |
print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
|
236 |
flush=True)
|
237 |
traceback.print_exc()
|
|
|
541 |
"""
|
542 |
else:
|
543 |
description = "For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).<br>"
|
544 |
+
if is_public:
|
545 |
description += """<p><b> DISCLAIMERS: </b><ul><i><li>The data used to train this model include The Pile and other sources. These may contain objectionable content, so the model may reproduce that material. Use application and responses at own risk.</i></li>"""
|
546 |
if kwargs['load_8bit']:
|
547 |
+
description += """<i><li> Model is loaded in 8-bit, model loading-unloading is disabled, and other limitations exist in order to fit on GPUs with lower amounts of VRAM, so UX can be worse than non-hosted version.</i></li>"""
|
548 |
+
description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
|
549 |
+
description += """<i><li>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md).</i></li></ul></p>"""
|
550 |
|
551 |
if kwargs['verbose']:
|
552 |
task_info_md = f"""
|
|
|
554 |
else:
|
555 |
task_info_md = ''
|
556 |
|
557 |
+
css_code = """footer {visibility: hidden;}
|
558 |
+
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
|
559 |
+
body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
560 |
|
561 |
+
from gradio.themes.utils import Color, colors, fonts, sizes
|
562 |
if kwargs['h2ocolors']:
|
563 |
+
h2o_yellow = Color(
|
564 |
+
name="yellow",
|
565 |
+
c50="#fffef2",
|
566 |
+
c100="#fff9e6",
|
567 |
+
c200="#ffecb3",
|
568 |
+
c300="#ffe28c",
|
569 |
+
c400="#ffd659",
|
570 |
+
c500="#fec925",
|
571 |
+
c600="#e6ac00",
|
572 |
+
c700="#bf8f00",
|
573 |
+
c800="#a67c00",
|
574 |
+
c900="#664d00",
|
575 |
+
c950="#403000",
|
576 |
+
)
|
577 |
+
h2o_gray = Color(
|
578 |
+
name="gray",
|
579 |
+
c50="#f2f2f2",
|
580 |
+
c100="#e5e5e5",
|
581 |
+
c200="#cccccc",
|
582 |
+
c300="#b2b2b2",
|
583 |
+
c400="#999999",
|
584 |
+
c500="#7f7f7f",
|
585 |
+
c600="#666666",
|
586 |
+
c700="#4c4c4c",
|
587 |
+
c800="#333333",
|
588 |
+
c900="#191919",
|
589 |
+
c950="#0d0d0d",
|
590 |
+
)
|
591 |
+
colors_dict = dict(primary_hue=h2o_yellow,
|
592 |
+
secondary_hue=h2o_yellow,
|
593 |
+
neutral_hue=h2o_gray,
|
594 |
spacing_size=sizes.spacing_md,
|
595 |
radius_size=sizes.radius_md,
|
596 |
text_size=sizes.text_md,
|
|
|
662 |
{description}
|
663 |
{task_info_md}
|
664 |
""")
|
665 |
+
if is_hf:
|
666 |
gr.HTML('''<center><a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate this Space to skip the queue and run in a private space</center>''')
|
667 |
|
668 |
# go button visible if
|
669 |
base_wanted = bool(kwargs['base_model']) and kwargs['login_mode_if_model0']
|
670 |
+
go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
|
671 |
normal_block = gr.Row(visible=not base_wanted)
|
672 |
with normal_block:
|
673 |
with gr.Tabs():
|
|
|
730 |
value=kwargs['stream_output'])
|
731 |
prompt_type = gr.Dropdown(prompt_types_strings,
|
732 |
value=kwargs['prompt_type'], label="Prompt Type",
|
733 |
+
visible=not is_public)
|
734 |
temperature = gr.Slider(minimum=0, maximum=3,
|
735 |
value=kwargs['temperature'],
|
736 |
label="Temperature",
|
|
|
743 |
value=kwargs['top_k'], label="Top k",
|
744 |
info='Num. tokens to sample from'
|
745 |
)
|
746 |
+
max_beams = 8 if not is_low_mem else 2
|
747 |
num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
|
748 |
value=min(max_beams, kwargs['num_beams']), label="Beams",
|
749 |
info="Number of searches for optimal overall probability. "
|
750 |
"Uses more GPU memory/compute")
|
751 |
+
max_max_new_tokens = 2048 if not is_low_mem else kwargs['max_new_tokens']
|
752 |
max_new_tokens = gr.Slider(
|
753 |
minimum=1, maximum=max_max_new_tokens, step=1,
|
754 |
value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
|
|
|
759 |
)
|
760 |
early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
|
761 |
value=kwargs['early_stopping'])
|
762 |
+
max_max_time = 60 * 5 if not is_low_mem else 60
|
763 |
max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
|
764 |
value=min(max_max_time, kwargs['max_time']), label="Max. time",
|
765 |
info="Max. time to search optimal output.")
|
|
|
769 |
num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
|
770 |
value=kwargs['num_return_sequences'],
|
771 |
label="Number Returns", info="Must be <= num_beams",
|
772 |
+
visible=not is_public)
|
773 |
do_sample = gr.Checkbox(label="Sample", info="Sample, for diverse output(s)",
|
774 |
value=kwargs['do_sample'])
|
775 |
if kwargs['chat']:
|
776 |
iinput = gr.Textbox(lines=4, label="Input",
|
777 |
placeholder=kwargs['placeholder_input'],
|
778 |
+
visible=not is_public)
|
779 |
# nominally empty for chat mode
|
780 |
context = gr.Textbox(lines=1, label="Context",
|
781 |
info="Ignored in chat mode.",
|
782 |
+
visible=not is_public)
|
783 |
|
784 |
with gr.TabItem("Models"):
|
785 |
with gr.Row():
|
|
|
789 |
model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model", value=kwargs['base_model'])
|
790 |
lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora'])
|
791 |
with gr.Column(scale=1):
|
792 |
+
load_msg = "Load Model/LORA" if not is_public \
|
793 |
+
else "LOAD DISABLED FOR HOSTED DEMO"
|
794 |
load_model_button = gr.Button(load_msg)
|
795 |
model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
|
796 |
lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora'])
|
|
|
802 |
add_model_button = gr.Button("Add new model name")
|
803 |
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
|
804 |
with gr.TabItem("System"):
|
805 |
+
system_row = gr.Row(visible=not is_public)
|
806 |
+
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
|
807 |
+
admin_btn = gr.Button(value="admin", visible=is_public)
|
808 |
+
with system_row:
|
809 |
with gr.Column():
|
810 |
system_text = gr.Textbox(label='System Info')
|
811 |
system_btn = gr.Button(value='Get System Info')
|
812 |
|
813 |
+
zip_btn = gr.Button("Zip")
|
814 |
+
file_output = gr.File()
|
815 |
+
|
816 |
+
# Get flagged data
|
817 |
+
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
|
818 |
+
zip_btn.click(zip_data1, inputs=None, outputs=file_output)
|
819 |
|
820 |
+
def check_admin_pass(x):
|
821 |
+
return gr.update(visible=x == admin_pass)
|
822 |
+
|
823 |
+
admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row)
|
824 |
+
|
825 |
+
# Get inputs to evaluate()
|
826 |
inputs_list = get_inputs_list(locals(), kwargs['model_lower'])
|
827 |
from functools import partial
|
828 |
all_kwargs = kwargs.copy()
|
|
|
871 |
len(history[-1]) >= 2:
|
872 |
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
873 |
|
874 |
+
max_length_tokenize = 512 if is_low_mem else 2048
|
875 |
cutoff_len = max_length_tokenize*4 # restrict deberta related to max for LLM
|
876 |
|
877 |
question = history[-1][0]
|
|
|
893 |
clear_torch_cache()
|
894 |
return 'Response Score: GPU OOM'
|
895 |
except RuntimeError as e:
|
896 |
+
if 'Expected all tensors to be on the same device' in str(e) or \
|
897 |
+
'expected scalar type Half but found Float' in str(e) or \
|
898 |
+
'probability tensor contains either' in str(e):
|
899 |
print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
|
900 |
traceback.print_exc()
|
901 |
clear_torch_cache()
|
|
|
1087 |
outputs=[model_state, model_used, lora_used, prompt_type])
|
1088 |
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
|
1089 |
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
|
1090 |
+
if not is_public:
|
1091 |
load_model_event = load_model_button.click(**load_model_args) \
|
1092 |
.then(**prompt_update_args) \
|
1093 |
.then(**chatbot_update_args) \
|
|
|
1141 |
|
1142 |
|
1143 |
input_args_list = ['model_state']
|
1144 |
+
inputs_kwargs_list = ['debug', 'chat', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
|
1145 |
|
1146 |
|
1147 |
def get_inputs_list(inputs_dict, model_lower):
|
|
|
1204 |
src_lang=None,
|
1205 |
tgt_lang=None,
|
1206 |
debug=False,
|
1207 |
+
save_dir=None,
|
1208 |
chat=False,
|
1209 |
hard_stop_list=None,
|
1210 |
sanitize_bot_response=True,
|
|
|
1267 |
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
1268 |
# stopping only starts once output is beyond prompt
|
1269 |
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
1270 |
+
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
1271 |
encounters = [1, 2]
|
1272 |
elif prompt_type == 'instruct_vicuna':
|
1273 |
# even below is not enough, generic strings and many ways to encode
|
|
|
1298 |
# avoid padding in front of tokens
|
1299 |
if tokenizer.pad_token:
|
1300 |
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
1301 |
+
# handle fake \n added
|
1302 |
+
stop_words_ids = [x[1:] if y[0] == '\n' else x for x,y in zip(stop_words_ids, stop_words)]
|
1303 |
+
# build stopper
|
1304 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
|
1305 |
else:
|
1306 |
stopping_criteria = StoppingCriteriaList()
|
|
|
1309 |
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
|
1310 |
# RuntimeError: expected scalar type Half but found Float
|
1311 |
# with - 256
|
1312 |
+
max_length_tokenize = 768 - 256 if is_low_mem else 2048 - 256
|
1313 |
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
|
1314 |
output_smallest = 30 * 4
|
1315 |
prompt = prompt[-cutoff_len - output_smallest:]
|
|
|
1398 |
clear_torch_cache()
|
1399 |
return
|
1400 |
except RuntimeError as e:
|
1401 |
+
if 'Expected all tensors to be on the same device' in str(e) or \
|
1402 |
+
'expected scalar type Half but found Float' in str(e) or \
|
1403 |
+
'probability tensor contains either' in str(e):
|
1404 |
print(
|
1405 |
"GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
|
1406 |
flush=True)
|
|
|
1410 |
else:
|
1411 |
raise
|
1412 |
|
1413 |
+
decoded_output = None
|
1414 |
for output in CallbackToGenerator(generate, callback=None, **gen_kwargs):
|
1415 |
decoded_output = decoder(output)
|
1416 |
if output[-1] in [tokenizer.eos_token_id]:
|
|
|
1421 |
raise StopIteration
|
1422 |
yield prompter.get_response(decoded_output, prompt=inputs_decoded,
|
1423 |
sanitize_bot_response=sanitize_bot_response)
|
1424 |
+
if save_dir and decoded_output:
|
1425 |
+
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
1426 |
else:
|
1427 |
outputs = model.generate(**gen_kwargs)
|
1428 |
outputs = [decoder(s) for s in outputs.sequences]
|
1429 |
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
1430 |
sanitize_bot_response=sanitize_bot_response)
|
1431 |
+
if save_dir and outputs and len(outputs) >= 1:
|
1432 |
+
decoded_output = prompt + outputs[0]
|
1433 |
+
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
1434 |
|
1435 |
|
1436 |
def get_generate_params(model_lower, chat,
|
|
|
1641 |
|
1642 |
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
|
1643 |
|
1644 |
+
must have 4*48GB GPU and run without 8bit in order for sharding to work with infer_devices=False
|
1645 |
+
can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
|
1646 |
+
python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
|
1647 |
+
|
1648 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b
|
1649 |
+
|
1650 |
""", flush=True)
|
1651 |
fire.Fire(main)
|
finetune.py
CHANGED
@@ -73,6 +73,7 @@ prompt_type_to_model_name = {
|
|
73 |
'decapoda-research/llama-7b-hf',
|
74 |
'decapoda-research/llama-13b-hf',
|
75 |
'decapoda-research/llama-30b-hf',
|
|
|
76 |
'facebook/mbart-large-50-many-to-many-mmt',
|
77 |
'philschmid/bart-large-cnn-samsum',
|
78 |
'philschmid/flan-t5-base-samsum',
|
@@ -120,7 +121,10 @@ def train(
|
|
120 |
save_code: bool = False,
|
121 |
run_id: int = None,
|
122 |
|
123 |
-
base_model: str = '
|
|
|
|
|
|
|
124 |
# base_model: str = 'EleutherAI/pythia-12b-deduped',
|
125 |
# base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
|
126 |
# base_model: str = 'decapoda-research/llama-7b-hf',
|
|
|
73 |
'decapoda-research/llama-7b-hf',
|
74 |
'decapoda-research/llama-13b-hf',
|
75 |
'decapoda-research/llama-30b-hf',
|
76 |
+
'decapoda-research/llama-65b-hf',
|
77 |
'facebook/mbart-large-50-many-to-many-mmt',
|
78 |
'philschmid/bart-large-cnn-samsum',
|
79 |
'philschmid/flan-t5-base-samsum',
|
|
|
121 |
save_code: bool = False,
|
122 |
run_id: int = None,
|
123 |
|
124 |
+
base_model: str = 'h2oai/h2ogpt-oig-oasst1-256-6.9b',
|
125 |
+
# base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
|
126 |
+
# base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
|
127 |
+
# base_model: str = 'EleutherAI/gpt-neox-20b',
|
128 |
# base_model: str = 'EleutherAI/pythia-12b-deduped',
|
129 |
# base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
|
130 |
# base_model: str = 'decapoda-research/llama-7b-hf',
|
utils.py
CHANGED
@@ -1,7 +1,13 @@
|
|
|
|
1 |
import os
|
2 |
import gc
|
3 |
import random
|
|
|
4 |
import time
|
|
|
|
|
|
|
|
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
7 |
import torch
|
@@ -87,3 +93,59 @@ def system_info_print():
|
|
87 |
return df.to_markdown()
|
88 |
except Exception as e:
|
89 |
return "Error: %s" % str(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
import os
|
3 |
import gc
|
4 |
import random
|
5 |
+
import shutil
|
6 |
import time
|
7 |
+
import traceback
|
8 |
+
import zipfile
|
9 |
+
|
10 |
+
import filelock
|
11 |
import numpy as np
|
12 |
import pandas as pd
|
13 |
import torch
|
|
|
93 |
return df.to_markdown()
|
94 |
except Exception as e:
|
95 |
return "Error: %s" % str(e)
|
96 |
+
|
97 |
+
|
98 |
+
def zip_data(root_dirs=None, zip_path='data.zip', base_dir='./'):
|
99 |
+
try:
|
100 |
+
return _zip_data(zip_path=zip_path, base_dir=base_dir, root_dirs=root_dirs)
|
101 |
+
except Exception as e:
|
102 |
+
traceback.print_exc()
|
103 |
+
print('Exception in zipping: %s' % str(e))
|
104 |
+
|
105 |
+
|
106 |
+
def _zip_data(root_dirs=None, zip_path='data.zip', base_dir='./'):
|
107 |
+
assert root_dirs is not None
|
108 |
+
with zipfile.ZipFile(zip_path, "w") as expt_zip:
|
109 |
+
for root_dir in root_dirs:
|
110 |
+
if root_dir is None:
|
111 |
+
continue
|
112 |
+
for root, d, files in os.walk(root_dir):
|
113 |
+
for file in files:
|
114 |
+
file_to_archive = os.path.join(root, file)
|
115 |
+
assert os.path.exists(file_to_archive)
|
116 |
+
path_to_archive = os.path.relpath(file_to_archive, base_dir)
|
117 |
+
expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
|
118 |
+
return "data.zip"
|
119 |
+
|
120 |
+
|
121 |
+
def save_generate_output(output=None, base_model=None, save_dir=None):
|
122 |
+
try:
|
123 |
+
return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
|
124 |
+
except Exception as e:
|
125 |
+
traceback.print_exc()
|
126 |
+
print('Exception in saving: %s' % str(e))
|
127 |
+
|
128 |
+
|
129 |
+
def _save_generate_output(output=None, base_model=None, save_dir=None):
|
130 |
+
"""
|
131 |
+
Save conversation to .json, row by row.
|
132 |
+
json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
|
133 |
+
Appends if file exists
|
134 |
+
"""
|
135 |
+
assert save_dir, "save_dir must be provided"
|
136 |
+
if os.path.exists(save_dir) and not os.path.isdir(save_dir):
|
137 |
+
raise RuntimeError("save_dir already exists and is not a directory!")
|
138 |
+
os.makedirs(save_dir, exist_ok=True)
|
139 |
+
import json
|
140 |
+
if output[-10:] == '\n\n<human>:':
|
141 |
+
# remove trailing <human>:
|
142 |
+
output = output[:-10]
|
143 |
+
with filelock.FileLock("save_dir.lock"):
|
144 |
+
# lock logging in case have concurrency
|
145 |
+
with open(os.path.join(save_dir, "history.json"), "a") as f:
|
146 |
+
# just add [ at start, and ] at end, and have proper JSON dataset
|
147 |
+
f.write(
|
148 |
+
" " + json.dumps(
|
149 |
+
dict(text=output, time=time.ctime(), base_model=base_model)
|
150 |
+
) + ",\n"
|
151 |
+
)
|