pseudotensor commited on
Commit
9aa08b9
·
1 Parent(s): 9e9d047

Update with h2oGPT hash d5a4556404029122394e3b1c0a4ea97d8c996bb6

Browse files
Files changed (3) hide show
  1. generate.py +187 -73
  2. gradio_runner.py +107 -84
  3. utils.py +34 -0
generate.py CHANGED
@@ -1,14 +1,15 @@
1
  import functools
 
2
  import sys
3
  import os
 
4
  import traceback
5
  import typing
6
- from threading import Thread
7
  from datetime import datetime
8
  import filelock
9
  import psutil
10
 
11
- from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
12
 
13
  SEED = 1236
14
  set_seed(SEED)
@@ -35,11 +36,11 @@ eval_extra_columns = ['prompt', 'response', 'score']
35
  def main(
36
  load_8bit: bool = False,
37
  load_half: bool = True,
38
- infer_devices: bool = True, # really if to "control" devices now
39
  base_model: str = '',
40
  tokenizer_base_model: str = '',
41
  lora_weights: str = "",
42
- gpu_id: int = 0, # if infer_devices = True and gpu_id != -1
43
 
44
  prompt_type: Union[int, str] = None,
45
  # input to generation
@@ -60,7 +61,7 @@ def main(
60
  share: bool = True,
61
  local_files_only: bool = False,
62
  resume_download: bool = True,
63
- use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
64
 
65
  src_lang: str = "English",
66
  tgt_lang: str = "Russian",
@@ -68,20 +69,18 @@ def main(
68
  gradio: bool = True,
69
  gradio_avoid_processing_markdown: bool = False,
70
  chat: bool = True,
71
- chat_history: int = 4096, # character length of chat context/history
72
- chat_context: bool = False, # use default context if human_bot
73
  stream_output: bool = True,
74
  show_examples: bool = None,
75
  verbose: bool = False,
76
  h2ocolors: bool = True,
77
  height: int = 400,
78
  show_lora: bool = True,
79
- # set to True to load --base_model after client logs in,
80
- # to be able to free GPU memory when model is swapped
81
  login_mode_if_model0: bool = False,
82
  block_gradio_exit: bool = True,
83
  concurrency_count: int = 1,
84
- api_open: bool = False, # don't let API skip queue
85
  allow_api: bool = True,
86
  input_lines: int = 1,
87
 
@@ -97,9 +96,64 @@ def main(
97
  eval_sharegpt_prompts_only: int = 0,
98
  eval_sharegpt_prompts_only_seed: int = 1234,
99
  eval_sharegpt_as_output: bool = False,
100
-
101
- hard_stop_list: typing.List[str] = [],
102
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
104
  is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
105
  is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
@@ -107,7 +161,7 @@ def main(
107
  admin_pass = os.getenv("ADMIN_PASS")
108
  # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
109
  # but becomes unrecoverable sometimes if raise, so just be silent for now
110
- raise_generate_gpu_exceptions = not is_public
111
 
112
  # allow set token directly
113
  use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
@@ -223,9 +277,10 @@ def main(
223
  eval_filename = os.path.join(scoring_path, eval_filename)
224
 
225
  # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
226
- context_class = NullContext() if n_gpus > 1 or n_gpus == 0 else torch.device("cuda")
 
227
 
228
- with context_class:
229
  # ensure was set right above before examples generated
230
  assert not stream_output, "stream_output=True does not make sense with example loop"
231
  import time
@@ -240,7 +295,8 @@ def main(
240
  fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir, is_low_mem=is_low_mem,
241
  raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
242
  chat_context=chat_context,
243
- concurrency_count=concurrency_count)
 
244
  else:
245
  assert eval_sharegpt_prompts_only > 0
246
 
@@ -288,7 +344,7 @@ def main(
288
  truncation=True,
289
  max_length=cutoff_len)
290
  try:
291
- score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
292
  except torch.cuda.OutOfMemoryError as e:
293
  print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
294
  traceback.print_exc()
@@ -649,12 +705,12 @@ def evaluate(
649
  debug=False,
650
  concurrency_count=None,
651
  save_dir=None,
652
- hard_stop_list=None,
653
  sanitize_bot_response=True,
654
  model_state0=None,
655
  is_low_mem=None,
656
  raise_generate_gpu_exceptions=None,
657
  chat_context=None,
 
658
  ):
659
  # ensure passed these
660
  assert concurrency_count is not None
@@ -710,10 +766,6 @@ def evaluate(
710
  prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
711
  prompt = prompter.generate_prompt(data_point)
712
 
713
- if hard_stop_list is None:
714
- # acts like undo on user entry and bot response
715
- hard_stop_list = []
716
-
717
  if isinstance(tokenizer, str):
718
  # pipeline
719
  if tokenizer == "summarization":
@@ -829,55 +881,115 @@ def evaluate(
829
  )
830
 
831
  with torch.no_grad():
832
- # protection for gradio not keeping track of closed users,
833
- # else hit bitsandbytes lack of thread safety:
834
- # https://github.com/h2oai/h2ogpt/issues/104
835
- # but only makes sense if concurrency_count == 1
836
- context_class = NullContext #if concurrency_count > 1 else filelock.FileLock
837
- print('Pre-Generate: %s' % str(datetime.now()), flush=True)
838
- decoded_output = None
839
- with context_class("generate.lock"):
840
- print('Generate: %s' % str(datetime.now()), flush=True)
841
- # decoded tokenized prompt can deviate from prompt due to special characters
842
- inputs_decoded = decoder(input_ids[0])
843
- inputs_decoded_raw = decoder_raw(input_ids[0])
844
- if inputs_decoded == prompt:
845
- # normal
846
- pass
847
- elif inputs_decoded.lstrip() == prompt.lstrip():
848
- # sometimes extra space in front, make prompt same for prompt removal
849
- prompt = inputs_decoded
850
- elif inputs_decoded_raw == prompt:
851
- # some models specify special tokens that are part of normal prompt, so can't skip them
852
- inputs_decoded_raw = inputs_decoded
853
- decoder = decoder_raw
854
- else:
855
- print("WARNING: Special characters in prompt", flush=True)
856
- if stream_output:
857
- skip_prompt = False
858
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
859
- gen_kwargs.update(dict(streamer=streamer))
860
- target_func = generate_with_exceptions
861
- target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
862
- raise_generate_gpu_exceptions, **gen_kwargs)
863
- thread = Thread(target=target)
864
- thread.start()
865
- outputs = ""
866
- for new_text in streamer:
867
- outputs += new_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
  yield prompter.get_response(outputs, prompt=inputs_decoded,
869
  sanitize_bot_response=sanitize_bot_response)
870
- decoded_output = outputs
871
- else:
872
- outputs = model.generate(**gen_kwargs)
873
- outputs = [decoder(s) for s in outputs.sequences]
874
- yield prompter.get_response(outputs, prompt=inputs_decoded,
875
- sanitize_bot_response=sanitize_bot_response)
876
- if outputs and len(outputs) >= 1:
877
- decoded_output = prompt + outputs[0]
878
- if save_dir and decoded_output:
879
- save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
880
- print('Post-Generate: %s decoded_output: %s' % (str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881
 
882
 
883
  def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
@@ -908,7 +1020,8 @@ def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_ex
908
  return
909
  else:
910
  clear_torch_cache()
911
- raise
 
912
 
913
 
914
  def get_generate_params(model_lower, chat,
@@ -1154,7 +1267,9 @@ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_l
1154
 
1155
 
1156
  if __name__ == "__main__":
1157
- print("""
 
 
1158
  WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
1159
  python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
1160
  python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
@@ -1180,6 +1295,5 @@ if __name__ == "__main__":
1180
  python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
1181
 
1182
  python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
1183
-
1184
- """, flush=True)
1185
  fire.Fire(main)
 
1
  import functools
2
+ import queue
3
  import sys
4
  import os
5
+ import time
6
  import traceback
7
  import typing
 
8
  from datetime import datetime
9
  import filelock
10
  import psutil
11
 
12
+ from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread
13
 
14
  SEED = 1236
15
  set_seed(SEED)
 
36
  def main(
37
  load_8bit: bool = False,
38
  load_half: bool = True,
39
+ infer_devices: bool = True,
40
  base_model: str = '',
41
  tokenizer_base_model: str = '',
42
  lora_weights: str = "",
43
+ gpu_id: int = 0,
44
 
45
  prompt_type: Union[int, str] = None,
46
  # input to generation
 
61
  share: bool = True,
62
  local_files_only: bool = False,
63
  resume_download: bool = True,
64
+ use_auth_token: Union[str, bool] = False,
65
 
66
  src_lang: str = "English",
67
  tgt_lang: str = "Russian",
 
69
  gradio: bool = True,
70
  gradio_avoid_processing_markdown: bool = False,
71
  chat: bool = True,
72
+ chat_history: int = 4096,
73
+ chat_context: bool = False,
74
  stream_output: bool = True,
75
  show_examples: bool = None,
76
  verbose: bool = False,
77
  h2ocolors: bool = True,
78
  height: int = 400,
79
  show_lora: bool = True,
 
 
80
  login_mode_if_model0: bool = False,
81
  block_gradio_exit: bool = True,
82
  concurrency_count: int = 1,
83
+ api_open: bool = False,
84
  allow_api: bool = True,
85
  input_lines: int = 1,
86
 
 
96
  eval_sharegpt_prompts_only: int = 0,
97
  eval_sharegpt_prompts_only_seed: int = 1234,
98
  eval_sharegpt_as_output: bool = False,
 
 
99
  ):
100
+ """
101
+
102
+ :param load_8bit: load model in 8-bit using bitsandbytes
103
+ :param load_half: load model in float16
104
+ :param infer_devices: whether to control devices with gpu_id. If False, then spread across GPUs
105
+ :param base_model: model HF-type name
106
+ :param tokenizer_base_model: tokenizer HF-type name
107
+ :param lora_weights: LORA weights path/HF link
108
+ :param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
109
+ :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
110
+ :param temperature: generation temperature
111
+ :param top_p: generation top_p
112
+ :param top_k: generation top_k
113
+ :param num_beams: generatino number of beams
114
+ :param repetition_penalty: generation repetition penalty
115
+ :param num_return_sequences: generation number of sequences (1 forced for chat)
116
+ :param do_sample: generation sample
117
+ :param max_new_tokens: generation max new tokens
118
+ :param min_new_tokens: generation min tokens
119
+ :param early_stopping: generation early stopping
120
+ :param max_time: maximum time to allow for generation
121
+ :param debug: enable debug mode
122
+ :param save_dir: directory chat data is saved to
123
+ :param share: whether to share the gradio app with sharable URL
124
+ :param local_files_only: whether to only use local files instead of doing to HF for models
125
+ :param resume_download: whether to resume downloads from HF for models
126
+ :param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
127
+ :param src_lang: source languages to include if doing translation (None = all)
128
+ :param tgt_lang: target languages to include if doing translation (None = all)
129
+ :param gradio: whether to enable gradio, or to enable benchmark mode
130
+ :param gradio_avoid_processing_markdown:
131
+ :param chat: whether to enable chat mode with chat history
132
+ :param chat_history: maximum character length of chat context/history
133
+ :param chat_context: whether to use extra helpful context if human_bot
134
+ :param stream_output: whether to stream output from generate
135
+ :param show_examples: whether to show clickable examples in gradio
136
+ :param verbose: whether to show verbose prints
137
+ :param h2ocolors: whether to use H2O.ai theme
138
+ :param height: height of chat window
139
+ :param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
140
+ :param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
141
+ :param block_gradio_exit: whether to block gradio exit (used for testing)
142
+ :param concurrency_count: gradio concurrency count (1 is optimal for LLMs)
143
+ :param api_open: If False, don't let API calls skip gradio queue
144
+ :param allow_api: whether to allow API calls at all to gradio server
145
+ :param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
146
+ :param sanitize_user_prompt: whether to remove profanity from user input
147
+ :param sanitize_bot_response: whether to remove profanity and repeat lines from bot output
148
+ :param extra_model_options: extra models to show in list in gradio
149
+ :param extra_lora_options: extra LORAA to show in list in gradio
150
+ :param score_model: which model to score responses (None means no scoring)
151
+ :param auto_score: whether to automatically score responses
152
+ :param eval_sharegpt_prompts_only: for no gradio benchmark, if using ShareGPT prompts for eval
153
+ :param eval_sharegpt_prompts_only_seed: for no gradio benchmark, if seed for ShareGPT sampling
154
+ :param eval_sharegpt_as_output: for no gradio benchmark, whether to test ShareGPT output itself
155
+ :return:
156
+ """
157
  is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
158
  is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
159
  is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
 
161
  admin_pass = os.getenv("ADMIN_PASS")
162
  # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
163
  # but becomes unrecoverable sometimes if raise, so just be silent for now
164
+ raise_generate_gpu_exceptions = True
165
 
166
  # allow set token directly
167
  use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
 
277
  eval_filename = os.path.join(scoring_path, eval_filename)
278
 
279
  # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
280
+ device = 'cpu' if n_gpus == 0 else 'cuda'
281
+ context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
282
 
283
+ with context_class(device):
284
  # ensure was set right above before examples generated
285
  assert not stream_output, "stream_output=True does not make sense with example loop"
286
  import time
 
295
  fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir, is_low_mem=is_low_mem,
296
  raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
297
  chat_context=chat_context,
298
+ concurrency_count=concurrency_count,
299
+ lora_weights=lora_weights)
300
  else:
301
  assert eval_sharegpt_prompts_only > 0
302
 
 
344
  truncation=True,
345
  max_length=cutoff_len)
346
  try:
347
+ score = torch.sigmoid(smodel(**inputs).logits[0].float()).cpu().detach().numpy()[0]
348
  except torch.cuda.OutOfMemoryError as e:
349
  print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
350
  traceback.print_exc()
 
705
  debug=False,
706
  concurrency_count=None,
707
  save_dir=None,
 
708
  sanitize_bot_response=True,
709
  model_state0=None,
710
  is_low_mem=None,
711
  raise_generate_gpu_exceptions=None,
712
  chat_context=None,
713
+ lora_weights=None,
714
  ):
715
  # ensure passed these
716
  assert concurrency_count is not None
 
766
  prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
767
  prompt = prompter.generate_prompt(data_point)
768
 
 
 
 
 
769
  if isinstance(tokenizer, str):
770
  # pipeline
771
  if tokenizer == "summarization":
 
881
  )
882
 
883
  with torch.no_grad():
884
+ context_class_cast = NullContext if device == 'cpu' or lora_weights else torch.autocast
885
+ with context_class_cast(device):
886
+ # protection for gradio not keeping track of closed users,
887
+ # else hit bitsandbytes lack of thread safety:
888
+ # https://github.com/h2oai/h2ogpt/issues/104
889
+ # but only makes sense if concurrency_count == 1
890
+ context_class = NullContext #if concurrency_count > 1 else filelock.FileLock
891
+ print('Pre-Generate: %s' % str(datetime.now()), flush=True)
892
+ decoded_output = None
893
+ with context_class("generate.lock"):
894
+ print('Generate: %s' % str(datetime.now()), flush=True)
895
+ # decoded tokenized prompt can deviate from prompt due to special characters
896
+ inputs_decoded = decoder(input_ids[0])
897
+ inputs_decoded_raw = decoder_raw(input_ids[0])
898
+ if inputs_decoded == prompt:
899
+ # normal
900
+ pass
901
+ elif inputs_decoded.lstrip() == prompt.lstrip():
902
+ # sometimes extra space in front, make prompt same for prompt removal
903
+ prompt = inputs_decoded
904
+ elif inputs_decoded_raw == prompt:
905
+ # some models specify special tokens that are part of normal prompt, so can't skip them
906
+ inputs_decoded_raw = inputs_decoded
907
+ decoder = decoder_raw
908
+ else:
909
+ print("WARNING: Special characters in prompt", flush=True)
910
+ if stream_output:
911
+ skip_prompt = False
912
+ streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False)
913
+ gen_kwargs.update(dict(streamer=streamer))
914
+ target_func = generate_with_exceptions
915
+ target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
916
+ raise_generate_gpu_exceptions, **gen_kwargs)
917
+ bucket = queue.Queue()
918
+ thread = EThread(target=target, kwargs=dict(streamer=streamer), bucket=bucket)
919
+ thread.start()
920
+ outputs = ""
921
+ try:
922
+ for new_text in streamer:
923
+ if bucket.qsize() > 0 or thread.exc:
924
+ thread.join()
925
+ outputs += new_text
926
+ yield prompter.get_response(outputs, prompt=inputs_decoded,
927
+ sanitize_bot_response=sanitize_bot_response)
928
+ except BaseException:
929
+ # if any exception, raise that exception if was from thread, first
930
+ if thread.exc:
931
+ raise thread.exc
932
+ raise
933
+ finally:
934
+ # in case no exception and didn't join with thread yet, then join
935
+ if not thread.exc:
936
+ thread.join()
937
+ # in case raise StopIteration or broke queue loop in streamer, but still have exception
938
+ if thread.exc:
939
+ raise thread.exc
940
+ decoded_output = outputs
941
+ else:
942
+ outputs = model.generate(**gen_kwargs)
943
+ outputs = [decoder(s) for s in outputs.sequences]
944
  yield prompter.get_response(outputs, prompt=inputs_decoded,
945
  sanitize_bot_response=sanitize_bot_response)
946
+ if outputs and len(outputs) >= 1:
947
+ decoded_output = prompt + outputs[0]
948
+ if save_dir and decoded_output:
949
+ save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
950
+ print('Post-Generate: %s decoded_output: %s' % (str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
951
+
952
+
953
+ class H2OTextIteratorStreamer(TextIteratorStreamer):
954
+ """
955
+ normally, timeout required for now to handle exceptions, else get()
956
+ but with H2O version of TextIteratorStreamer, loop over block to handle
957
+ """
958
+ def __init__(self, tokenizer, skip_prompt: bool = False, timeout: typing.Optional[float] = None,
959
+ block=True, **decode_kwargs):
960
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
961
+ self.text_queue = queue.Queue()
962
+ self.stop_signal = None
963
+ self.do_stop = False
964
+ self.timeout = timeout
965
+ self.block = block
966
+
967
+ def on_finalized_text(self, text: str, stream_end: bool = False):
968
+ """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
969
+ self.text_queue.put(text, timeout=self.timeout)
970
+ if stream_end:
971
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
972
+
973
+ def __iter__(self):
974
+ return self
975
+
976
+ def __next__(self):
977
+ while True:
978
+ try:
979
+ value = self.stop_signal # value looks unused in pycharm, not true
980
+ if self.do_stop:
981
+ print("hit stop", flush=True)
982
+ # could raise or break, maybe best to raise and make parent see if any exception in thread
983
+ raise StopIteration()
984
+ #break
985
+ value = self.text_queue.get(block=self.block, timeout=self.timeout)
986
+ break
987
+ except queue.Empty:
988
+ time.sleep(0.01)
989
+ if value == self.stop_signal:
990
+ raise StopIteration()
991
+ else:
992
+ return value
993
 
994
 
995
  def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
 
1020
  return
1021
  else:
1022
  clear_torch_cache()
1023
+ if raise_generate_gpu_exceptions:
1024
+ raise
1025
 
1026
 
1027
  def get_generate_params(model_lower, chat,
 
1267
 
1268
 
1269
  if __name__ == "__main__":
1270
+ """
1271
+ Examples:
1272
+
1273
  WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
1274
  python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
1275
  python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
 
1295
  python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
1296
 
1297
  python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
1298
+ """
 
1299
  fire.Fire(main)
gradio_runner.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import functools
2
  import inspect
3
  import os
@@ -246,7 +247,11 @@ def go_gradio(**kwargs):
246
  value=kwargs['top_k'], label="Top k",
247
  info='Num. tokens to sample from'
248
  )
249
- max_beams = 8 if not is_low_mem else 1
 
 
 
 
250
  num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
251
  value=min(max_beams, kwargs['num_beams']), label="Beams",
252
  info="Number of searches for optimal overall probability. "
@@ -262,7 +267,9 @@ def go_gradio(**kwargs):
262
  )
263
  early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
264
  value=kwargs['early_stopping'])
265
- max_max_time = 60 * 5 if not is_low_mem else 60
 
 
266
  max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
267
  value=min(max_max_time, kwargs['max_time']), label="Max. time",
268
  info="Max. time to search optimal output.")
@@ -309,9 +316,10 @@ def go_gradio(**kwargs):
309
  model_gpu = gr.Dropdown(n_gpus_list,
310
  label="GPU ID 2 [-1 = all GPUs, if Choose is enabled]",
311
  value=kwargs['gpu_id'])
312
- model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
 
313
  lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
314
- visible=kwargs['show_lora'])
315
  with gr.Row():
316
  with gr.Column(scale=50):
317
  new_model = gr.Textbox(label="New Model HF name/path")
@@ -354,15 +362,15 @@ def go_gradio(**kwargs):
354
  with gr.Column():
355
  with gr.Row():
356
  system_btn = gr.Button(value='Get System Info')
357
- system_text = gr.Textbox(label='System Info')
358
 
359
  with gr.Row():
360
  zip_btn = gr.Button("Zip")
361
- zip_text = gr.Textbox(label="Zip file name")
362
  file_output = gr.File()
363
  with gr.Row():
364
  s3up_btn = gr.Button("S3UP")
365
- s3up_text = gr.Textbox(label='S3UP result')
366
 
367
  # Get flagged data
368
  zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
@@ -395,12 +403,15 @@ def go_gradio(**kwargs):
395
  dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
396
  size="sm",
397
  )
 
 
398
  dark_mode_btn.click(
399
  None,
400
  None,
401
  None,
402
  _js=get_dark_js(),
403
  api_name="dark" if allow_api else None,
 
404
  )
405
 
406
  # Control chat and non-chat blocks, which can be independently used by chat checkbox swap
@@ -415,7 +426,8 @@ def go_gradio(**kwargs):
415
 
416
  chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox" if allow_api else None) \
417
  .then(col_chat_fun, chat, col_chat) \
418
- .then(context_fun, chat, context)
 
419
 
420
  # examples after submit or any other buttons for chat or no chat
421
  if kwargs['examples'] is not None and kwargs['show_examples']:
@@ -514,6 +526,10 @@ def go_gradio(**kwargs):
514
  if sanitize_user_prompt:
515
  from better_profanity import profanity
516
  user_message1 = profanity.censor(user_message1)
 
 
 
 
517
 
518
  history = args_list[-1]
519
  if undo and history:
@@ -541,15 +557,17 @@ def go_gradio(**kwargs):
541
  :param retry:
542
  :return:
543
  """
544
- args_list = list(args).copy()
545
  history = args_list[-1] # model_state is -2
546
  if retry and history:
547
  history.pop()
548
  if not history:
549
  print("No history", flush=True)
 
 
550
  return
551
  # ensure output will be unique to models
552
- history = history.copy()
553
  instruction1 = history[-1][0]
554
  context1 = ''
555
  if kwargs['chat_history'] > 0:
@@ -571,6 +589,8 @@ def go_gradio(**kwargs):
571
  args_list[2] = context1[-kwargs['chat_history']:]
572
  model_state1 = args_list[-2]
573
  if model_state1[0] is None or model_state1[0] == no_model_str:
 
 
574
  return
575
  args_list = args_list[:-2]
576
  fun1 = partial(evaluate,
@@ -580,19 +600,25 @@ def go_gradio(**kwargs):
580
  for output in fun1(*tuple(args_list)):
581
  bot_message = output
582
  history[-1][1] = bot_message
583
- yield history
584
  except StopIteration:
585
- yield history
586
  except RuntimeError as e:
587
  if "generator raised StopIteration" in str(e):
588
  # assume last entry was bad, undo
589
  history.pop()
590
- yield history
591
- raise
 
 
 
 
592
  except Exception as e:
593
  # put error into user input
594
- history[-1][0] = "Exception: %s" % str(e)
595
- yield history
 
 
596
  raise
597
  return
598
 
@@ -603,11 +629,11 @@ def go_gradio(**kwargs):
603
  )
604
  bot_args = dict(fn=bot,
605
  inputs=inputs_list + [model_state] + [text_output],
606
- outputs=text_output,
607
  )
608
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
609
  inputs=inputs_list + [model_state] + [text_output],
610
- outputs=text_output,
611
  )
612
  undo_user_args = dict(fn=functools.partial(user, undo=True),
613
  inputs=inputs_list + [text_output],
@@ -621,11 +647,11 @@ def go_gradio(**kwargs):
621
  )
622
  bot_args2 = dict(fn=bot,
623
  inputs=inputs_list + [model_state2] + [text_output2],
624
- outputs=text_output2,
625
  )
626
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
627
  inputs=inputs_list + [model_state2] + [text_output2],
628
- outputs=text_output2,
629
  )
630
  undo_user_args2 = dict(fn=functools.partial(user, undo=True),
631
  inputs=inputs_list + [text_output2],
@@ -636,67 +662,61 @@ def go_gradio(**kwargs):
636
  return gr.Textbox.update(value='')
637
 
638
  if kwargs['auto_score']:
639
- # in case 2nd model, consume instruction first, so can clear quickly
640
- # bot doesn't consume instruction itself, just history from user, so why works
641
- submit_event = instruction.submit(**user_args, queue=queue,
642
- api_name='instruction' if allow_api else None) \
643
- .then(**user_args2, api_name='instruction2' if allow_api else None) \
644
- .then(clear_instruct, None, instruction) \
645
- .then(clear_instruct, None, iinput) \
646
- .then(**bot_args, api_name='instruction_bot' if allow_api else None, queue=queue) \
647
- .then(**score_args, api_name='instruction_bot_score' if allow_api else None, queue=queue) \
648
- .then(**bot_args2, api_name='instruction_bot2' if allow_api else None, queue=queue) \
649
- .then(**score_args2, api_name='instruction_bot_score2' if allow_api else None, queue=queue) \
650
- .then(clear_torch_cache)
651
- submit_event2 = submit.click(**user_args, api_name='submit' if allow_api else None) \
652
- .then(**user_args2, api_name='submit2' if allow_api else None) \
653
- .then(clear_instruct, None, instruction) \
654
- .then(clear_instruct, None, iinput) \
655
- .then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) \
656
- .then(**score_args, api_name='submit_bot_score' if allow_api else None, queue=queue) \
657
- .then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue) \
658
- .then(**score_args2, api_name='submit_bot_score2' if allow_api else None, queue=queue) \
659
- .then(clear_torch_cache)
660
- submit_event3 = retry.click(**user_args, api_name='retry' if allow_api else None) \
661
- .then(**user_args2, api_name='retry2' if allow_api else None) \
662
- .then(clear_instruct, None, instruction) \
663
- .then(clear_instruct, None, iinput) \
664
- .then(**retry_bot_args, api_name='retry_bot' if allow_api else None, queue=queue) \
665
- .then(**score_args, api_name='retry_bot_score' if allow_api else None, queue=queue) \
666
- .then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, queue=queue) \
667
- .then(**score_args2, api_name='retry_bot_score2' if allow_api else None, queue=queue) \
668
- .then(clear_torch_cache)
669
- submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
670
- .then(**undo_user_args2, api_name='undo2' if allow_api else None) \
671
- .then(clear_instruct, None, instruction) \
672
- .then(clear_instruct, None, iinput) \
673
- .then(**score_args, api_name='undo_score' if allow_api else None) \
674
- .then(**score_args2, api_name='undo_score2' if allow_api else None)
675
  else:
676
- submit_event = instruction.submit(**user_args,
677
- api_name='instruction' if allow_api else None) \
678
- .then(**user_args2, api_name='instruction2' if allow_api else None) \
679
- .then(clear_instruct, None, instruction) \
680
- .then(clear_instruct, None, iinput) \
681
- .then(**bot_args, api_name='instruction_bot' if allow_api else None, queue=queue) \
682
- .then(**bot_args2, api_name='instruction_bot2' if allow_api else None, queue=queue) \
683
- .then(clear_torch_cache)
684
- submit_event2 = submit.click(**user_args, api_name='submit' if allow_api else None) \
685
- .then(**user_args2, api_name='submit2' if allow_api else None) \
686
- .then(clear_instruct, None, instruction) \
687
- .then(clear_instruct, None, iinput) \
688
- .then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) \
689
- .then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue) \
690
- .then(clear_torch_cache)
691
- submit_event3 = retry.click(**user_args, api_name='retry' if allow_api else None) \
692
- .then(**user_args2, api_name='retry2' if allow_api else None) \
693
- .then(clear_instruct, None, instruction) \
694
- .then(clear_instruct, None, iinput) \
695
- .then(**retry_bot_args, api_name='retry_bot' if allow_api else None, queue=queue) \
696
- .then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, queue=queue) \
697
- .then(clear_torch_cache)
698
- submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
699
- .then(**undo_user_args2, api_name='undo2' if allow_api else None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
 
701
  # does both models
702
  clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
@@ -864,9 +884,12 @@ def go_gradio(**kwargs):
864
  api_name='system_info' if allow_api else None, queue=False)
