savage1221 commited on
Commit
e21b443
Β·
verified Β·
1 Parent(s): 19a6bcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -394
app.py CHANGED
@@ -1,401 +1,46 @@
1
- import os
2
-
3
- from transformers import AutoTokenizer
4
- # from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig
5
- # from optimum.intel.openvino import OVModelForCausalLM
6
- from transformers import AutoConfig, AutoTokenizer
7
  import gradio as gr
8
- import time
9
- from threading import Thread
10
-
11
- from transformers import (
12
- TextIteratorStreamer,
13
- StoppingCriteria,
14
- StoppingCriteriaList,
15
- GenerationConfig,
16
- )
17
- # model_name = "openai-community/gpt2-large"
18
- # model_dir = "F:\\phi3\\openvinomodel\\phi3\\int4"
19
- # model_name = "savage1221/lora-fine"
20
- # save_name = model_name.split("/")[-1] + "_openvino"
21
- # precision = "f32"
22
-
23
-
24
- # quantization_config = OVWeightQuantizationConfig(
25
- # bits=4,
26
- # sym=False,
27
- # group_size=128,
28
- # ratio=0.6,
29
- # trust_remote_code=True,
30
- # )
31
-
32
- # ov_config = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}
33
-
34
- # device = "gpu"
35
-
36
-
37
- # load_kwargs = {
38
- # "device": device,
39
- # "ov_config": {
40
- # "PERFORMANCE_HINT": "LATENCY",
41
- # # "INFERENCE_PRECISION_HINT": precision,
42
- # "CACHE_DIR": os.path.join(save_name, "model_cache"), # OpenVINO will use this directory as cache
43
- # },
44
- # "compile": False,
45
- # "quantization_config": quantization_config,
46
- # "trust_remote_code": True,
47
- # # ov_config = ov_config
48
- # }
49
-
50
- # # Check whether the model was already exported
51
- # saved = os.path.exists(save_name)
52
-
53
- # model = OVModelForCausalLM.from_pretrained(
54
- # # model_name
55
- # model_name if not saved else save_name,
56
- # export=not saved,
57
- # **load_kwargs,
58
- # )
59
- # model = OVModelForCausalLM.from_pretrained(
60
- # model_name,
61
- # device='GPU.0',
62
- # ov_config=ov_config,
63
- # config=AutoConfig.from_pretrained(model_name, trust_remote_code=True),
64
- # trust_remote_code=True,
65
- # )
66
-
67
- # # Load tokenizer to be used with the model
68
- # tokenizer = AutoTokenizer.from_pretrained(model_name if not saved else save_name)
69
- # tokenizer = AutoTokenizer.from_pretrained(model_name )
70
-
71
- # # Save the exported model locally
72
- # if not saved:
73
- # model.save_pretrained(save_name)
74
- # tokenizer.save_pretrained(save_name)
75
-
76
- # # TODO Optional: export to huggingface/hub
77
-
78
- # model_size = os.stat(os.path.join(save_name, "openvino_model.bin")).st_size / 1024 ** 3
79
- # print(f'Model size in FP32: ~5.4GB, current model size in 4bit: {model_size:.2f}GB')
80
-
81
- #####################################################################
82
-
83
- # Load model directly
84
- from transformers import AutoTokenizer, AutoModelForCausalLM
85
-
86
- tokenizer = AutoTokenizer.from_pretrained("savage1221/lora-fine", trust_remote_code=True)
87
- model = AutoModelForCausalLM.from_pretrained("savage1221/lora-fine", trust_remote_code=True)
88
-
89
-
90
- # Copied and modified from https://github.com/bigcode-project/bigcode-evaluation-harness/blob/main/bigcode_eval/generation.py#L13
91
- class SuffixCriteria(StoppingCriteria):
92
- def __init__(self, start_length, eof_strings, tokenizer, check_fn=None):
93
- self.start_length = start_length
94
- self.eof_strings = eof_strings
95
- self.tokenizer = tokenizer
96
- if check_fn is None:
97
- check_fn = lambda decoded_generation: any(
98
- [decoded_generation.endswith(stop_string) for stop_string in self.eof_strings]
99
- )
100
- self.check_fn = check_fn
101
-
102
- def __call__(self, input_ids, scores, **kwargs):
103
- """Returns True if generated sequence ends with any of the stop strings"""
104
- decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
105
- return all([self.check_fn(decoded_generation) for decoded_generation in decoded_generations])
106
-
107
-
108
- def is_partial_stop(output, stop_str):
109
- """Check whether the output contains a partial stop str."""
110
- for i in range(0, min(len(output), len(stop_str))):
111
- if stop_str.startswith(output[-i:]):
112
- return True
113
- return False
114
-
115
-
116
-
117
- # Set the chat template to the tokenizer. The chat template implements the simple template of
118
- # User: content
119
- # Assistant: content
120
- # ...
121
- # Read more about chat templates here https://huggingface.co/docs/transformers/main/en/chat_templating
122
- tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
123
-
124
-
125
- # def prepare_history_for_model(history):
126
- # """
127
- # Converts the history to a tokenized prompt in the format expected by the model.
128
- # Params:
129
- # history: dialogue history
130
- # Returns:
131
- # Tokenized prompt
132
- # """
133
- # messages = []
134
- # for idx, (user_msg, model_msg) in enumerate(history):
135
- # # skip the last assistant message if its empty, the tokenizer will do the formating
136
- # if idx == len(history) - 1 and not model_msg:
137
- # messages.append({"role": "User", "content": user_msg})
138
- # break
139
- # if user_msg:
140
- # messages.append({"role": "User", "content": user_msg})
141
- # if model_msg:
142
- # messages.append({"role": "Assistant", "content": model_msg})
143
- # input_token = tokenizer.apply_chat_template(
144
- # messages,
145
- # add_generation_prompt=True,
146
- # tokenize=True,
147
- # return_tensors="pt",
148
- # return_dict=True
149
- # )
150
- # return input_token
151
-
152
 
 
153
 
