baohuynhbk14 commited on
Commit
d9991d4
·
1 Parent(s): 412554a

Refactor image processing functions for better performance and clarity in utils.py

Browse files
Files changed (1) hide show
  1. app.py +271 -590
app.py CHANGED
@@ -1,622 +1,303 @@
1
- import spaces
2
- import argparse
3
- from ast import parse
4
- import datetime
5
- import json
6
- import os
7
- import time
8
- import hashlib
9
- import re
10
  import torch
 
11
  import gradio as gr
12
- import requests
13
- import random
14
- from filelock import FileLock
15
- from io import BytesIO
16
- from PIL import Image, ImageDraw, ImageFont
17
- from models import load_image
18
- from constants import LOGDIR
19
- from utils import (
20
- build_logger,
21
- server_error_msg,
22
- violates_moderation,
23
- moderation_msg,
24
- load_image_from_base64,
25
- get_log_filename,
26
- )
27
  from threading import Thread
28
- import traceback
29
- # import torch
30
- from conversation import Conversation
31
- from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
 
32
  import subprocess
 
33
 
34
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
35
 
36
  torch.set_default_device('cuda')
37
 
38
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
39
-
40
- headers = {"User-Agent": "Vintern-Chat Client"}
41
-
42
- no_change_btn = gr.Button()
43
- enable_btn = gr.Button(interactive=True)
44
- disable_btn = gr.Button(interactive=False)
45
-
46
-
47
- @spaces.GPU(duration=10)
48
- def make_zerogpu_happy():
49
- pass
50
-
51
-
52
- def write2file(path, content):
53
- lock = FileLock(f"{path}.lock")
54
- with lock:
55
- with open(path, "a") as fout:
56
- fout.write(content)
57
-
58
-
59
- get_window_url_params = """
60
- function() {
61
- const params = new URLSearchParams(window.location.search);
62
- url_params = Object.fromEntries(params);
63
- console.log(url_params);
64
- return url_params;
65
- }
66
- """
67
-
68
-
69
- def init_state(state=None):
70
- if state is not None:
71
- del state
72
- return Conversation()
73
-
74
- def vote_last_response(state, liked, request: gr.Request):
75
- conv_data = {
76
- "tstamp": round(time.time(), 4),
77
- "like": liked,
78
- "model": 'Vintern-1B-v3',
79
- "state": state.dict(),
80
- "ip": request.client.host,
81
- }
82
- write2file(get_log_filename(), json.dumps(conv_data) + "\n")
83
-
84
-
85
- def upvote_last_response(state, request: gr.Request):
86
- logger.info(f"upvote. ip: {request.client.host}")
87
- vote_last_response(state, True, request)
88
- textbox = gr.MultimodalTextbox(value=None, interactive=True)
89
- return (textbox,) + (disable_btn,) * 3
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- def downvote_last_response(state, request: gr.Request):
93
- logger.info(f"downvote. ip: {request.client.host}")
94
- vote_last_response(state, False, request)
95
- textbox = gr.MultimodalTextbox(value=None, interactive=True)
96
- return (textbox,) + (disable_btn,) * 3
97
-
98
-
99
- def vote_selected_response(
100
- state, request: gr.Request, data: gr.LikeData
101
- ):
102
- logger.info(
103
- f"Vote: {data.liked}, index: {data.index}, value: {data.value} , ip: {request.client.host}"
104
- )
105
- conv_data = {
106
- "tstamp": round(time.time(), 4),
107
- "like": data.liked,
108
- "index": data.index,
109
- "model": 'Vintern-1B-v3',
110
- "state": state.dict(),
111
- "ip": request.client.host,
112
- }
113
- write2file(get_log_filename(), json.dumps(conv_data) + "\n")
114
- return
115
-
116
-
117
- def flag_last_response(state, request: gr.Request):
118
- logger.info(f"flag. ip: {request.client.host}")
119
- vote_last_response(state, "flag", request)
120
- textbox = gr.MultimodalTextbox(value=None, interactive=True)
121
- return (textbox,) + (disable_btn,) * 3
122
-
123
-
124
- def regenerate(state, image_process_mode, request: gr.Request):
125
- logger.info(f"regenerate. ip: {request.client.host}")
126
- # state.messages[-1][-1] = None
127
- state.update_message(Conversation.ASSISTANT, content='', image=None, idx=-1)
128
- prev_human_msg = state.messages[-2]
129
- if type(prev_human_msg[1]) in (tuple, list):
130
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
131
- state.skip_next = False
132
- textbox = gr.MultimodalTextbox(value=None, interactive=True)
133
- return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
134
-
135
-
136
- def clear_history(request: gr.Request):
137
- logger.info(f"clear_history. ip: {request.client.host}")
138
- state = init_state()
139
- textbox = gr.MultimodalTextbox(value=None, interactive=True)
140
- return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
141
-
142
-
143
- def add_text(state, message, system_prompt, request: gr.Request):
144
- print(f"state: {state}")
145
- if not state:
146
- state = init_state()
147
- images = message.get("files", [])
148
- text = message.get("text", "").strip()
149
- logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
150
- # import pdb; pdb.set_trace()
151
- textbox = gr.MultimodalTextbox(value=None, interactive=False)
152
- if len(text) <= 0 and len(images) == 0:
153
- state.skip_next = True
154
- return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
155
- if args.moderate:
156
- flagged = violates_moderation(text)
157
- if flagged:
158
- state.skip_next = True
159
- textbox = gr.MultimodalTextbox(
160
- value={"text": moderation_msg}, interactive=True
161
- )
162
- return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
163
- images = [Image.open(path).convert("RGB") for path in images]
164
-
165
- if len(images) > 0 and len(state.get_images(source=state.USER)) > 0:
166
- state = init_state(state)
167
- state.set_system_message(system_prompt)
168
- state.append_message(Conversation.USER, text, images)
169
- state.skip_next = False
170
- return (state, state.to_gradio_chatbot(), textbox) + (
171
- disable_btn,
172
- ) * 5
173
 
