Zenithwang commited on
Commit
34d79f8
1 Parent(s): 0298010

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -34
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
  from threading import Thread
 
6
 
7
  model_path = 'infly/OpenCoder-8B-Instruct'
8
 
@@ -43,42 +44,46 @@ system_prompt = f"<|im_start|>{system_role}\n{system_prompt}<|im_end|>"
43
  def predict(message, history):
44
  # history = []
45
  # history_transformer_format = history + [[message, ""]]
46
- stop = StopOnTokens()
47
-
48
- # Formatting the input for the model.
49
- # messages = system_prompt + sft_end_token.join([sft_end_token.join([f"\n{sft_start_token}{user_role}\n" + item[0], f"\n{sft_start_token}{assistant_role}\n" + item[1]])
50
- # for item in history_transformer_format])
51
-
52
- model_messages = []
53
- print(f'history: {history}')
54
- for i, item in enumerate(history):
55
- model_messages.append({"role": user_role, "content": item[0]})
56
- model_messages.append({"role": assistant_role, "content": item[1]})
57
-
58
- model_messages.append({"role": user_role, "content": message})
59
 
60
- print(f'model_messages: {model_messages}')
61
-
62
- print(f'model_final_inputs: {tokenizer.apply_chat_template(model_messages, add_generation_prompt=True, tokenize=False)}', flash=True)
63
- model_inputs = tokenizer.apply_chat_template(model_messages, add_generation_prompt=True, return_tensors="pt").to(device)
64
- # model_inputs = tokenizer([messages], return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
65
 
66
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
67
- generate_kwargs = dict(
68
- model_inputs,
69
- streamer=streamer,
70
- max_new_tokens=1024,
71
- do_sample=False,
72
- # stopping_criteria=StoppingCriteriaList([stop])
73
- )
74
- t = Thread(target=model.generate, kwargs=generate_kwargs)
75
- t.start() # Starting the generation in a separate thread.
76
- partial_message = ""
77
- for new_token in streamer:
78
- partial_message += new_token
79
- if sft_end_token in partial_message: # Breaking the loop if the stop token is generated.
80
- break
81
- yield partial_message
 
 
 
 
 
 
 
82
 
83
 
84
  css = """
 
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
  from threading import Thread
6
+ import traceback
7
 
8
  model_path = 'infly/OpenCoder-8B-Instruct'
9
 
 
44
  def predict(message, history):
45
  # history = []
46
  # history_transformer_format = history + [[message, ""]]
47
+ try:
48
+ stop = StopOnTokens()
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Formatting the input for the model.
51
+ # messages = system_prompt + sft_end_token.join([sft_end_token.join([f"\n{sft_start_token}{user_role}\n" + item[0], f"\n{sft_start_token}{assistant_role}\n" + item[1]])
52
+ # for item in history_transformer_format])
53
+
54
+ model_messages = []
55
+ print(f'history: {history}')
56
+ for i, item in enumerate(history):
57
+ model_messages.append({"role": user_role, "content": item[0]})
58
+ model_messages.append({"role": assistant_role, "content": item[1]})
59
+
60
+ model_messages.append({"role": user_role, "content": message})
61
+
62
+ print(f'model_messages: {model_messages}')
63
 
64
+ print(f'model_final_inputs: {tokenizer.apply_chat_template(model_messages, add_generation_prompt=True, tokenize=False)}', flush=True)
65
+ model_inputs = tokenizer.apply_chat_template(model_messages, add_generation_prompt=True, return_tensors="pt").to(device)
66
+ # model_inputs = tokenizer([messages], return_tensors="pt").to(device)
67
+
68
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
69
+ generate_kwargs = dict(
70
+ model_inputs,
71
+ streamer=streamer,
72
+ max_new_tokens=1024,
73
+ do_sample=False,
74
+ # stopping_criteria=StoppingCriteriaList([stop])
75
+ )
76
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
77
+ t.start() # Starting the generation in a separate thread.
78
+ partial_message = ""
79
+ for new_token in streamer:
80
+ partial_message += new_token
81
+ if sft_end_token in partial_message: # Breaking the loop if the stop token is generated.
82
+ break
83
+ yield partial_message
84
+
85
+ except Exception as e:
86
+ print(traceback.format_exc())
87
 
88
 
89
  css = """