865
 
866
  # don't pass text_output, don't want to clear output, just stop it
867
- # FIXME: have to click once to stop output and second time to stop GPUs going
868
  stop_btn.click(lambda: None, None, None,
869
- cancels=[submit_event_nochat, submit_event, submit_event2, submit_event3],
 
 
 
870
  queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
871
  demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
872
 
@@ -887,8 +910,8 @@ def go_gradio(**kwargs):
887
 
888
 
889
  input_args_list = ['model_state']
890
- inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
891
- 'raise_generate_gpu_exceptions', 'chat_context', 'concurrency_count']
892
 
893
 
894
  def get_inputs_list(inputs_dict, model_lower):
 
1
+ import copy
2
  import functools
3
  import inspect
4
  import os
 
247
  value=kwargs['top_k'], label="Top k",
248
  info='Num. tokens to sample from'
249
  )
250
+ # FIXME: https://github.com/h2oai/h2ogpt/issues/106
251
+ if os.getenv('TESTINGFAIL'):
252
+ max_beams = 8 if not (is_low_mem or is_public) else 1
253
+ else:
254
+ max_beams = 1
255
  num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
256
  value=min(max_beams, kwargs['num_beams']), label="Beams",
257
  info="Number of searches for optimal overall probability. "
 
267
  )