174
- model_name = "5CD-AI/Vintern-1B-v3_5"
175
  model = AutoModel.from_pretrained(
176
- model_name,
177
  torch_dtype=torch.bfloat16,
178
  low_cpu_mem_usage=True,
179
  trust_remote_code=True,
180
  ).eval().cuda()
181
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
182
-
183
- @spaces.GPU
184
- def http_bot(
185
- state,
186
- temperature,
187
- top_p,
188
- repetition_penalty,
189
- max_new_tokens,
190
- max_input_tiles,
191
- request: gr.Request,
192
- ):
193
 
194
- logger.info(f"http_bot. ip: {request.client.host}")
195
- start_tstamp = time.time()
196
- if hasattr(state, "skip_next") and state.skip_next:
197
- # This generate call is skipped due to invalid inputs
198
- yield (
199
- state,
200
- state.to_gradio_chatbot(),
201
- gr.MultimodalTextbox(interactive=False),
202
- ) + (no_change_btn,) * 5
203
- return
204
-
205
- if model is None:
206
- # state.messages[-1][-1] = server_error_msg
207
- state.update_message(Conversation.ASSISTANT, server_error_msg)
208
- yield (
209
- state,
210
- state.to_gradio_chatbot(),
211
- gr.MultimodalTextbox(interactive=False),
212
- disable_btn,
213
- disable_btn,
214
- disable_btn,
215
- enable_btn,
216
- enable_btn,
217
- )
218
- return
219
-
220
- all_images = state.get_images(source=state.USER)
221
- all_image_paths = [state.save_image(image) for image in all_images]
222
-
223
- state.append_message(Conversation.ASSISTANT, state.streaming_placeholder)
224
- yield (
225
- state,
226
- state.to_gradio_chatbot(),
227
- gr.MultimodalTextbox(interactive=False),
228
- ) + (disable_btn,) * 5
229
-
230
- try:
231
- # Stream output
232
- # response = requests.post(worker_addr, json=pload, headers=headers, stream=True, timeout=300)
233
- print(f"all_image_paths: {all_image_paths}")
234
 
