yentinglin commited on
Commit
4d31b4c
·
1 Parent(s): cff2810

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -108
app.py CHANGED
@@ -1,29 +1,125 @@
1
- import time
2
  import os
3
  import gradio as gr
4
  from text_generation import Client
5
  from conversation import get_default_conv_template
6
  from transformers import AutoTokenizer
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  endpoint_url = os.environ.get("ENDPOINT_URL", "http://127.0.0.1:8080")
10
  client = Client(endpoint_url, timeout=120)
11
  eos_token = "</s>"
12
- max_prompt_length = 4096 - 512 - 10
 
 
 
13
 
14
  tokenizer = AutoTokenizer.from_pretrained("yentinglin/Taiwan-LLaMa-v1.0")
15
 
16
  with gr.Blocks() as demo:
 
 
17
  chatbot = gr.Chatbot()
18
- msg = gr.Textbox()
19
- clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def user(user_message, history):
22
  return "", history + [[user_message, None]]
23
 
24
- def bot(history):
 
25
  conv = get_default_conv_template("vicuna").copy()
26
  roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
 
27
  for user, bot in history:
28
  conv.append_message(roles['human'], user)
29
  conv.append_message(roles["gpt"], bot)
@@ -31,123 +127,105 @@ with gr.Blocks() as demo:
31
  prompt_tokens = tokenizer.encode(msg)
32
  length_of_prompt = len(prompt_tokens)
33
  if length_of_prompt > max_prompt_length:
34
- msg = tokenizer.decode(prompt_tokens[-max_prompt_length+1:])
35
 
36
  history[-1][1] = ""
37
  for response in client.generate_stream(
38
  msg,
39
- max_new_tokens=512,
 
 
 
40
  ):
41
  if not response.token.special:
42
  character = response.token.text
43
  history[-1][1] += character
44
  yield history
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- def generate_response(history, max_new_token=512, top_p=0.9, temperature=0.8, do_sample=True):
48
- conv = get_default_conv_template("vicuna").copy()
49
- roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
50
- for user, bot in history:
51
- conv.append_message(roles['human'], user)
52
- conv.append_message(roles["gpt"], bot)
53
- msg = conv.get_prompt()
54
 
55
- for response in client.generate_stream(
56
- msg,
57
- max_new_tokens=max_new_token,
58
- top_p=top_p,
59
- temperature=temperature,
60
- do_sample=do_sample,
61
- ):
62
- history[-1][1] = ""
63
- # if not response.token.special:
64
- character = response.token.text
65
- history[-1][1] += character
66
- print(history[-1][1])
67
- time.sleep(0.05)
68
- yield history
69
 
70
 
71
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
72
- bot, chatbot, chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  clear.click(lambda: None, None, chatbot, queue=False)
75
-
76
- demo.queue()
77
- demo.launch()
78
 
79
- #
80
- # with gr.Blocks() as demo:
81
- # chatbot = gr.Chatbot()
82
- # with gr.Row():
83
- # with gr.Column(scale=4):
84
- # with gr.Column(scale=12):
85
- # user_input = gr.Textbox(
86
- # show_label=False,
87
- # placeholder="Shift + Enter傳送...",
88
- # lines=10).style(
89
- # container=False)
90
- # with gr.Column(min_width=32, scale=1):
91
- # submitBtn = gr.Button("Submit", variant="primary")
92
- # with gr.Column(scale=1):
93
- # emptyBtn = gr.Button("Clear History")
94
- # max_new_token = gr.Slider(
95
- # 1,
96
- # 1024,
97
- # value=128,
98
- # step=1.0,
99
- # label="Maximum New Token Length",
100
- # interactive=True)
101
- # top_p = gr.Slider(0, 1, value=0.9, step=0.01,
102
- # label="Top P", interactive=True)
103
- # temperature = gr.Slider(
104
- # 0,
105
- # 1,
106
- # value=0.5,
107
- # step=0.01,
108
- # label="Temperature",
109
- # interactive=True)
110
- # top_k = gr.Slider(1, 40, value=40, step=1,
111
- # label="Top K", interactive=True)
112
- # do_sample = gr.Checkbox(
113
- # value=True,
114
- # label="Do Sample",
115
- # info="use random sample strategy",
116
- # interactive=True)
117
- # repetition_penalty = gr.Slider(
118
- # 1.0,
119
- # 3.0,
120
- # value=1.1,
121
- # step=0.1,
122
- # label="Repetition Penalty",
123
- # interactive=True)
124
- #
125
- # params = [user_input, chatbot]
126
- # predict_params = [
127
- # chatbot,
128
- # max_new_token,
129
- # top_p,
130
- # temperature,
131
- # top_k,
132
- # do_sample,
133
- # repetition_penalty]
134
- #
135
- # submitBtn.click(
136
- # generate_response,
137
- # [user_input, max_new_token, top_p, top_k, temperature, do_sample, repetition_penalty],
138
- # [chatbot],
139
- # queue=False
140
- # )
141
- #
142
- # user_input.submit(
143
- # generate_response,
144
- # [user_input, max_new_token, top_p, top_k, temperature, do_sample, repetition_penalty],
145
- # [chatbot],
146
- # queue=False
147
- # )
148
- #
149
- # submitBtn.click(lambda: None, [], [user_input])
150
- #
151
- # emptyBtn.click(lambda: chatbot.reset(), outputs=[chatbot], show_progress=True)
152
- #
153
- # demo.launch()
 
 
1
  import os