268
  early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
269
  value=kwargs['early_stopping'])
270
+ max_max_time = 60 * 5 if not is_public else 60 * 2
271
+ if is_hf:
272
+ max_max_time = min(max_max_time, 60 * 1)
273
  max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
274
  value=min(max_max_time, kwargs['max_time']), label="Max. time",
275
  info="Max. time to search optimal output.")
 
316
  model_gpu = gr.Dropdown(n_gpus_list,
317
  label="GPU ID 2 [-1 = all GPUs, if Choose is enabled]",
318
  value=kwargs['gpu_id'])
319
+ model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'],
320
+ interactive=False)
321
  lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
322
+ visible=kwargs['show_lora'], interactive=False)
323
  with gr.Row():
324
  with gr.Column(scale=50):
325
  new_model = gr.Textbox(label="New Model HF name/path")
 
362
  with gr.Column():
363
  with gr.Row():
364
  system_btn = gr.Button(value='Get System Info')
365
+ system_text = gr.Textbox(label='System Info', interactive=False)
366
 
367
  with gr.Row():
368
  zip_btn = gr.Button("Zip")
369
+ zip_text = gr.Textbox(label="Zip file name", interactive=False)
370
  file_output = gr.File()
371
  with gr.Row():
372
  s3up_btn = gr.Button("S3UP")
