ShieldX commited on
Commit
22da7c8
1 Parent(s): 804edcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -21,7 +21,7 @@ examples = ["I have been feeling more and more down for over a month. I have sta
21
 
22
  class StopOnTokens(StoppingCriteria):
23
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
24
- stop_ids = [29, 0]
25
  for stop_id in stop_ids:
26
  if input_ids[0][-1] == stop_id:
27
  return True
@@ -36,9 +36,16 @@ def predict(message, history):
36
 
37
  messages = "".join(["".join([sys_msg + "\n###USER:"+item[0], "\n###ASSISTANT:"+item[1]]) #curr_system_message +
38
  for item in history_transformer_format])
 
 
 
 
 
 
 
39
 
40
  model_inputs = tokenizer([messages], return_tensors="pt").to(device)
41
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
42
  generate_kwargs = dict(
43
  model_inputs,
44
  streamer=streamer,
@@ -46,8 +53,10 @@ def predict(message, history):
46
  do_sample=True,
47
  top_p=0.95,
48
  top_k=1000,
49
- temperature=1.0,
50
  num_beams=1,
 
 
51
  stopping_criteria=StoppingCriteriaList([stop])
52
  )
53
  t = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -55,13 +64,12 @@ def predict(message, history):
55
 
56
  partial_message = ""
57
  for new_token in streamer:
58
- if new_token != '#':
 
 
 
59
  partial_message += new_token
60
  yield partial_message
61
- else:
62
- print("new token = #")
63
- partial_message += new_token
64
- yield partial_message
65
 
66
 
67
  gr.ChatInterface(
 
21
 
22
  class StopOnTokens(StoppingCriteria):
23
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
24
+ stop_ids = [1, 2]
25
  for stop_id in stop_ids:
26
  if input_ids[0][-1] == stop_id:
27
  return True
 
36
 
37
  messages = "".join(["".join([sys_msg + "\n###USER:"+item[0], "\n###ASSISTANT:"+item[1]]) #curr_system_message +
38
  for item in history_transformer_format])
39
+
40
+ # def format_prompt(q):
41
+ # return f"""{sys_msg}
42
+ # ###USER: {q}
43
+ # ###ASSISTANT:"""
44
+
45
+ # messages = format_prompt(message)
46
 
47
  model_inputs = tokenizer([messages], return_tensors="pt").to(device)
48
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=False)
49
  generate_kwargs = dict(
50
  model_inputs,
51
  streamer=streamer,
 
53
  do_sample=True,
54
  top_p=0.95,
55
  top_k=1000,
56
+ temperature=0.2,
57
  num_beams=1,
58
+ eos_token_id=[tokenizer.eos_token_id],
59
+ pad_token_id=tokenizer.eos_token_id,
60
  stopping_criteria=StoppingCriteriaList([stop])
61
  )
62
  t = Thread(target=model.generate, kwargs=generate_kwargs)
 
64
 
65
  partial_message = ""
66
  for new_token in streamer:
67
+ if new_token != '<':
68
+ # if "#" in new_token:
69
+ # break
70
+ # else:
71
  partial_message += new_token
72
  yield partial_message
 
 
 
 
73
 
74
 
75
  gr.ChatInterface(