jiuhai commited on
Commit
01f98ad
1 Parent(s): 54c78cf
Files changed (42) hide show
  1. app.py +394 -394
  2. llava/__pycache__/__init__.cpython-310.pyc +0 -0
  3. llava/__pycache__/constants.cpython-310.pyc +0 -0
  4. llava/__pycache__/conversation.cpython-310.pyc +0 -0
  5. llava/__pycache__/mm_utils.cpython-310.pyc +0 -0
  6. llava/__pycache__/utils.cpython-310.pyc +0 -0
  7. llava/model/__pycache__/__init__.cpython-310.pyc +0 -0
  8. llava/model/__pycache__/builder.cpython-310.pyc +0 -0
  9. llava/model/__pycache__/llava_arch.cpython-310.pyc +0 -0
  10. llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc +0 -0
  11. llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc +0 -0
  12. llava/model/language_model/__pycache__/llava_mpt.cpython-310.pyc +0 -0
  13. llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  14. llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc +0 -0
  15. llava/model/multimodal_encoder/__pycache__/imagebind.cpython-310.pyc +0 -0
  16. llava/model/multimodal_encoder/__pycache__/open_clip_encoder.cpython-310.pyc +0 -0
  17. llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc +0 -0
  18. llava/model/multimodal_encoder/dev_eva_clip/__pycache__/eva_vit.cpython-310.pyc +0 -0
  19. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/__init__.cpython-310.pyc +0 -0
  20. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/constants.cpython-310.pyc +0 -0
  21. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/eva_vit_model.cpython-310.pyc +0 -0
  22. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/factory.cpython-310.pyc +0 -0
  23. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_configs.cpython-310.pyc +0 -0
  24. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_model.cpython-310.pyc +0 -0
  25. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/loss.cpython-310.pyc +0 -0
  26. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/model.cpython-310.pyc +0 -0
  27. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/modified_resnet.cpython-310.pyc +0 -0
  28. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/openai.cpython-310.pyc +0 -0
  29. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/pretrained.cpython-310.pyc +0 -0
  30. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/rope.cpython-310.pyc +0 -0
  31. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/timm_model.cpython-310.pyc +0 -0
  32. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/tokenizer.cpython-310.pyc +0 -0
  33. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transform.cpython-310.pyc +0 -0
  34. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transformer.cpython-310.pyc +0 -0
  35. llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/utils.cpython-310.pyc +0 -0
  36. llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_encoder.cpython-310.pyc +0 -0
  37. llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_processors.cpython-310.pyc +0 -0
  38. llava/model/multimodal_encoder/eva_clip/__pycache__/eva_vit.cpython-310.pyc +0 -0
  39. llava/model/multimodal_encoder/eva_clip/__pycache__/factory.cpython-310.pyc +0 -0
  40. llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
  41. llava/train/__pycache__/llava_trainer.cpython-310.pyc +0 -0
  42. llava/train/__pycache__/train.cpython-310.pyc +0 -0