154
- def prepare_history_for_model(history):
155
- """
156
- Converts the history to a tokenized prompt in the format expected by the model.
157
- Params:
158
- history: dialogue history
159
- Returns:
160
- Tokenized prompt
161
- """
162
- messages = []
163
-
164
- # Add instruction
165
- instruction = "Generate quotes for AWS RDS services"
166
- messages.append({"role": "Instruction", "content": instruction})
167
-
168
- for idx, (user_msg, model_msg) in enumerate(history):
169
- # Assuming the user message contains the product information
170
- if user_msg:
171
- messages.append({"role": "Input", "content": user_msg})
172
-
173
- # Skip the last assistant message if it's empty
174
- if idx == len(history) - 1 and not model_msg:
175
- break
176
-
177
- if model_msg:
178
- messages.append({"role": "Output", "content": model_msg})
179
-
180
- input_token = tokenizer.apply_chat_template(
181
- messages,
182
- add_generation_prompt=True,
183
- tokenize=True,
184
- return_tensors="pt",
185
- return_dict=True
186
- )
187
- return input_token
188
-
189
-
190
-
191
- def generate(history, temperature, max_new_tokens, top_p, repetition_penalty, assisted):
192
- """
193
- Generates the assistant's reponse given the chatbot history and generation parameters
194
-
195
- Params:
196
- history: conversation history formated in pairs of user and assistant messages `[user_message, assistant_message]`
197
- temperature: parameter for control the level of creativity in AI-generated text.
198
- By adjusting the `temperature`, you can influence the AI model's probability distribution, making the text more focused or diverse.
199
- max_new_tokens: The maximum number of tokens we allow the model to generate as a response.
200
- top_p: parameter for control the range of tokens considered by the AI model based on their cumulative probability.
201
- repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text.
202
- assisted: boolean parameter to enable/disable assisted generation with speculative decoding.
203
- Yields:
204
- Updated history and generation status.
205
- """
206
- start = time.perf_counter()
207
- # Construct the input message string for the model by concatenating the current system message and conversation history
208
- # Tokenize the messages string
209
- inputs = prepare_history_for_model(history)
210
- input_length = inputs['input_ids'].shape[1]
211
- # truncate input in case it is too long.
212
- # TODO improve this
213
- if input_length > 2000:
214
- history = [history[-1]]
215
- inputs = prepare_history_for_model(history)
216
- input_length = inputs['input_ids'].shape[1]
217
-
218
- prompt_char = "β–Œ"
219
- history[-1][1] = prompt_char
220
- yield history, "Status: Generating...", *([gr.update(interactive=False)] * 4)
221
-
222
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
223
-
224
- # Create a stopping criteria to prevent the model from playing the role of the user aswell.
225
- stop_str = ["\nUser:", "\nAssistant:", "\nRules:", "\nQuestion:"]
226
- stopping_criteria = StoppingCriteriaList([SuffixCriteria(input_length, stop_str, tokenizer)])
227
- # Prepare input for generate
228
- generation_config = GenerationConfig(
229
- max_new_tokens=max_new_tokens,
230
- do_sample=temperature > 0.0,
231
- temperature=temperature if temperature > 0.0 else 1.0,
232
- repetition_penalty=repetition_penalty,
233
- top_p=top_p,
234
- eos_token_id=[tokenizer.eos_token_id],
235
- pad_token_id=tokenizer.eos_token_id,
236
- )
237
- generate_kwargs = dict(
238
- streamer=streamer,
239
- generation_config=generation_config,
240
- stopping_criteria=stopping_criteria,
241
- ) | inputs
242
-
243
- if assisted:
244
- target_generate = stateless_model.generate
245
- generate_kwargs["assistant_model"] = asst_model
246
- else:
247
- target_generate = model.generate
248
-
249
- t1 = Thread(target=target_generate, kwargs=generate_kwargs)
250
- t1.start()
251
-
252
- # Initialize an empty string to store the generated text.
253
- partial_text = ""
254
- for new_text in streamer:
255
- partial_text += new_text
256
- history[-1][1] = partial_text + prompt_char
257
- for s in stop_str:
258
- if (pos := partial_text.rfind(s)) != -1:
259
- break
260
- if pos != -1:
261
- partial_text = partial_text[:pos]
262
- break
263
- elif any([is_partial_stop(partial_text, s) for s in stop_str]):
264
- continue
265
- yield history, "Status: Generating...", *([gr.update(interactive=False)] * 4)
266
- history[-1][1] = partial_text
267
- generation_time = time.perf_counter() - start
268
- yield history, f'Generation time: {generation_time:.2f} sec', *([gr.update(interactive=True)] * 4)
269
-
270
-
271
- #############################################################
272
-
273
-
274
- # model.compile()
275
-
276
-
277
- try:
278
- demo.close()
279
- except:
280
- pass
281
-
282
-
283
- EXAMPLES = [
284
- ["What is OpenVINO?"],
285
- ["Can you explain to me briefly what is Python programming language?"],
286
- ["Explain the plot of Cinderella in a sentence."],
287
- ["Write a Python function to perform binary search over a sorted list. Use markdown to write code"],
288
- ["Lily has a rubber ball that she drops from the top of a wall. The wall is 2 meters tall. How long will it take for the ball to reach the ground?"],
289
- ]
290
-
291
-
292
- def add_user_text(message, history):
293
- """
294
- Add user's message to chatbot history
295
-
296
- Params:
297
- message: current user message
298
- history: conversation history
299
- Returns:
300
- Updated history, clears user message and status
301
- """
302
- # Append current user message to history with a blank assistant message which will be generated by the model
303
- history.append([message, None])
304
- return ('', history)
305
-
306
-
307
- def prepare_for_regenerate(history):
308
- """
309
- Delete last assistant message to prepare for regeneration
310
-
311
- Params:
312
- history: conversation history
313
- Returns:
314
- updated history
315
- """
316
- history[-1][1] = None
317
- return history
318
-
319
 