2
  import gradio as gr
3
  from text_generation import Client
4
  from conversation import get_default_conv_template
5
  from transformers import AutoTokenizer
6
+ DESCRIPTION = """
7
+ # Language Models for Taiwanese Culture
8
 
9
+ Taiwan-LLaMa is a fine-tuned model specifically designed for traditional Chinese applications. It is built upon the LLaMa 2 architecture and includes a pretraining phase with over 5 billion tokens and fine-tuning with over 490k multi-turn conversational data in Traditional Chinese.
10
+
11
+ ## Key Features
12
+
13
+ 1. **Traditional Chinese Support**: The model is fine-tuned to understand and generate text in Traditional Chinese, making it suitable for Taiwanese culture and related applications.
14
+
15
+ 2. **Instruction-Tuned**: Further fine-tuned on conversational data to offer context-aware and instruction-following responses.
16
+
17
+ 3. **Performance on Vicuna Benchmark**: Taiwan-LLaMa's relative performance on Vicuna Benchmark is measured against models like GPT-4 and ChatGPT. It's particularly optimized for Taiwanese culture.
18
+
19
+ 4. **Flexible Customization**: Advanced options for controlling the model's behavior like system prompt, temperature, top-p, and top-k are available in the demo.
20
+
21
+ ## Model Versions
22
+
23
+ Different versions of Taiwan-LLaMa are available:
24
+
25
+ - **Taiwan-LLaMa v1.0 (This demo)**: Optimized for Taiwanese Culture
26
+ - **Taiwan-LLaMa v0.9**: Partial instruction set
27
+ - **Taiwan-LLaMa v0.0**: No Traditional Chinese pretraining
28
+
29
+ The models can be accessed from the provided links in the Hugging Face repository.
30
+
31
+ Try out the demo to interact with Taiwan-LLaMa and experience its capabilities in handling Traditional Chinese!
32
+ """
33
+
34
+ LICENSE = """
35
+ ## Licenses
36
+
37
+ - Code is licensed under Apache 2.0 License.
38
+ - Models are licensed under the LLAMA 2 Community License.
39
+ - By using this model, you agree to the terms and conditions specified in the license.
40
+ - By using this demo, you agree to share your input utterances with us to improve the model.
41
+
42
+ ## Acknowledgements
43
+
44
+ Taiwan-LLaMa project acknowledges the efforts of the [Meta LLaMa team](https://github.com/facebookresearch/llama) and [Vicuna team](https://github.com/lm-sys/FastChat) in democratizing large language models.
45
+ """
46
+
47
+ DEFAULT_SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. You are built by NTU Miulab by Yen-Ting Lin for research purpose."
48
 
49
  endpoint_url = os.environ.get("ENDPOINT_URL", "http://127.0.0.1:8080")
50
  client = Client(endpoint_url, timeout=120)
51
  eos_token = "</s>"
52
+ MAX_MAX_NEW_TOKENS = 1024
53
+ DEFAULT_MAX_NEW_TOKENS = 1024
54
+
55
+ max_prompt_length = 4096 - MAX_MAX_NEW_TOKENS - 10
56
 
57
  tokenizer = AutoTokenizer.from_pretrained("yentinglin/Taiwan-LLaMa-v1.0")
58
 
59
  with gr.Blocks() as demo:
60
+ gr.Markdown(DESCRIPTION)
61
+
62
  chatbot = gr.Chatbot()