373
+ s3up_text = gr.Textbox(label='S3UP result', interactive=False)
374
 
375
  # Get flagged data
376
  zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
 
403
  dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
404
  size="sm",
405
  )
406
+ # FIXME: Could add exceptions for non-chat but still streaming
407
+ exception_text = gr.Textbox(value="", visible=kwargs['chat'], label='Chat Exceptions', interactive=False)
408
  dark_mode_btn.click(
409
  None,
410
  None,
411
  None,
412
  _js=get_dark_js(),
413
  api_name="dark" if allow_api else None,
414
+ queue=False,
415
  )
416
 
417
  # Control chat and non-chat blocks, which can be independently used by chat checkbox swap
 
426
 
427
  chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox" if allow_api else None) \
428
  .then(col_chat_fun, chat, col_chat) \
429
+ .then(context_fun, chat, context) \
430
+ .then(col_chat_fun, chat, exception_text)
431
 
432
  # examples after submit or any other buttons for chat or no chat
433
  if kwargs['examples'] is not None and kwargs['show_examples']:
 
526
  if sanitize_user_prompt:
527
  from better_profanity import profanity
528
  user_message1 = profanity.censor(user_message1)
529
+ if user_message1 in ['']:
530
+ # e.g. when user just hits enter in textbox,
531
+ # else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
532
+ user_message1 = '\n'
533
 