320
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
321
- gr.Markdown('<h1 style="text-align: center;">Chat with Phi-3 on Meteor Lake iGPU</h1>')
322
- chatbot = gr.Chatbot()
323
- with gr.Row():
324
- assisted = gr.Checkbox(value=False, label="Assisted Generation", scale=10)
325
- msg = gr.Textbox(placeholder="Enter message here...", show_label=False, autofocus=True, scale=75)
326
- status = gr.Textbox("Status: Idle", show_label=False, max_lines=1, scale=15)
327
- with gr.Row():
328
- submit = gr.Button("Submit", variant='primary')
329
- regenerate = gr.Button("Regenerate")
330
- clear = gr.Button("Clear")
331
- with gr.Accordion("Advanced Options:", open=False):
332
- with gr.Row():
333
- with gr.Column():
334
- temperature = gr.Slider(
335
- label="Temperature",
336
- value=0.0,
337
- minimum=0.0,
338
- maximum=1.0,
339
- step=0.05,
340
- interactive=True,
341
- )
342
- max_new_tokens = gr.Slider(
343
- label="Max new tokens",
344
- value=512,
345
- minimum=0,
346
- maximum=1024,
347
- step=32,
348
- interactive=True,
349
- )
350
- with gr.Column():
351
- top_p = gr.Slider(
352
- label="Top-p (nucleus sampling)",
353
- value=1.0,
354
- minimum=0.0,
355
- maximum=1.0,
356
- step=0.05,
357
- interactive=True,
358
- )
359
- repetition_penalty = gr.Slider(
360
- label="Repetition penalty",
361
- value=1.0,
362
- minimum=1.0,
363
- maximum=2.0,
364
- step=0.1,
365
- interactive=True,
366
- )
367
- gr.Examples(
368
- EXAMPLES, inputs=msg, label="Click on any example and press the 'Submit' button"
369
- )
370
 