235
- pixel_values = load_image(all_image_paths[0], max_num=6).to(torch.bfloat16).cuda()
236
- print(f"pixel_values: {pixel_values}")
237
- generation_config = dict(max_new_tokens= 700, do_sample=False, num_beams = 3, repetition_penalty=2.5)
238
- message = state.get_user_message(source=state.USER)
239
- print(f"######################")
240
- print(f"message: {message}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  if pixel_values is not None:
242
- question = '<image>\n'+message
243
  else:
244
- question = message
245
- print("Model: ", model)
246
- print("Tokenizer: ", tokenizer)
247
- print("Question: ", question)
248
  response, conv_history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
249
- print(f"AI response: {response}")
250
-
251
-
252
- # streamer = TextIteratorStreamer(
253
- # tokenizer, skip_prompt=True, skip_special_tokens=True
254
- # )
255
- # generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
256
-
257
- # thread = Thread(target=model.generate, kwargs=generation_kwargs)
258
- # thread.start()
259
 
260
- # response = "This is a test response"
261
- buffer = ""
262
- for new_text in response:
263
- buffer += new_text
264
- # Remove <|im_end|> or similar tokens from the output
265
- buffer = buffer.replace("<|im_end|>", "")
266
-
267
- state.update_message(Conversation.ASSISTANT, buffer + state.streaming_placeholder, None)
268
- yield (
269
- state,
270
- state.to_gradio_chatbot(),
271
- gr.MultimodalTextbox(interactive=False),
272
- ) + (disable_btn,) * 5
273
-
274
- except Exception as e:
275
- logger.error(f"Error in http_bot: {e}")
276
- traceback.print_exc()
277
- state.update_message(Conversation.ASSISTANT, server_error_msg, None)
278
- yield (
279
- state,
280
- state.to_gradio_chatbot(),
281
- gr.MultimodalTextbox(interactive=True),
282
- ) + (
283
- disable_btn,
284
- disable_btn,
285
- disable_btn,
286
- enable_btn,
287
- enable_btn,
288
- )
289
- return
290
-
291
- ai_response = state.return_last_message()
292
-
293
- logger.info(f"==== response ====\n{ai_response}")
294
-
295
- state.end_of_current_turn()
296
-
297
- yield (
298
- state,
299
- state.to_gradio_chatbot(),
300
- gr.MultimodalTextbox(interactive=True),
301
- ) + (enable_btn,) * 5
302
-
303
- finish_tstamp = time.time()
304
- logger.info(f"{buffer}")
305
- data = {
306
- "tstamp": round(finish_tstamp, 4),
307
- "like": None,
308
- "model": model_name,
309
- "start": round(start_tstamp, 4),
310
- "finish": round(start_tstamp, 4),
311
- "state": state.dict(),
312
- "images": all_image_paths,
313
- "ip": request.client.host,
314
- }
315
- write2file(get_log_filename(), json.dumps(data) + "\n")
316
-
317
- # <h1 style="font-size: 28px; font-weight: bold;">Expanding Performance Boundaries of Open-Source Multimodal Models with Model, Data, and Test-Time Scaling</h1>
318
- title_html = """
319
- <div style="text-align: center;">
320
- <img src="https://lh3.googleusercontent.com/pw/AP1GczMmW-aFQ4dNaR_LCAllh4UZLLx9fTZ1ITHeGVMWx-1bwlIWz4VsWJSGb3_9C7CQfvboqJH41y2Sbc5ToC9ZmKeV4-buf_DEevIMU0HtaLWgHAPOqBiIbG6LaE8CvDqniLZzvB9UX8TR_-YgvYzPFt2z=w1472-h832-s-no-gm?authuser=0" style="height: 100; width: 100%;">
321
- <p>Vintern-1B: An Efficient Multimodal Large Language Model for Vietnamese</p>
322
- <a href="https://huggingface.co/papers/2408.12480">[📖 Vintern Paper]</a>
323
- <a href="https://huggingface.co/5CD-AI">[🤗 5CD-AI Huggingface]</a>
324
- </div>
325
- """
326
-
327
-
328
- tos_markdown = """
329
- ### Terms of use
330
- By using this service, users are required to agree to the following terms:
331
- It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
332
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
333
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
334
- """
335
-
336
-
337
- # .gradio-container {margin: 5px 10px 0 10px !important};
338
- block_css = """
339
- .gradio-container {margin: 0.1% 1% 0 1% !important; max-width: 98% !important;};
340
- #buttons button {
341
- min-width: min(120px,100%);
342
  }
343
-
344
- .gradient-text {
345
- font-size: 28px;
346
- width: auto;
347
- font-weight: bold;
348
- background: linear-gradient(45deg, red, orange, yellow, green, blue, indigo, violet);
349
- background-clip: text;
350
- -webkit-background-clip: text;
351
- color: transparent;
352
  }
353
-
354
- .plain-text {
355
- font-size: 22px;
356
- width: auto;
357
- font-weight: bold;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  }
359
  """
360
 
361
- # js = """
362
- # function createWaveAnimation() {
363
- # const text = document.getElementById('text');
364
- # var i = 0;
365
- # setInterval(function() {
366
- # const colors = [
367
- # 'red, orange, yellow, green, blue, indigo, violet, purple',
368
- # 'orange, yellow, green, blue, indigo, violet, purple, red',
369
- # 'yellow, green, blue, indigo, violet, purple, red, orange',
370
- # 'green, blue, indigo, violet, purple, red, orange, yellow',
371
- # 'blue, indigo, violet, purple, red, orange, yellow, green',
372
- # 'indigo, violet, purple, red, orange, yellow, green, blue',
373
- # 'violet, purple, red, orange, yellow, green, blue, indigo',
374
- # 'purple, red, orange, yellow, green, blue, indigo, violet',
375
- # ];
376
- # const angle = 45;
377
- # const colorIndex = i % colors.length;
378
- # text.style.background = `linear-gradient(${angle}deg, ${colors[colorIndex]})`;
379
- # text.style.webkitBackgroundClip = 'text';
380
- # text.style.backgroundClip = 'text';
381
- # text.style.color = 'transparent';
382
- # text.style.fontSize = '28px';
383
- # text.style.width = 'auto';
384
- # text.textContent = 'Vintern-1B';
385
- # text.style.fontWeight = 'bold';
386
- # i += 1;
387
- # }, 200);
388
- # const params = new URLSearchParams(window.location.search);
389
- # url_params = Object.fromEntries(params);
390
- # // console.log(url_params);
391
- # // console.log('hello world...');
392
- # // console.log(window.location.search);
393
- # // console.log('hello world...');
394
- # // alert(window.location.search)
395
- # // alert(url_params);
396
- # return url_params;
397
- # }
398
-
399
- # """
400
 
 
401
 
402
- def build_demo():
403
- textbox = gr.MultimodalTextbox(
404
- interactive=True,
405
- file_types=["image", "video"],
406
- placeholder="Enter message or upload file...",
407
- show_label=False,
 
 
 
 
408
  )
409
-
410
- with gr.Blocks(
411
- title="Vintern-Chat",
412
- theme=gr.themes.Default(),
413
- css=block_css,
414
- ) as demo:
415
- state = gr.State()
416
-
417
- with gr.Row():
418
- with gr.Column(scale=2):
419
- # gr.Image('./gallery/logo-47b364d3.jpg')
420
- gr.HTML(title_html)
421
-
422
- with gr.Accordion("Settings", open=False) as setting_row:
423
- system_prompt = gr.Textbox(
424
- value="请尽可能详细地回答用户的问题。",
425
- label="System Prompt",
426
- interactive=True,
427
- )
428
- temperature = gr.Slider(
429
- minimum=0.0,
430
- maximum=1.0,
431
- value=0.2,
432
- step=0.1,
433
- interactive=True,
434
- label="Temperature",
435
- )
436
- top_p = gr.Slider(
437
- minimum=0.0,
438
- maximum=1.0,
439
- value=0.7,
440
- step=0.1,
441
- interactive=True,
442
- label="Top P",
443
- )
444
- repetition_penalty = gr.Slider(
445
- minimum=1.0,
446
- maximum=1.5,
447
- value=1.1,
448
- step=0.02,
449
- interactive=True,
450
- label="Repetition penalty",
451
- )
452
- max_output_tokens = gr.Slider(
453
- minimum=0,
454
- maximum=4096,
455
- value=1024,
456
- step=64,
457
- interactive=True,
458
- label="Max output tokens",
459
- )
460
- max_input_tiles = gr.Slider(
461
- minimum=1,
462
- maximum=32,
463
- value=12,
464
- step=1,
465
- interactive=True,
466
- label="Max input tiles (control the image size)",
467
- )
468
- examples = gr.Examples(
469
- examples=[
470
- [
471
- {
472
- "files": [
473
- "gallery/14.jfif",
474
- ],
475
- "text": "Please help me analyze this picture.",
476
- }
477
- ],
478
- [
479
- {
480
- "files": [
481
- "gallery/1-2.PNG",
482
- ],
483
- "text": "Implement this flow chart using python",
484
- }
485
- ],
486
- [
487
- {
488
- "files": [
489
- "gallery/15.PNG",
490
- ],
491
- "text": "Please help me analyze this picture.",
492
- }
493
- ],
494
- ],
495
- inputs=[textbox],
496
- )
497
-
498
- with gr.Column(scale=8):
499
- chatbot = gr.Chatbot(
500
- elem_id="chatbot",
501
- label="Vintern",
502
- height=580,
503
- show_copy_button=True,
504
- show_share_button=True,
505
- avatar_images=[
506
- "assets/human.png",
507
- "assets/assistant.png",
508
- ],
509
- bubble_full_width=False,
510
- )
511
- with gr.Row():
512
- with gr.Column(scale=8):
513
- textbox.render()
514
- with gr.Column(scale=1, min_width=50):
515
- submit_btn = gr.Button(value="Send", variant="primary")
516
- with gr.Row(elem_id="buttons") as button_row:
517
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
518
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
519
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
520
- # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
521
- regenerate_btn = gr.Button(
522
- value="🔄 Regenerate", interactive=False
523
- )
524
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
525
-
526
- gr.Markdown(tos_markdown)
527
- url_params = gr.JSON(visible=False)
528
-
529
- # Register listeners
530
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
531
- upvote_btn.click(
532
- upvote_last_response,
533
- [state],
534
- [textbox, upvote_btn, downvote_btn, flag_btn],
535
- )
536
- downvote_btn.click(
537
- downvote_last_response,
538
- [state],
539
- [textbox, upvote_btn, downvote_btn, flag_btn],
540
- )
541
- chatbot.like(
542
- vote_selected_response,
543
- [state],
544
- [],
545
- )
546
- flag_btn.click(
547
- flag_last_response,
548
- [state],
549
- [textbox, upvote_btn, downvote_btn, flag_btn],
550
- )
551
- regenerate_btn.click(
552
- regenerate,
553
- [state, system_prompt],
554
- [state, chatbot, textbox] + btn_list,
555
- ).then(
556
- http_bot,
557
- [
558
- state,
559
- temperature,
560
- top_p,
561
- repetition_penalty,
562
- max_output_tokens,
563
- max_input_tiles,
564
- ],
565
- [state, chatbot, textbox] + btn_list,
566
- )
567
- clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
568
-
569
- textbox.submit(
570
- add_text,
571
- [state, textbox, system_prompt],
572
- [state, chatbot, textbox] + btn_list,
573
- ).then(
574
- http_bot,
575
- [
576
- state,
577
- temperature,
578
- top_p,
579
- repetition_penalty,
580
- max_output_tokens,
581
- max_input_tiles,
582
- ],
583
- [state, chatbot, textbox] + btn_list,
584
- )
585
- submit_btn.click(
586
- add_text,
587
- [state, textbox, system_prompt],
588
- [state, chatbot, textbox] + btn_list,
589
- ).then(
590
- http_bot,
591
- [
592
- state,
593
- temperature,
594
- top_p,
595
- repetition_penalty,
596
- max_output_tokens,
597
- max_input_tiles,
598
- ],
599
- [state, chatbot, textbox] + btn_list,
600
- )
601
-
602
- return demo
603
-
604
-
605
- if __name__ == "__main__":
606
- parser = argparse.ArgumentParser()
607
- parser.add_argument("--host", type=str, default="0.0.0.0")
608
- parser.add_argument("--port", type=int, default=7860)
609
- parser.add_argument("--concurrency-count", type=int, default=10)
610
- parser.add_argument("--share", action="store_true")
611
- parser.add_argument("--moderate", action="store_true")
612
- args = parser.parse_args()
613
- logger.info(f"args: {args}")
614
-
615
- logger.info(args)
616
- demo = build_demo()
617
- demo.queue(api_open=False).launch(
618
- server_name=args.host,
619
- server_port=args.port,
620
- share=args.share,
621
- max_threads=args.concurrency_count,
622
  )
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
 
 
 
 
 
 
 
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria
4
  import gradio as gr
5
+ import spaces
6
+ import torch
7
+ import numpy as np
8
+ import torch
9
+ import torchvision.transforms as T
10
+ from PIL import Image
11
+ from torchvision.transforms.functional import InterpolationMode
12
+ from transformers import AutoModel, AutoTokenizer
13
+ from PIL import Image, ExifTags
14
+
 
 
 
 
 
15
  from threading import Thread
16
+ import re
17
+ import time
18
+ from PIL import Image
19
+ import torch
20
+ import spaces
21
  import subprocess
22
+ import os
23
 
24
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
25
 
26
  torch.set_default_device('cuda')
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
30
+ IMAGENET_STD = (0.229, 0.224, 0.225)
31
+
32
+ def build_transform(input_size):
33
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
34
+ transform = T.Compose([
35
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
36
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
37
+ T.ToTensor(),
38
+ T.Normalize(mean=MEAN, std=STD)
39
+ ])
40
+ return transform
41
+
42
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
43
+ best_ratio_diff = float('inf')
44
+ best_ratio = (1, 1)
45
+ area = width * height
46
+ for ratio in target_ratios:
47
+ target_aspect_ratio = ratio[0] / ratio[1]
48
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
49
+ if ratio_diff < best_ratio_diff:
50
+ best_ratio_diff = ratio_diff
51
+ best_ratio = ratio
52
+ elif ratio_diff == best_ratio_diff:
53
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
54
+ best_ratio = ratio
55
+ return best_ratio
56
+
57
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
58
+ orig_width, orig_height = image.size
59
+ aspect_ratio = orig_width / orig_height
60
+
61
+ # calculate the existing image aspect ratio
62
+ target_ratios = set(
63
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
64
+ i * j <= max_num and i * j >= min_num)
65
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
66
+
67
+ # find the closest aspect ratio to the target
68
+ target_aspect_ratio = find_closest_aspect_ratio(
69
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
70
+
71
+ # calculate the target width and height
72
+ target_width = image_size * target_aspect_ratio[0]
73
+ target_height = image_size * target_aspect_ratio[1]
74
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
75
+
76
+ # resize the image
77
+ resized_img = image.resize((target_width, target_height))
78
+ processed_images = []
79
+ for i in range(blocks):
80
+ box = (
81
+ (i % (target_width // image_size)) * image_size,
82
+ (i // (target_width // image_size)) * image_size,
83
+ ((i % (target_width // image_size)) + 1) * image_size,
84
+ ((i // (target_width // image_size)) + 1) * image_size
85
+ )
86
+ # split the image
87
+ split_img = resized_img.crop(box)
88
+ processed_images.append(split_img)
89
+ assert len(processed_images) == blocks
90
+ if use_thumbnail and len(processed_images) != 1:
91
+ thumbnail_img = image.resize((image_size, image_size))
92
+ processed_images.append(thumbnail_img)
93
+ return processed_images
94
+
95
+ def correct_image_orientation(image_path):
96
+ # Mở ảnh
97
+ image = Image.open(image_path)
98
+
99
+ # Kiểm tra dữ liệu Exif (nếu có)
100
+ try:
101
+ exif = image._getexif()
102
+ if exif is not None:
103
+ for tag, value in exif.items():
104
+ if ExifTags.TAGS.get(tag) == "Orientation":
105
+ # Sửa hướng dựa trên Orientation
106
+ if value == 3:
107
+ image = image.rotate(180, expand=True)
108
+ elif value == 6:
109
+ image = image.rotate(-90, expand=True)
110
+ elif value == 8:
111
+ image = image.rotate(90, expand=True)
112
+ break
113
+ except Exception as e:
114
+ print("Không thể xử lý Exif:", e)
115
 
116
+ return image
117
+
118
+ def load_image(image_file, input_size=448, max_num=12):
119
+ image = correct_image_orientation(image_file).convert('RGB')
120
+ print("Image size: ", image.size)
121
+ transform = build_transform(input_size=input_size)
122
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
123
+ pixel_values = [transform(image) for image in images]
124
+ pixel_values = torch.stack(pixel_values)
125
+ return pixel_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
 
127
  model = AutoModel.from_pretrained(
128
+ "5CD-AI/Vintern-1B-v3_5",
129
  torch_dtype=torch.bfloat16,
130
  low_cpu_mem_usage=True,
131
  trust_remote_code=True,
132
  ).eval().cuda()
133
+ tokenizer = AutoTokenizer.from_pretrained("5CD-AI/Vintern-1B-v3_5", trust_remote_code=True, use_fast=False)
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ @spaces.GPU
136
+ def chat(message, history):
137
+ print("history",history)
138
+ print("message",message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ if len(history) != 0 and len(message["files"]) != 0:
141
+ return """Chúng tôi hiện chỉ hổ trợ 1 ảnh ở đầu ngữ cảnh! Vui lòng tạo mới cuộc trò chuyện.
142
+ We currently only support one image at the start of the context! Please start a new conversation."""
143
+
144
+ if len(history) == 0 and len(message["files"]) != 0:
145
+ if "path" in message["files"][0]:
146
+ test_image = message["files"][0]["path"]
147
+ else:
148
+ test_image = message["files"][0]
149
+ pixel_values = load_image(test_image, max_num=6).to(torch.bfloat16).cuda()
150
+ elif len(history) == 0 and len(message["files"]) == 0:
151
+ pixel_values = None
152
+ elif history[0][0][0] is not None and os.path.isfile(history[0][0][0]):
153
+ test_image = history[0][0][0]
154
+ pixel_values = load_image(test_image, max_num=6).to(torch.bfloat16).cuda()
155
+ else:
156
+ pixel_values = None
157
+
158
+
159
+ generation_config = dict(max_new_tokens= 700, do_sample=False, num_beams = 3, repetition_penalty=2.5)
160
+
161
+ if len(history) == 0:
162
  if pixel_values is not None:
163
+ question = '<image>\n'+message["text"]
164
  else:
165
+ question = message["text"]
 
 
 
166
  response, conv_history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
167
+ else:
168
+ conv_history = []
169
+ if history[0][0][0] is not None and os.path.isfile(history[0][0][0]):
170
+ start_index = 1
171
+ else:
172
+ start_index = 0
 
 
 
 
173
 
174
+ for i, chat_pair in enumerate(history[start_index:]):
175
+ if i == 0 and start_index == 1:
176
+ conv_history.append(tuple(['<image>\n'+chat_pair[0],chat_pair[1]]))
177
+ else:
178
+ conv_history.append(tuple(chat_pair))
179
+
180
+
181
+ print("conv_history",conv_history)
182
+ question = message["text"]
183
+ response, conv_history = model.chat(tokenizer, pixel_values, question, generation_config, history=conv_history, return_history=True)
184
+
185
+ print(f'User: {question}\nAssistant: {response}')
186
+
187
+ # return response
188
+ buffer = ""
189
+ for new_text in response:
190
+ buffer += new_text
191
+ generated_text_without_prompt = buffer[:]
192
+ time.sleep(0.02)
193
+ yield generated_text_without_prompt
194
+
195
+ CSS ="""
196
+ #component-10 {
197
+ height: 70dvh !important;
198
+ transform-origin: top; /* Đảm bảo rằng phần tử mở rộng từ trên xuống */
199
+ border-style: solid;
200
+ overflow: hidden;
201
+ flex-grow: 1;
202
+ min-width: min(160px, 100%);
203
+ border-width: var(--block-border-width);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  }
205
+ /* Đảm bảo ảnh bên trong nút hiển thị đúng cách cho các nút có aria-label chỉ định */
206
+ button.svelte-1lcyrx4[aria-label="user's message: a file of type image/jpeg, "] img.svelte-1pijsyv {
207
+ width: 100%;
208
+ object-fit: contain;
209
+ height: 100%;
210
+ border-radius: 13px; /* Thêm bo góc cho ảnh */
211
+ max-width: 50vw; /* Giới hạn chiều rộng ảnh */
 
 
212
  }
213
+ /* Đặt chiều cao cho nút và cho phép chọn văn bản chỉ cho các nút có aria-label chỉ định */
214
+ button.svelte-1lcyrx4[aria-label="user's message: a file of type image/jpeg, "] {
215
+ user-select: text;
216
+ text-align: left;
217
+ height: 300px;
218
+ }
219
+ /* Thêm bo góc và giới hạn chiều rộng cho ảnh không thuộc avatar container */
220
+ .message-wrap.svelte-1lcyrx4 > div.svelte-1lcyrx4 .svelte-1lcyrx4:not(.avatar-container) img {
221
+ border-radius: 13px;
222
+ max-width: 50vw;
223
+ }
224
+ .message-wrap.svelte-1lcyrx4 .message.svelte-1lcyrx4 img {
225
+ margin: var(--size-2);
226
+ max-height: 500px;
227
+ }
228
+ .image-preview-close-button {
229
+ position: relative; /* Nếu cần định vị trí */
230
+ width: 5%; /* Chiều rộng nút */
231
+ height: 5%; /* Chiều cao nút */
232
+ display: flex;
233
+ justify-content: center;
234
+ align-items: center;
235
+ padding: 0; /* Để tránh ảnh hưởng từ padding mặc định */
236
+ border: none; /* Tùy chọn để loại bỏ đường viền */
237
+ background: none; /* Tùy chọn để loại bỏ nền */
238
+ }
239
+ .example-image-container.svelte-9pi8y1 {
240
+ width: calc(var(--size-8) * 5);
241
+ height: calc(var(--size-8) * 5);
242
+ border-radius: var(--radius-lg);
243
+ overflow: hidden;
244
+ position: relative;
245
+ margin-bottom: var(--spacing-lg);
246
  }
247
  """
248
 
249
+ js = """
250
+ function forceLightTheme() {
251
+ const url = new URL(window.location);
252
+ // Cập nhật __theme thành light nếu giá trị không đúng
253
+ if (url.searchParams.get('__theme') !== 'light') {
254
+ url.searchParams.set('__theme', 'light');
255
+ // Thay đổi URL không tải lại trang nếu cần
256
+ window.history.replaceState({}, '', url.href);
257
+ }
258
+ // Đảm bảo document luôn áp dụng theme light
259
+ document.documentElement.setAttribute('data-theme', 'light');
260
+ }
261
+ """
262
+ from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16, device="cuda:0")
265
 
266
+ @spaces.GPU
267
+ def transcribe_speech(filepath):
268
+ output = pipe(
269
+ filepath,
270
+ max_new_tokens=256,
271
+ generate_kwargs={
272
+ "task": "transcribe",
273
+ },
274
+ chunk_length_s=30,
275
+ batch_size=1,
276
  )
277
+ return output["text"]
278
+
279
+ demo = gr.Blocks(css=CSS,js=js, theme='NoCrypt/miku')
280
+
281
+ with demo:
282
+ chat_demo_interface = gr.ChatInterface(
283
+ fn=chat,
284
+ description="""**Vintern-1B-v3.5** is the latest in the Vintern series, bringing major improvements over v2 across all benchmarks. This **continuous fine-tuning Version** enhances Vietnamese capabilities while retaining strong English performance. It excels in OCR, text recognition, and Vietnam-specific document understanding.""",
285
+ examples=[{"text": "Hãy viết một email giới thiệu sản phẩm trong ảnh.", "files":["./demo_3.jpg"]},
286
+ {"text": "Trích xuất các thông tin từ ảnh trả về markdown.", "files":["./demo_1.jpg"]},
287
+ {"text": "Bạn là nhân viên marketing chuyên nghiệp. Hãy viết một bài quảng cáo dài trên mạng xã hội giới thiệu về cửa hàng.", "files":["./demo_2.jpg"]},
288
+ {"text": "Trích xuất thông tin kiện hàng trong ảnh và trả về dạng JSON.", "files":["./demo_4.jpg"]}],
289
+ title="❄️ Vintern-1B-v3.5 Demo ❄️",
290
+ multimodal=True,
291
+ css=CSS,
292
+ js=js,
293
+ theme='NoCrypt/miku'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  )
295
+
296
+ # mic_transcribe = gr.Interface(
297
+ # fn=transcribe_speech,
298
+ # inputs=gr.Audio(sources="microphone", type="filepath", editable=False),
299
+ # outputs=gr.components.Textbox(),
300
+ # )
301
+
302
+ # chat_demo_interface.queue()
303
+ demo.queue().launch()