Tmawn commited on
Commit
c5a8908
·
verified ·
1 Parent(s): 97da3ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -5,7 +5,7 @@ from huggingface_hub import InferenceClient
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
 
10
  def respond(
11
  message,
@@ -14,6 +14,9 @@ def respond(
14
  max_tokens,
15
  temperature,
16
  top_p,
 
 
 
17
  ):
18
  messages = [{"role": "system", "content": system_message}]
19
 
@@ -33,12 +36,16 @@ def respond(
33
  stream=True,
34
  temperature=temperature,
35
  top_p=top_p,
 
 
 
36
  ):
37
  token = message.choices[0].delta.content
38
 
39
  response += token
40
  yield response
41
 
 
42
  """
43
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
  """
@@ -55,9 +62,12 @@ demo = gr.ChatInterface(
55
  step=0.05,
56
  label="Top-p (nucleus sampling)",
57
  ),
 
 
 
58
  ],
59
  )
60
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
 
10
  def respond(
11
  message,
 
14
  max_tokens,
15
  temperature,
16
  top_p,
17
+ repetition_penalty,
18
+ top_k,
19
+ truncate,
20
  ):
21
  messages = [{"role": "system", "content": system_message}]
22
 
 
36
  stream=True,
37
  temperature=temperature,
38
  top_p=top_p,
39
+ repetition_penalty=repetition_penalty,
40
+ top_k=top_k,
41
+ truncate=truncate,
42
  ):
43
  token = message.choices[0].delta.content
44
 
45
  response += token
46
  yield response
47
 
48
+
49
  """
50
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
51
  """
 
62
  step=0.05,
63
  label="Top-p (nucleus sampling)",
64
  ),
65
+ gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty"),
66
+ gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
67
+ gr.Slider(minimum=1, maximum=2000, value=1000, step=1, label="Truncate"),
68
  ],
69
  )
70
 
71
 
72
  if __name__ == "__main__":
73
+ demo.launch()