app.py CHANGED
@@ -1,360 +1,39 @@
1
- import gradio as gr
2
- import os
3
- import torch
4
- import spaces
5
-
6
- from llava import conversation as conversation_lib
7
- from llava.constants import IMAGE_TOKEN_IDX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
8
- from llava.conversation import conv_templates, SeparatorStyle
9
- from llava.model.builder import load_pretrained_model
10
- from llava.utils import disable_torch_init
11
- from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
12
-
13
- from PIL import Image
14
- import argparse
15
-
16
-
17
- from transformers import TextIteratorStreamer
18
- from threading import Thread
19
-
20
- import subprocess
21
- # Install flash attention, skipping CUDA build if necessary
22
- subprocess.run(
23
- "pip install flash-attn --no-build-isolation",
24
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
25
- shell=True,
26
- )
27
-
28
- # os.environ['GRADIO_TEMP_DIR'] = './gradio_tmp'
29
- no_change_btn = gr.Button()
30
- enable_btn = gr.Button(interactive=True)
31
- disable_btn = gr.Button(interactive=False)
32
-
33
- argparser = argparse.ArgumentParser()
34
- argparser.add_argument("--model-path", default="umd-vt-nyu/clip-evaclip-und-gen-sft", type=str)
35
- argparser.add_argument("--model-base", type=str, default=None)
36
- argparser.add_argument("--num-gpus", type=int, default=1)
37
- argparser.add_argument("--conv-mode", type=str, default="llama3")
38
- argparser.add_argument("--temperature", type=float, default=0.2)
39
- argparser.add_argument("--max-new-tokens", type=int, default=64)
40
- argparser.add_argument("--num_frames", type=int, default=16)
41
- argparser.add_argument("--load-8bit", action="store_true")
42
- argparser.add_argument("--load-4bit", action="store_true")
43
- argparser.add_argument("--debug", action="store_true")
44
-
45
- args = argparser.parse_args()
46
- model_path = args.model_path
47
- conv_mode = args.conv_mode
48
- filt_invalid="cut"
49
- model_name = get_model_name_from_path(args.model_path)
50
- model_name = 'clip-evaclip-und-gen-sft'
51
- model_kwargs = {
52
- "use_cache": False,
53
- "trust_remote_code": True,
54
- "torch_dtype": torch.bfloat16,
55
- "attn_implementation": "sdpa"
56
- }
57
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map="cuda:0", **model_kwargs)
58
- our_chatbot = None
59
-
60
- def upvote_last_response(state):
61
- return ("",) + (disable_btn,) * 3
62
-
63
-
64
- def downvote_last_response(state):
65
- return ("",) + (disable_btn,) * 3
66
-
67
-
68
- def flag_last_response(state):
69
- return ("",) + (disable_btn,) * 3
70
-
71
- def clear_history():
72
- state =conv_templates[conv_mode].copy()
73
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
74
-
75
- def add_text(state, imagebox, textbox, image_process_mode):
76
- if state is None:
77
- state = conv_templates[conv_mode].copy()
78
-
79
- if imagebox is not None:
80
- textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox
81
- image = Image.open(imagebox).convert('RGB')
82
-
83
- if imagebox is not None:
84
- textbox = (textbox, image, image_process_mode)
85
-
86
- state.append_message(state.roles[0], textbox)
87
- state.append_message(state.roles[1], None)
88
-
89
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
90
-
91
- def delete_text(state, image_process_mode):
92
- state.messages[-1][-1] = None
93
- prev_human_msg = state.messages[-2]
94
- if type(prev_human_msg[1]) in (tuple, list):
95
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
96
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
97
-
98
- def regenerate(state, image_process_mode):
99
- state.messages[-1][-1] = None
100
- prev_human_msg = state.messages[-2]
101
- if type(prev_human_msg[1]) in (tuple, list):
102
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
103
- state.skip_next = False
104
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
105
-
106
- @spaces.GPU
107
- def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
108
- prompt = state.get_prompt()
109
- images = state.get_images(return_pil=True)
110
- #prompt, image_args = process_image(prompt, images)
111
-
112
- ori_prompt = prompt
113
- num_image_tokens = 0
114
-
115
- if images is not None and len(images) > 0:
116
- if len(images) > 0:
117
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
118
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
119
-
120
- #images = [load_image_from_base64(image) for image in images]
121
- image_sizes = [image.size for image in images]
122
-
123
- images = process_images(images, image_processor, model.config)
124
-
125
- if type(images) is list:
126
- images = [image.to(model.device, dtype=torch.float16) for image in images]
127
- else:
128
- images = images.to(model.device, dtype=torch.float16)
129
- else:
130
- images = None
131
- image_sizes = None
132
- image_args = {"images": images, "image_sizes": image_sizes}
133
- else:
134
- images = None
135
- image_args = {}
136
-
137
- max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
138
- max_new_tokens = 512
139
- do_sample = True if temperature > 0.001 else False
140
- stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
141
-
142
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_IDX, return_tensors='pt').unsqueeze(0).to(model.device)
143
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
144
-
145
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
146
-
147
- if max_new_tokens < 1:
148
- # yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
149
- return
150
-
151
- thread = Thread(target=model.generate, kwargs=dict(
152
- inputs=input_ids,
153
- do_sample=do_sample,
154
- temperature=temperature,
155
- top_p=top_p,
156
- max_new_tokens=max_new_tokens,
157
- streamer=streamer,
158
- use_cache=True,
159
- pad_token_id=tokenizer.eos_token_id,
160
- **image_args
161
- ))
162
- thread.start()
163
- generated_text = ''
164
- for new_text in streamer:
165
- generated_text += new_text
166
- if generated_text.endswith(stop_str):
167
- generated_text = generated_text[:-len(stop_str)]
168
- state.messages[-1][-1] = generated_text
169
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
170
-
171
- yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
172
-
173
- torch.cuda.empty_cache()
174
-
175
- txt = gr.Textbox(
176
- scale=4,
177
- show_label=False,
178
- placeholder="Enter text and press enter.",
179
- container=False,
180
- )
181
-
182
-
183
- title_markdown = ("""
184
- # Florence-llama
185
- [[Code](TBD)] [[Model](TBD)] | 📚 [[Arxiv](TBD)]]
186
- """)
187
-
188
- # title_markdown = ("""
189
- # # Florence-llama
190
- # """)
191
-
192
- tos_markdown = ("""
193
- ### Terms of use
194
- By using this service, users are required to agree to the following terms:
195
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
196
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
197
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
198
- """)
199
-
200
-
201
- learn_more_markdown = ("""
202
- ### License
203
- The service is a research preview intended for non-commercial use only, subject to the. Please contact us if you find any potential violation.
204
- """)
205
-
206
- block_css = """
207
- #buttons button {
208
- min-width: min(120px,100%);
209
- }
210
- """
211
-
212
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
213
- with gr.Blocks(title="llava", theme=gr.themes.Default(), css=block_css) as demo:
214
- state = gr.State()
215
-
216
- gr.Markdown(title_markdown)
217
-
218
- with gr.Row():
219
- with gr.Column(scale=3):
220
- imagebox = gr.Image(label="Input Image", type="filepath")
221
- image_process_mode = gr.Radio(
222
- ["Crop", "Resize", "Pad", "Default"],
223
- value="Default",
224
- label="Preprocess for non-square image", visible=False)
225
-
226
- cur_dir = os.path.dirname(os.path.abspath(__file__))
227
- # gr.Examples(examples=[
228
- # [f"{cur_dir}/assets/health-insurance.png", "Under which circumstances do I need to be enrolled in mandatory health insurance if I am an international student?"],
229
- # [f"{cur_dir}/assets/leasing-apartment.png", "I don't have any 3rd party renter's insurance now. Do I need to get one for myself?"],
230
- # [f"{cur_dir}/assets/nvidia.jpeg", "Who is the person in the middle?"],
231
- # [f"{cur_dir}/assets/animal-compare.png", "Are these two pictures showing the same kind of animal?"],
232
- # [f"{cur_dir}/assets/georgia-tech.jpeg", "Where is this photo taken?"]
233
- # ], inputs=[imagebox, textbox], cache_examples=False)
234
-
235
-
236
-
237
-
238
- gr.Examples(examples=[
239
- [f"{cur_dir}/assets/animal-compare.png", "Provide a detailed description of the given image."]
240
- ], inputs=[imagebox, textbox], cache_examples=False)
241
-
242
- with gr.Accordion("Parameters", open=False) as parameter_row:
243
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
244
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
245
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
246
-
247
- with gr.Column(scale=8):
248
- chatbot = gr.Chatbot(
249
- elem_id="chatbot",
250
- label="llava Chatbot",
251
- height=650,
252
- layout="panel",
253
- )
254
- with gr.Row():
255
- with gr.Column(scale=8):
256
- textbox.render()
257
- with gr.Column(scale=1, min_width=50):
258
- submit_btn = gr.Button(value="Send", variant="primary")
259
- with gr.Row(elem_id="buttons") as button_row:
260
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
261
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
262
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
263
- stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
264
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
265
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
266
-
267
- gr.Markdown(tos_markdown)
268
- gr.Markdown(learn_more_markdown)
269
- url_params = gr.JSON(visible=False)
270
-
271
- # Register listeners
272
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
273
- upvote_btn.click(
274
- upvote_last_response,
275
- [state],
276
- [textbox, upvote_btn, downvote_btn, flag_btn]
277
- )
278
- downvote_btn.click(
279
- downvote_last_response,
280
- [state],
281
- [textbox, upvote_btn, downvote_btn, flag_btn]
282
- )
283
- flag_btn.click(
284
- flag_last_response,
285
- [state],
286
- [textbox, upvote_btn, downvote_btn, flag_btn]
287
- )
288
-
289
- clear_btn.click(
290
- clear_history,
291
- None,
292
- [state, chatbot, textbox, imagebox] + btn_list,
293
- queue=False
294
- )
295
-
296
- regenerate_btn.click(
297
- delete_text,
298
- [state, image_process_mode],
299
- [state, chatbot, textbox, imagebox] + btn_list,
300
- ).then(
301
- generate,
302
- [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
303
- [state, chatbot, textbox, imagebox] + btn_list,
304
- )
305
- textbox.submit(
306
- add_text,
307
- [state, imagebox, textbox, image_process_mode],
308
- [state, chatbot, textbox, imagebox] + btn_list,
309
- ).then(
310
- generate,
311
- [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
312
- [state, chatbot, textbox, imagebox] + btn_list,
313
- )
314
-
315
- submit_btn.click(
316
- add_text,
317
- [state, imagebox, textbox, image_process_mode],
318
- [state, chatbot, textbox, imagebox] + btn_list,
319
- ).then(
320
- generate,
321
- [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
322
- [state, chatbot, textbox, imagebox] + btn_list,
323
- )
324
-
325
- demo.queue(
326
- status_update_rate=10,
327
- api_open=False
328
- ).launch()
329
-
330
-
331
-
332
-
333
-
334
-
335
-
336
-
337
-
338
  # import gradio as gr
339
  # import os
340
  # import torch
341
- # import argparse
342
- # from transformers import TextIteratorStreamer
343
- # from threading import Thread
344
- # from PIL import Image
345
  # from llava import conversation as conversation_lib
346
- # from llava.constants import *
347
  # from llava.conversation import conv_templates, SeparatorStyle
348
  # from llava.model.builder import load_pretrained_model
349
  # from llava.utils import disable_torch_init
350
  # from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
351
- # from diffusers import DiffusionPipeline
352
 
353
- # # Define paths and configurations
354
- # # diffusion_path = "/export/jchen169/hub/models--BAAI--Emu2-Gen/snapshots/a41a2dcd777a68225dddc72c7213b064ee06f4a0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
  # argparser = argparse.ArgumentParser()
357
- # argparser.add_argument("--model-path", default="umd-vt-nyu/clip-evaclip-und-gen-sft-3v", type=str)
 
 
358
  # argparser.add_argument("--conv-mode", type=str, default="llama3")
359
  # argparser.add_argument("--temperature", type=float, default=0.2)
360
  # argparser.add_argument("--max-new-tokens", type=int, default=64)
@@ -362,48 +41,45 @@ demo.queue(
362
  # argparser.add_argument("--load-8bit", action="store_true")
363
  # argparser.add_argument("--load-4bit", action="store_true")
364
  # argparser.add_argument("--debug", action="store_true")
365
- # args = argparser.parse_args()
366
 
367
- # # Load LLaVA model
368
- # disable_torch_init()
 
 
369
  # model_name = get_model_name_from_path(args.model_path)
370
- # tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, None, model_name)
 
 
 
 
 
 
 
371
  # our_chatbot = None
372
 
373
- # # Load Diffusion model for image generation
374
- # pipe = DiffusionPipeline.from_pretrained(
375
- # 'BAAI/Emu2-Gen',
376
- # custom_pipeline="pipeline_llava_gen",
377
- # torch_dtype=torch.bfloat16,
378
- # use_safetensors=True,
379
- # variant="bf16",
380
- # multimodal_encoder=model,
381
- # tokenizer=tokenizer,
382
- # )
383
- # pipe.vae.to("cuda:0")
384
- # pipe.unet.to("cuda:0")
385
- # pipe.safety_checker.to("cuda:0")
386
-
387
  # def upvote_last_response(state):
388
  # return ("",) + (disable_btn,) * 3
389
 
 
390
  # def downvote_last_response(state):
391
  # return ("",) + (disable_btn,) * 3
392
 
 
393
  # def flag_last_response(state):
394
  # return ("",) + (disable_btn,) * 3
395
 
396
  # def clear_history():
397
- # state = conv_templates[conv_mode].copy()
398
  # return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
399
 
400
  # def add_text(state, imagebox, textbox, image_process_mode):
401
  # if state is None:
402
  # state = conv_templates[conv_mode].copy()
403
-
404
  # if imagebox is not None:
405
  # textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox
406
  # image = Image.open(imagebox).convert('RGB')
 
407
  # if imagebox is not None:
408
  # textbox = (textbox, image, image_process_mode)
409
 
@@ -412,9 +88,27 @@ demo.queue(
412
 
413
  # yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  # def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
416
  # prompt = state.get_prompt()
417
  # images = state.get_images(return_pil=True)
 
 
418
  # ori_prompt = prompt
419
  # num_image_tokens = 0
420
 
@@ -423,8 +117,11 @@ demo.queue(
423
  # if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
424
  # raise ValueError("Number of images does not match number of <image> tokens in prompt")
425
 
 
426
  # image_sizes = [image.size for image in images]
 
427
  # images = process_images(images, image_processor, model.config)
 
428
  # if type(images) is list:
429
  # images = [image.to(model.device, dtype=torch.float16) for image in images]
430
  # else:
@@ -444,9 +141,11 @@ demo.queue(
444
 
445
  # input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_IDX, return_tensors='pt').unsqueeze(0).to(model.device)
446
  # streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
 
447
  # max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
448
 
449
  # if max_new_tokens < 1:
 
450
  # return
451
 
452
  # thread = Thread(target=model.generate, kwargs=dict(
@@ -470,25 +169,51 @@ demo.queue(
470
  # yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
471
 
472
  # yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
 
473
  # torch.cuda.empty_cache()
474
 
475
- # def add_template(prompt):
476
- # conv = conv_templates['llama3'].copy()
477
- # conv.append_message(conv.roles[0], prompt[0])
478
- # conv.append_message(conv.roles[1], None)
479
- # prompt = conv.get_prompt()
480
- # return [prompt]
 
 
 
 
 
 
 
 
 
 
481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
- # def generate_image(prompt):
484
- # prompt = add_template(prompt)
485
- # gen_img = pipe(prompt, guidance_scale=3.0)
486
- # return gen_img.image
 
487
 
488
- # # Interface setup
489
- # with gr.Blocks(title="LLaVA Chatbot with Image Generation") as demo:
490
  # state = gr.State()
491
- # gr.Markdown("# LLaVA Chatbot with Image Generation")
 
492
 
493
  # with gr.Row():
494
  # with gr.Column(scale=3):
@@ -497,24 +222,299 @@ demo.queue(
497
  # ["Crop", "Resize", "Pad", "Default"],
498
  # value="Default",
499
  # label="Preprocess for non-square image", visible=False)
500
- # temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature")
501
- # top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P")
502
- # max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens")
503
- # with gr.Column(scale=8):
504
- # chatbot = gr.Chatbot(label="LLaVA Chatbot", height=650, layout="panel")
505
- # textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
506
- # submit_btn = gr.Button(value="Send", variant="primary")
 
 
507
 
508
- # with gr.Row() as button_row:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  # clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
510
-
511
- # # Define actions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  # submit_btn.click(
513
- # lambda state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens: (
514
- # generate_image([textbox]) if "generate image" in textbox.lower() else add_text(
515
- # state, imagebox, textbox, image_process_mode)),
 
 
516
  # [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
517
- # [state, chatbot, textbox, imagebox]
518
  # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
 
520
- # demo.queue(status_update_rate=10, api_open=False).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # import gradio as gr
2
  # import os
3
  # import torch
4
+ # import spaces
5
+
 
 
6
  # from llava import conversation as conversation_lib
7
+ # from llava.constants import IMAGE_TOKEN_IDX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
8
  # from llava.conversation import conv_templates, SeparatorStyle
9
  # from llava.model.builder import load_pretrained_model
10
  # from llava.utils import disable_torch_init
11
  # from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
 
12
 
13
+ # from PIL import Image
14
+ # import argparse
15
+
16
+
17
+ # from transformers import TextIteratorStreamer
18
+ # from threading import Thread
19
+
20
+ # import subprocess
21
+ # # Install flash attention, skipping CUDA build if necessary
22
+ # subprocess.run(
23
+ # "pip install flash-attn --no-build-isolation",
24
+ # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
25
+ # shell=True,
26
+ # )
27
+
28
+ # # os.environ['GRADIO_TEMP_DIR'] = './gradio_tmp'
29
+ # no_change_btn = gr.Button()
30
+ # enable_btn = gr.Button(interactive=True)
31
+ # disable_btn = gr.Button(interactive=False)
32
 
33
  # argparser = argparse.ArgumentParser()
34
+ # argparser.add_argument("--model-path", default="umd-vt-nyu/clip-evaclip-und-gen-sft", type=str)
35
+ # argparser.add_argument("--model-base", type=str, default=None)
36
+ # argparser.add_argument("--num-gpus", type=int, default=1)
37
  # argparser.add_argument("--conv-mode", type=str, default="llama3")
38
  # argparser.add_argument("--temperature", type=float, default=0.2)
39
  # argparser.add_argument("--max-new-tokens", type=int, default=64)
 
41
  # argparser.add_argument("--load-8bit", action="store_true")
42
  # argparser.add_argument("--load-4bit", action="store_true")
43
  # argparser.add_argument("--debug", action="store_true")
 
44
 
45
+ # args = argparser.parse_args()
46
+ # model_path = args.model_path
47
+ # conv_mode = args.conv_mode
48
+ # filt_invalid="cut"
49
  # model_name = get_model_name_from_path(args.model_path)
50
+ # model_name = 'clip-evaclip-und-gen-sft'
51
+ # model_kwargs = {
52
+ # "use_cache": False,
53
+ # "trust_remote_code": True,
54
+ # "torch_dtype": torch.bfloat16,
55
+ # "attn_implementation": "sdpa"
56
+ # }
57
+ # tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map="cuda:0", **model_kwargs)
58
  # our_chatbot = None
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # def upvote_last_response(state):
61
  # return ("",) + (disable_btn,) * 3
62
 
63
+
64
  # def downvote_last_response(state):
65
  # return ("",) + (disable_btn,) * 3
66
 
67
+
68
  # def flag_last_response(state):
69
  # return ("",) + (disable_btn,) * 3
70
 
71
  # def clear_history():
72
+ # state =conv_templates[conv_mode].copy()
73
  # return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
74
 
75
  # def add_text(state, imagebox, textbox, image_process_mode):
76
  # if state is None:
77
  # state = conv_templates[conv_mode].copy()
78
+
79
  # if imagebox is not None:
80
  # textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox
81
  # image = Image.open(imagebox).convert('RGB')
82
+
83
  # if imagebox is not None:
84
  # textbox = (textbox, image, image_process_mode)
85
 
 
88
 
89
  # yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
90
 
91
+ # def delete_text(state, image_process_mode):
92
+ # state.messages[-1][-1] = None
93
+ # prev_human_msg = state.messages[-2]
94
+ # if type(prev_human_msg[1]) in (tuple, list):
95
+ # prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
96
+ # yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
97
+
98
+ # def regenerate(state, image_process_mode):
99
+ # state.messages[-1][-1] = None
100
+ # prev_human_msg = state.messages[-2]
101
+ # if type(prev_human_msg[1]) in (tuple, list):
102
+ # prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
103
+ # state.skip_next = False
104
+ # return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
105
+
106
+ # @spaces.GPU
107
  # def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
108
  # prompt = state.get_prompt()
109
  # images = state.get_images(return_pil=True)
110
+ # #prompt, image_args = process_image(prompt, images)
111
+
112
  # ori_prompt = prompt
113
  # num_image_tokens = 0
114
 
 
117
  # if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
118
  # raise ValueError("Number of images does not match number of <image> tokens in prompt")
119
 
120
+ # #images = [load_image_from_base64(image) for image in images]
121
  # image_sizes = [image.size for image in images]
122
+
123
  # images = process_images(images, image_processor, model.config)
124
+
125
  # if type(images) is list:
126
  # images = [image.to(model.device, dtype=torch.float16) for image in images]
127
  # else:
 
141
 
142
  # input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_IDX, return_tensors='pt').unsqueeze(0).to(model.device)
143
  # streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
144
+
145
  # max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
146
 
147
  # if max_new_tokens < 1:
148
+ # # yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
149
  # return
150
 
151
  # thread = Thread(target=model.generate, kwargs=dict(
 
169
  # yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
170
 
171
  # yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
172
+
173
  # torch.cuda.empty_cache()
174
 
175
+ # txt = gr.Textbox(
176
+ # scale=4,
177
+ # show_label=False,
178
+ # placeholder="Enter text and press enter.",
179
+ # container=False,
180
+ # )
181
+
182
+
183
+ # title_markdown = ("""
184
+ # # Florence-llama
185
+ # [[Code](TBD)] [[Model](TBD)] | 📚 [[Arxiv](TBD)]]
186
+ # """)
187
+
188
+ # # title_markdown = ("""
189
+ # # # Florence-llama
190
+ # # """)
191
 
192
+ # tos_markdown = ("""
193
+ # ### Terms of use
194
+ # By using this service, users are required to agree to the following terms:
195
+ # The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
196
+ # Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
197
+ # For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
198
+ # """)
199
+
200
+
201
+ # learn_more_markdown = ("""
202
+ # ### License
203
+ # The service is a research preview intended for non-commercial use only, subject to the. Please contact us if you find any potential violation.
204
+ # """)
205
 
206
+ # block_css = """
207
+ # #buttons button {
208
+ # min-width: min(120px,100%);
209
+ # }
210
+ # """
211
 
212
+ # textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
213
+ # with gr.Blocks(title="llava", theme=gr.themes.Default(), css=block_css) as demo:
214
  # state = gr.State()
215
+
216
+ # gr.Markdown(title_markdown)
217
 
218
  # with gr.Row():
219
  # with gr.Column(scale=3):
 
222
  # ["Crop", "Resize", "Pad", "Default"],
223
  # value="Default",
224
  # label="Preprocess for non-square image", visible=False)
225
+
226
+ # cur_dir = os.path.dirname(os.path.abspath(__file__))
227
+ # # gr.Examples(examples=[
228
+ # # [f"{cur_dir}/assets/health-insurance.png", "Under which circumstances do I need to be enrolled in mandatory health insurance if I am an international student?"],
229
+ # # [f"{cur_dir}/assets/leasing-apartment.png", "I don't have any 3rd party renter's insurance now. Do I need to get one for myself?"],
230
+ # # [f"{cur_dir}/assets/nvidia.jpeg", "Who is the person in the middle?"],
231
+ # # [f"{cur_dir}/assets/animal-compare.png", "Are these two pictures showing the same kind of animal?"],
232
+ # # [f"{cur_dir}/assets/georgia-tech.jpeg", "Where is this photo taken?"]
233
+ # # ], inputs=[imagebox, textbox], cache_examples=False)
234
 
235
+
236
+
237
+
238
+ # gr.Examples(examples=[
239
+ # [f"{cur_dir}/assets/animal-compare.png", "Provide a detailed description of the given image."]
240
+ # ], inputs=[imagebox, textbox], cache_examples=False)
241
+
242
+ # with gr.Accordion("Parameters", open=False) as parameter_row:
243
+ # temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
244
+ # top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
245
+ # max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
246
+
247
+ # with gr.Column(scale=8):
248
+ # chatbot = gr.Chatbot(
249
+ # elem_id="chatbot",
250
+ # label="llava Chatbot",
251
+ # height=650,
252
+ # layout="panel",
253
+ # )
254
+ # with gr.Row():
255
+ # with gr.Column(scale=8):
256
+ # textbox.render()
257
+ # with gr.Column(scale=1, min_width=50):
258
+ # submit_btn = gr.Button(value="Send", variant="primary")
259
+ # with gr.Row(elem_id="buttons") as button_row:
260
+ # upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
261
+ # downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
262
+ # flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
263
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
264
+ # regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
265
  # clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
266
+
267
+ # gr.Markdown(tos_markdown)
268
+ # gr.Markdown(learn_more_markdown)
269
+ # url_params = gr.JSON(visible=False)
270
+
271
+ # # Register listeners
272
+ # btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
273
+ # upvote_btn.click(
274
+ # upvote_last_response,
275
+ # [state],
276
+ # [textbox, upvote_btn, downvote_btn, flag_btn]
277
+ # )
278
+ # downvote_btn.click(
279
+ # downvote_last_response,
280
+ # [state],
281
+ # [textbox, upvote_btn, downvote_btn, flag_btn]
282
+ # )
283
+ # flag_btn.click(
284
+ # flag_last_response,
285
+ # [state],
286
+ # [textbox, upvote_btn, downvote_btn, flag_btn]
287
+ # )
288
+
289
+ # clear_btn.click(
290
+ # clear_history,
291
+ # None,
292
+ # [state, chatbot, textbox, imagebox] + btn_list,
293
+ # queue=False
294
+ # )
295
+
296
+ # regenerate_btn.click(
297
+ # delete_text,
298
+ # [state, image_process_mode],
299
+ # [state, chatbot, textbox, imagebox] + btn_list,
300
+ # ).then(
301
+ # generate,
302
+ # [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
303
+ # [state, chatbot, textbox, imagebox] + btn_list,
304
+ # )
305
+ # textbox.submit(
306
+ # add_text,
307
+ # [state, imagebox, textbox, image_process_mode],
308
+ # [state, chatbot, textbox, imagebox] + btn_list,
309
+ # ).then(
310
+ # generate,
311
+ # [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
312
+ # [state, chatbot, textbox, imagebox] + btn_list,
313
+ # )
314
+
315
  # submit_btn.click(
316
+ # add_text,
317
+ # [state, imagebox, textbox, image_process_mode],
318
+ # [state, chatbot, textbox, imagebox] + btn_list,
319
+ # ).then(
320
+ # generate,
321
  # [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
322
+ # [state, chatbot, textbox, imagebox] + btn_list,
323
  # )
324
+
325
+ # demo.queue(
326
+ # status_update_rate=10,
327
+ # api_open=False
328
+ # ).launch()
329
+
330
+
331
+
332
+
333
+
334
+
335
+
336
+
337
+
338
+ import gradio as gr
339
+ import os
340
+ import torch
341
+ import argparse
342
+ from transformers import TextIteratorStreamer
343
+ from threading import Thread
344
+ from PIL import Image
345
+ from llava import conversation as conversation_lib
346
+ from llava.constants import *
347
+ from llava.conversation import conv_templates, SeparatorStyle
348
+ from llava.model.builder import load_pretrained_model
349
+ from llava.utils import disable_torch_init
350
+ from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
351
+ from diffusers import DiffusionPipeline
352
+
353
+ # Define paths and configurations
354
+ # diffusion_path = "/export/jchen169/hub/models--BAAI--Emu2-Gen/snapshots/a41a2dcd777a68225dddc72c7213b064ee06f4a0"
355
+
356
+ argparser = argparse.ArgumentParser()
357
+ argparser.add_argument("--model-path", default="umd-vt-nyu/clip-evaclip-und-gen-sft-3v", type=str)
358
+ argparser.add_argument("--conv-mode", type=str, default="llama3")
359
+ argparser.add_argument("--temperature", type=float, default=0.2)
360
+ argparser.add_argument("--max-new-tokens", type=int, default=64)
361
+ argparser.add_argument("--num_frames", type=int, default=16)
362
+ argparser.add_argument("--load-8bit", action="store_true")
363
+ argparser.add_argument("--load-4bit", action="store_true")
364
+ argparser.add_argument("--debug", action="store_true")
365
+ args = argparser.parse_args()
366
+
367
+ # Load LLaVA model
368
+ disable_torch_init()
369
+ model_name = get_model_name_from_path(args.model_path)
370
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, None, model_name)
371
+ our_chatbot = None
372
+
373
+ # Load Diffusion model for image generation
374
+ pipe = DiffusionPipeline.from_pretrained(
375
+ 'BAAI/Emu2-Gen',
376
+ custom_pipeline="pipeline_llava_gen",
377
+ torch_dtype=torch.bfloat16,
378
+ use_safetensors=True,
379
+ variant="bf16",
380
+ multimodal_encoder=model,
381
+ tokenizer=tokenizer,
382
+ )
383
+ pipe.vae.to("cuda:0")
384
+ pipe.unet.to("cuda:0")
385
+ pipe.safety_checker.to("cuda:0")
386
+
387
+ def upvote_last_response(state):
388
+ return ("",) + (disable_btn,) * 3
389
+
390
+ def downvote_last_response(state):
391
+ return ("",) + (disable_btn,) * 3
392
+
393
+ def flag_last_response(state):
394
+ return ("",) + (disable_btn,) * 3
395
+
396
+ def clear_history():
397
+ state = conv_templates[conv_mode].copy()
398
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
399
+
400
+ def add_text(state, imagebox, textbox, image_process_mode):
401
+ if state is None:
402
+ state = conv_templates[conv_mode].copy()
403
+
404
+ if imagebox is not None:
405
+ textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox
406
+ image = Image.open(imagebox).convert('RGB')
407
+ if imagebox is not None:
408
+ textbox = (textbox, image, image_process_mode)
409
+
410
+ state.append_message(state.roles[0], textbox)
411
+ state.append_message(state.roles[1], None)
412
+
413
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
414
+
415
+ def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
416
+ prompt = state.get_prompt()
417
+ images = state.get_images(return_pil=True)
418
+ ori_prompt = prompt
419
+ num_image_tokens = 0
420
+
421
+ if images is not None and len(images) > 0:
422
+ if len(images) > 0:
423
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
424
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
425
+
426
+ image_sizes = [image.size for image in images]
427
+ images = process_images(images, image_processor, model.config)
428
+ if type(images) is list:
429
+ images = [image.to(model.device, dtype=torch.float16) for image in images]
430
+ else:
431
+ images = images.to(model.device, dtype=torch.float16)
432
+ else:
433
+ images = None
434
+ image_sizes = None
435
+ image_args = {"images": images, "image_sizes": image_sizes}
436
+ else:
437
+ images = None
438
+ image_args = {}
439
+
440
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
441
+ max_new_tokens = 512
442
+ do_sample = True if temperature > 0.001 else False
443
+ stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
444
+
445
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_IDX, return_tensors='pt').unsqueeze(0).to(model.device)
446
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
447
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
448
+
449
+ if max_new_tokens < 1:
450
+ return
451
+
452
+ thread = Thread(target=model.generate, kwargs=dict(
453
+ inputs=input_ids,
454
+ do_sample=do_sample,
455
+ temperature=temperature,
456
+ top_p=top_p,
457
+ max_new_tokens=max_new_tokens,
458
+ streamer=streamer,
459
+ use_cache=True,
460
+ pad_token_id=tokenizer.eos_token_id,
461
+ **image_args
462
+ ))
463
+ thread.start()
464
+ generated_text = ''
465
+ for new_text in streamer:
466
+ generated_text += new_text
467
+ if generated_text.endswith(stop_str):
468
+ generated_text = generated_text[:-len(stop_str)]
469
+ state.messages[-1][-1] = generated_text
470
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
471
+
472
+ yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
473
+ torch.cuda.empty_cache()
474
+
475
+ def add_template(prompt):
476
+ conv = conv_templates['llama3'].copy()
477
+ conv.append_message(conv.roles[0], prompt[0])
478
+ conv.append_message(conv.roles[1], None)
479
+ prompt = conv.get_prompt()
480
+ return [prompt]
481
+
482
+
483
+ def generate_image(prompt):
484
+ prompt = add_template(prompt)
485
+ gen_img = pipe(prompt, guidance_scale=3.0)
486
+ return gen_img.image
487
+
488
+ # Interface setup
489
+ with gr.Blocks(title="LLaVA Chatbot with Image Generation") as demo:
490
+ state = gr.State()
491
+ gr.Markdown("# LLaVA Chatbot with Image Generation")
492
+
493
+ with gr.Row():
494
+ with gr.Column(scale=3):
495
+ imagebox = gr.Image(label="Input Image", type="filepath")
496
+ image_process_mode = gr.Radio(
497
+ ["Crop", "Resize", "Pad", "Default"],
498
+ value="Default",
499
+ label="Preprocess for non-square image", visible=False)
500
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature")
501
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P")
502
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens")
503
+ with gr.Column(scale=8):
504
+ chatbot = gr.Chatbot(label="LLaVA Chatbot", height=650, layout="panel")
505
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
506
+ submit_btn = gr.Button(value="Send", variant="primary")
507
+
508
+ with gr.Row() as button_row:
509
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
510
+
511
+ # Define actions
512
+ submit_btn.click(
513
+ lambda state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens: (
514
+ generate_image([textbox]) if "generate image" in textbox.lower() else add_text(
515
+ state, imagebox, textbox, image_process_mode)),
516
+ [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
517
+ [state, chatbot, textbox, imagebox]
518
+ )
519
 
520
+ demo.queue(status_update_rate=10, api_open=False).launch()
llava/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/llava/__pycache__/__init__.cpython-310.pyc and b/llava/__pycache__/__init__.cpython-310.pyc differ
 
llava/__pycache__/constants.cpython-310.pyc CHANGED
Binary files a/llava/__pycache__/constants.cpython-310.pyc and b/llava/__pycache__/constants.cpython-310.pyc differ
 
llava/__pycache__/conversation.cpython-310.pyc CHANGED
Binary files a/llava/__pycache__/conversation.cpython-310.pyc and b/llava/__pycache__/conversation.cpython-310.pyc differ
 
llava/__pycache__/mm_utils.cpython-310.pyc CHANGED
Binary files a/llava/__pycache__/mm_utils.cpython-310.pyc and b/llava/__pycache__/mm_utils.cpython-310.pyc differ
 
llava/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/llava/__pycache__/utils.cpython-310.pyc and b/llava/__pycache__/utils.cpython-310.pyc differ
 
llava/model/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/llava/model/__pycache__/__init__.cpython-310.pyc and b/llava/model/__pycache__/__init__.cpython-310.pyc differ
 
llava/model/__pycache__/builder.cpython-310.pyc CHANGED
Binary files a/llava/model/__pycache__/builder.cpython-310.pyc and b/llava/model/__pycache__/builder.cpython-310.pyc differ
 
llava/model/__pycache__/llava_arch.cpython-310.pyc CHANGED
Binary files a/llava/model/__pycache__/llava_arch.cpython-310.pyc and b/llava/model/__pycache__/llava_arch.cpython-310.pyc differ
 
llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc CHANGED
Binary files a/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc and b/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc differ
 
llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc CHANGED
Binary files a/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc and b/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc differ
 
llava/model/language_model/__pycache__/llava_mpt.cpython-310.pyc CHANGED
Binary files a/llava/model/language_model/__pycache__/llava_mpt.cpython-310.pyc and b/llava/model/language_model/__pycache__/llava_mpt.cpython-310.pyc differ
 
llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc and b/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc differ
 
llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc and b/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc differ
 
llava/model/multimodal_encoder/__pycache__/imagebind.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/__pycache__/imagebind.cpython-310.pyc and b/llava/model/multimodal_encoder/__pycache__/imagebind.cpython-310.pyc differ
 
llava/model/multimodal_encoder/__pycache__/open_clip_encoder.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/__pycache__/open_clip_encoder.cpython-310.pyc and b/llava/model/multimodal_encoder/__pycache__/open_clip_encoder.cpython-310.pyc differ
 
llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc and b/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/__pycache__/eva_vit.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/__pycache__/eva_vit.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/__pycache__/eva_vit.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/__init__.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/__init__.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/constants.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/constants.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/constants.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/eva_vit_model.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/eva_vit_model.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/eva_vit_model.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/factory.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/factory.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/factory.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_configs.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_configs.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_configs.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_model.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_model.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_model.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/loss.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/loss.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/loss.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/model.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/model.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/model.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/modified_resnet.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/modified_resnet.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/modified_resnet.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/openai.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/openai.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/openai.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/pretrained.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/pretrained.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/pretrained.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/rope.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/rope.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/rope.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/timm_model.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/timm_model.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/timm_model.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/tokenizer.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/tokenizer.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/tokenizer.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transform.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transform.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transform.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transformer.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transformer.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transformer.cpython-310.pyc differ
 
llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/utils.cpython-310.pyc and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/utils.cpython-310.pyc differ
 
llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_encoder.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_encoder.cpython-310.pyc and b/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_encoder.cpython-310.pyc differ
 
llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_processors.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_processors.cpython-310.pyc and b/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_processors.cpython-310.pyc differ
 
llava/model/multimodal_encoder/eva_clip/__pycache__/eva_vit.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_vit.cpython-310.pyc and b/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_vit.cpython-310.pyc differ
 
llava/model/multimodal_encoder/eva_clip/__pycache__/factory.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_encoder/eva_clip/__pycache__/factory.cpython-310.pyc and b/llava/model/multimodal_encoder/eva_clip/__pycache__/factory.cpython-310.pyc differ
 
llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc CHANGED
Binary files a/llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc and b/llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc differ
 
llava/train/__pycache__/llava_trainer.cpython-310.pyc CHANGED
Binary files a/llava/train/__pycache__/llava_trainer.cpython-310.pyc and b/llava/train/__pycache__/llava_trainer.cpython-310.pyc differ
 
llava/train/__pycache__/train.cpython-310.pyc CHANGED
Binary files a/llava/train/__pycache__/train.cpython-310.pyc and b/llava/train/__pycache__/train.cpython-310.pyc differ