63
+ with gr.Row():
64
+ msg = gr.Textbox(
65
+ container=False,
66
+ show_label=False,
67
+ placeholder='Type a message...',
68
+ scale=10,
69
+ )
70
+ submit_button = gr.Button('Submit',
71
+ variant='primary',
72
+ scale=1,
73
+ min_width=0)
74
+
75
+ with gr.Row():
76
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
77
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
78
+ clear = gr.Button('🗑️ Clear', variant='secondary')
79
+
80
+ saved_input = gr.State()
81
+
82
+ with gr.Accordion(label='Advanced options', open=False):
83
+ system_prompt = gr.Textbox(label='System prompt',
84
+ value=DEFAULT_SYSTEM_PROMPT,
85
+ lines=6)
86
+ max_new_tokens = gr.Slider(
87
+ label='Max new tokens',
88
+ minimum=1,
89
+ maximum=MAX_MAX_NEW_TOKENS,
90
+ step=1,
91
+ value=DEFAULT_MAX_NEW_TOKENS,
92
+ )
93
+ temperature = gr.Slider(
94
+ label='Temperature',
95
+ minimum=0.1,
96
+ maximum=1.0,
97
+ step=0.1,
98
+ value=0.7,
99
+ )
100
+ top_p = gr.Slider(
101
+ label='Top-p (nucleus sampling)',
102
+ minimum=0.05,
103
+ maximum=1.0,
104
+ step=0.05,
105
+ value=0.9,
106
+ )
107
+ top_k = gr.Slider(
108
+ label='Top-k',
109
+ minimum=1,
110
+ maximum=1000,
111
+ step=1,
112
+ value=50,
113
+ )
114
 
115
  def user(user_message, history):
116
  return "", history + [[user_message, None]]
117
 
118
+
119
+ def bot(history, max_new_tokens, temperature, top_p, top_k, system_prompt):
120
  conv = get_default_conv_template("vicuna").copy()
121
  roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
122
+ conv.system = system_prompt
123
  for user, bot in history:
124
  conv.append_message(roles['human'], user)
125
  conv.append_message(roles["gpt"], bot)
 
127
  prompt_tokens = tokenizer.encode(msg)
128
  length_of_prompt = len(prompt_tokens)
129
  if length_of_prompt > max_prompt_length:
130
+ msg = tokenizer.decode(prompt_tokens[-max_prompt_length + 1:])
131
 
132
  history[-1][1] = ""
133
  for response in client.generate_stream(
134
  msg,
135
+ max_new_tokens=max_new_tokens,
136
+ temperature=temperature,
137
+ top_p=top_p,
138
+ top_k=top_k,
139
  ):
140
  if not response.token.special:
141
  character = response.token.text
142
  history[-1][1] += character
143
  yield history
144
 
145
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
146
+ fn=bot,
147
+ inputs=[
148
+ chatbot,
149
+ max_new_tokens,
150
+ temperature,
151
+ top_p,
152
+ top_k,
153
+ system_prompt,
154
+ ],
155
+ outputs=chatbot
156
+ )
157
+ submit_button.click(
158
+ user, [msg, chatbot], [msg, chatbot], queue=False
159
+ ).then(
160
+ fn=bot,
161
+ inputs=[
162
+ chatbot,
163
+ max_new_tokens,
164
+ temperature,
165
+ top_p,
166
+ top_k,
167
+ system_prompt,
168
+ ],
169
+ outputs=chatbot
170
+ )
171
 
 
 
 
 
 
 
 
172
 
173
+ def delete_prev_fn(
174
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
175
+ try:
176
+ message, _ = history.pop()
177
+ except IndexError:
178
+ message = ''
179
+ return history, message or ''
 
 
 
 
 
 
 
180
 
181
 
182
+ def display_input(message: str,
183
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
184
+ history.append((message, ''))
185
+ return history
186
+
187
+ retry_button.click(
188
+ fn=delete_prev_fn,
189
+ inputs=chatbot,
190
+ outputs=[chatbot, saved_input],
191
+ api_name=False,
192
+ queue=False,
193
+ ).then(
194
+ fn=display_input,
195
+ inputs=[saved_input, chatbot],
196
+ outputs=chatbot,
197
+ api_name=False,
198
+ queue=False,
199
+ ).then(
200
+ fn=bot,
201
+ inputs=[
202
+ chatbot,
203
+ max_new_tokens,
204
+ temperature,
205
+ top_p,
206
+ top_k,
207
+ system_prompt,
208
+ ],
209
+ outputs=chatbot,
210
  )
211
+
212
+ undo_button.click(
213
+ fn=delete_prev_fn,
214
+ inputs=chatbot,
215
+ outputs=[chatbot, saved_input],
216
+ api_name=False,
217
+ queue=False,
218
+ ).then(
219
+ fn=lambda x: x,
220
+ inputs=[saved_input],
221
+ outputs=msg,
222
+ api_name=False,
223
+ queue=False,
224
+ )
225
+
226
  clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
227
 
228
+ gr.Markdown(LICENSE)
229
+
230
+ demo.queue(max_size=128)
231
+ demo.launch()