534
  history = args_list[-1]
535
  if undo and history:
 
557
  :param retry:
558
  :return:
559
  """
560
+ args_list = copy.deepcopy(list(args))
561
  history = args_list[-1] # model_state is -2
562
  if retry and history:
563
  history.pop()
564
  if not history:
565
  print("No history", flush=True)
566
+ history = [['', None]]
567
+ yield history, ''
568
  return
569
  # ensure output will be unique to models
570
+ history = copy.deepcopy(history)
571
  instruction1 = history[-1][0]
572
  context1 = ''
573
  if kwargs['chat_history'] > 0:
 
589
  args_list[2] = context1[-kwargs['chat_history']:]
590
  model_state1 = args_list[-2]
591
  if model_state1[0] is None or model_state1[0] == no_model_str:
592
+ history = [['', None]]
593
+ yield history, ''
594
  return
595
  args_list = args_list[:-2]
596
  fun1 = partial(evaluate,
 
600
  for output in fun1(*tuple(args_list)):
601
  bot_message = output
602
  history[-1][1] = bot_message
603
+ yield history, ''
604
  except StopIteration:
605
+ yield history, ''
606
  except RuntimeError as e:
607
  if "generator raised StopIteration" in str(e):
608
  # assume last entry was bad, undo
609
  history.pop()
610
+ yield history, ''
611
+ else:
612
+ if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None:
613
+ history[-1][1] = ''
614
+ yield history, str(e)
615
+ raise
616
  except Exception as e:
617
  # put error into user input
618
+ ex = "Exception: %s" % str(e)
619
+ if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None:
620
+ history[-1][1] = ''
621
+ yield history, ex
622
  raise
623
  return
624
 
 
629
  )
630
  bot_args = dict(fn=bot,
631
  inputs=inputs_list + [model_state] + [text_output],
632
+ outputs=[text_output, exception_text],
633
  )
634
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
635
  inputs=inputs_list + [model_state] + [text_output],
636
+ outputs=[text_output, exception_text],
637
  )
638
  undo_user_args = dict(fn=functools.partial(user, undo=True),
639
  inputs=inputs_list + [text_output],
 
647
  )
648
  bot_args2 = dict(fn=bot,
649
  inputs=inputs_list + [model_state2] + [text_output2],
650
+ outputs=[text_output2, exception_text],
651
  )
652
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
653
  inputs=inputs_list + [model_state2] + [text_output2],
654
+ outputs=[text_output2, exception_text],
655
  )
656
  undo_user_args2 = dict(fn=functools.partial(user, undo=True),
657
  inputs=inputs_list + [text_output2],
 
662
  return gr.Textbox.update(value='')
663
 
664
  if kwargs['auto_score']:
665
+ score_args_submit = score_args
666
+ score_args2_submit = score_args2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  else:
668
+ score_args_submit = dict(fn=lambda: None, inputs=None, outputs=None)
669
+ score_args2_submit = dict(fn=lambda: None, inputs=None, outputs=None)
670
+
671
+ # in case 2nd model, consume instruction first, so can clear quickly
672
+ # bot doesn't consume instruction itself, just history from user, so why works
673
+ submit_event1a = instruction.submit(**user_args, queue=queue,
674
+ api_name='instruction' if allow_api else None)
675
+ submit_event1b = submit_event1a.then(**user_args2, api_name='instruction2' if allow_api else None)
676
+ submit_event1c = submit_event1b.then(clear_instruct, None, instruction) \
677
+ .then(clear_instruct, None, iinput)
678
+ submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None,
679
+ queue=queue)
680
+ submit_event1e = submit_event1d.then(**score_args_submit, api_name='instruction_bot_score' if allow_api else None,
681
+ queue=queue)
682
+ submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None,
683
+ queue=queue)
684
+ submit_event1g = submit_event1f.then(**score_args2_submit,
685
+ api_name='instruction_bot_score2' if allow_api else None, queue=queue)
686
+ submit_event1h = submit_event1g.then(clear_torch_cache)
687
+
688
+ submit_event2a = submit.click(**user_args, api_name='submit' if allow_api else None)
689
+ submit_event2b = submit_event2a.then(**user_args2, api_name='submit2' if allow_api else None)
690
+ submit_event2c = submit_event2b.then(clear_instruct, None, instruction) \
691
+ .then(clear_instruct, None, iinput)
692
+ submit_event2d = submit_event2c.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue)
693
+ submit_event2e = submit_event2d.then(**score_args_submit, api_name='submit_bot_score' if allow_api else None,
694
+ queue=queue)
695
+ submit_event2f = submit_event2e.then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue)
696
+ submit_event2g = submit_event2f.then(**score_args2_submit, api_name='submit_bot_score2' if allow_api else None,
697
+ queue=queue)
698
+ submit_event2h = submit_event2g.then(clear_torch_cache)
699
+
700
+ submit_event3a = retry.click(**user_args, api_name='retry' if allow_api else None)
701
+ submit_event3b = submit_event3a.then(**user_args2, api_name='retry2' if allow_api else None)
702
+ submit_event3c = submit_event3b.then(clear_instruct, None, instruction) \
703
+ .then(clear_instruct, None, iinput)
704
+ submit_event3d = submit_event3c.then(**retry_bot_args, api_name='retry_bot' if allow_api else None,
705
+ queue=queue)
706
+ submit_event3e = submit_event3d.then(**score_args_submit, api_name='retry_bot_score' if allow_api else None,
707
+ queue=queue)
708
+ submit_event3f = submit_event3e.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None,
709
+ queue=queue)
710
+ submit_event3g = submit_event3f.then(**score_args2_submit, api_name='retry_bot_score2' if allow_api else None,
711
+ queue=queue)
712
+ submit_event3h = submit_event3g.then(clear_torch_cache)
713
+
714
+ submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
715
+ .then(**undo_user_args2, api_name='undo2' if allow_api else None) \
716
+ .then(clear_instruct, None, instruction) \
717
+ .then(clear_instruct, None, iinput) \
718
+ .then(**score_args_submit, api_name='undo_score' if allow_api else None) \
719
+ .then(**score_args2_submit, api_name='undo_score2' if allow_api else None)
720
 
721
  # does both models
722
  clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
 
884
  api_name='system_info' if allow_api else None, queue=False)
885
 
886
  # don't pass text_output, don't want to clear output, just stop it
887
+ # cancel only stops outer generation, not inner generation or non-generation
888
  stop_btn.click(lambda: None, None, None,
889
+ cancels=[submit_event1d, submit_event1f,
890
+ submit_event2d, submit_event2f,
891
+ submit_event3d, submit_event3f,
892
+ submit_event_nochat],
893
  queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
894
  demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
895
 
 
910
 
911
 
912
  input_args_list = ['model_state']
913
+ inputs_kwargs_list = ['debug', 'save_dir', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
914
+ 'raise_generate_gpu_exceptions', 'chat_context', 'concurrency_count', 'lora_weights']
915
 
916
 
917
  def get_inputs_list(inputs_dict, model_lower):
utils.py CHANGED
@@ -259,3 +259,37 @@ def wrapped_partial(func, *args, **kwargs):
259
  partial_func = functools.partial(func, *args, **kwargs)
260
  functools.update_wrapper(partial_func, func)
261
  return partial_func
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  partial_func = functools.partial(func, *args, **kwargs)
260
  functools.update_wrapper(partial_func, func)
261
  return partial_func
262
+
263
+
264
+ class ThreadException(Exception):
265
+ pass
266
+
267
+
268
+ class EThread(threading.Thread):
269
+ # Function that raises the custom exception
270
+ def __init__(self, group=None, target=None, name=None,
271
+ args=(), kwargs=None, *, daemon=None, bucket=None):
272
+ self.bucket = bucket
273
+ self.streamer = kwargs.get('streamer')
274
+ self.exc = None
275
+ super().__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
276
+
277
+ def run(self):
278
+ # Variable that stores the exception, if raised by someFunction
279
+ try:
280
+ super().run()
281
+ except BaseException as e:
282
+ print("thread exception: %s" % str(sys.exc_info()))
283
+ self.bucket.put(sys.exc_info())
284
+ self.exc = e
285
+ if self.streamer:
286
+ print("make stop: %s" % str(sys.exc_info()), flush=True)
287
+ self.streamer.do_stop = True
288
+
289
+ def join(self, timeout=None):
290
+ threading.Thread.join(self)
291
+ # Since join() returns in caller thread
292
+ # we re-raise the caught exception
293
+ # if any was caught
294
+ if self.exc:
295
+ raise self.exc