371
- # Sets generate function to be triggered when the user submit a new message
372
- gr.on(
373
- triggers=[submit.click, msg.submit],
374
- fn=add_user_text,
375
- inputs=[msg, chatbot],
376
- outputs=[msg, chatbot],
377
- queue=False,
378
- ).then(
379
- fn=generate,
380
- inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],
381
- outputs=[chatbot, status, msg, submit, regenerate, clear],
382
- concurrency_limit=1,
383
- queue=True
384
- )
385
- regenerate.click(
386
- fn=prepare_for_regenerate,
387
- inputs=chatbot,
388
- outputs=chatbot,
389
- queue=True,
390
- concurrency_limit=1
391
- ).then(
392
- fn=generate,
393
- inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],
394
- outputs=[chatbot, status, msg, submit, regenerate, clear],
395
- concurrency_limit=1,
396
- queue=True
397
- )
398
- clear.click(fn=lambda: (None, "Status: Idle"), inputs=None, outputs=[chatbot, status], queue=False)
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
- demo.launch()
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
 
 
 
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ torch.random.manual_seed(0)
6
 
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ "savage1221/lora-fine",
9
+ device_map="cuda",
10
+ torch_dtype="auto",
11
+ trust_remote_code=True,
12
+ )
13
+ tokenizer = AutoTokenizer.from_pretrained("savage1221/lora-fine",trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ instruction = "Generate quotes for AWS RDS services"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ pipe = pipeline(
18
+ "text-generation",
19
+ model=model,
20
+ tokenizer=tokenizer,
21
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ generation_args = {
24
+ "max_new_tokens": 500,
25
+ "return_full_text": False,
26
+ "temperature": 0.9,
27
+ "do_sample": True,
28
+ "top_k": 50,
29
+ "top_p": 0.95,
30
+ "num_return_sequences": 1,
31
+ }
32
+
33
+ def predict_price(input_data):
34
+ prompt = f"{instruction}\nInput: {input_data}\nOutput:"
35
+ output = pipe(prompt, **generation_args)
36
+ return output[0]['generated_text']
37
+
38
+ interface = gr.Interface(
39
+ fn=predict_price,
40
+ inputs=gr.inputs.Textbox(lines=7, label="θΎ“ε…₯商品俑息"),
41
+ outputs=gr.outputs.Textbox(label="ι’„ζ΅‹δ»·ζ Ό"),
42
+ title="商品价格钄桋",
43
+ description="θΎ“ε…₯商品俑息,钄桋商品价格",
44
+ )
45
 
46
+ interface.launch()