jadechoghari commited on
Commit
151137d
β€’
1 Parent(s): e3feb3f

working app

Browse files
8b23f327b90b6211049acd36e3f99975.jpg DELETED
Binary file (24.4 kB)
 
=0.26.0 DELETED
@@ -1,39 +0,0 @@
1
- Collecting accelerate
2
- Downloading accelerate-1.0.1-py3-none-any.whl (330 kB)
3
- ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 330.9/330.9 kB 10.6 MB/s eta 0:00:00
4
- Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/site-packages (from accelerate) (2.5.0)
5
- Requirement already satisfied: numpy<3.0.0,>=1.17 in /usr/local/lib/python3.10/site-packages (from accelerate) (2.1.2)
6
- Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.10/site-packages (from accelerate) (0.26.0)
7
- Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/site-packages (from accelerate) (6.0.2)
8
- Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/site-packages (from accelerate) (24.1)
9
- Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/site-packages (from accelerate) (0.4.5)
10
- Requirement already satisfied: psutil in /usr/local/lib/python3.10/site-packages (from accelerate) (5.9.8)
11
- Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2024.6.1)
12
- Requirement already satisfied: requests in /usr/local/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)
13
- Requirement already satisfied: filelock in /usr/local/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (3.16.1)
14
- Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.66.5)
15
- Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)
16
- Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
17
- Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.0)
18
- Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.6.1.9)
19
- Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (2.21.5)
20
- Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.3.1.170)
21
- Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.4)
22
- Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
23
- Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
24
- Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (9.1.0.70)
25
- Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
26
- Requirement already satisfied: networkx in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.4.1)
27
- Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.127)
28
- Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.4.5.8)
29
- Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (1.13.1)
30
- Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.2.1.3)
31
- Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (10.3.5.147)
32
- Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/site-packages (from sympy==1.13.1->torch>=1.10.0->accelerate) (1.3.0)
33
- Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/site-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)
34
- Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.8.30)
35
- Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.4.0)
36
- Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2.2.3)
37
- Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.10)
38
- Installing collected packages: accelerate
39
- Successfully installed accelerate-1.0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__pycache__/cli.cpython-310.pyc ADDED
Binary file (3.49 kB). View file
 
__pycache__/constants.cpython-310.pyc CHANGED
Binary files a/__pycache__/constants.cpython-310.pyc and b/__pycache__/constants.cpython-310.pyc differ
 
__pycache__/gradio_css.cpython-310.pyc CHANGED
Binary files a/__pycache__/gradio_css.cpython-310.pyc and b/__pycache__/gradio_css.cpython-310.pyc differ
 
app.py CHANGED
@@ -1,525 +1,135 @@
1
- import argparse
2
- import datetime
3
- import json
4
- import os
5
- import time
6
-
7
  import gradio as gr
8
- import requests
9
  from inference import inference_and_run
10
- from conversation import (default_conversation, conv_templates,
11
- SeparatorStyle)
12
-
13
- LOGDIR = "."
14
- from utils import (build_logger, server_error_msg,
15
- violates_moderation, moderation_msg)
16
- import hashlib
17
  import spaces
 
 
 
18
 
19
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
20
-
21
- headers = {"User-Agent": "LLaVA Client"}
22
-
23
- no_change_btn = gr.Button()
24
- enable_btn = gr.Button(interactive=True)
25
- disable_btn = gr.Button(interactive=False)
26
-
27
- priority = {
28
- "vicuna-13b": "aaaaaaa",
29
- "koala-13b": "aaaaaab",
30
- }
31
-
32
-
33
- def get_conv_log_filename():
34
- t = datetime.datetime.now()
35
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
36
- return name
37
-
38
-
39
- def get_model_list():
40
- # ret = requests.post(args.controller_url + "/refresh_all_workers")
41
- # assert ret.status_code == 200
42
- # ret = requests.post(args.controller_url + "/list_models")
43
- # models = ret.json()["models"]
44
- # models.sort(key=lambda x: priority.get(x, x))
45
- # logger.info(f"Models: {models}")
46
- # return models
47
- models = ["jadechoghari/Ferret-UI-Gemma2b"]
48
- logger.info(f"Models: {models}")
49
- return models
50
-
51
- get_window_url_params = """
52
- function() {
53
- const params = new URLSearchParams(window.location.search);
54
- url_params = Object.fromEntries(params);
55
- console.log(url_params);
56
- return url_params;
57
- }
58
- """
59
-
60
-
61
- def load_demo(url_params, request: gr.Request):
62
-
63
- dropdown_update = gr.Dropdown(visible=True)
64
- if "model" in url_params:
65
- model = url_params["model"]
66
- if model in models:
67
- dropdown_update = gr.Dropdown(value=model, visible=True)
68
-
69
- state = default_conversation.copy()
70
- return state, dropdown_update
71
-
72
-
73
- def load_demo_refresh_model_list(request: gr.Request):
74
- models = get_model_list()
75
- state = default_conversation.copy()
76
- dropdown_update = gr.Dropdown(
77
- choices=models,
78
- value=models[0] if len(models) > 0 else ""
79
- )
80
- return state, dropdown_update
81
-
82
-
83
- def vote_last_response(state, vote_type, model_selector, request: gr.Request):
84
- with open(get_conv_log_filename(), "a") as fout:
85
- data = {
86
- "tstamp": round(time.time(), 4),
87
- "type": vote_type,
88
- "model": model_selector,
89
- "state": state.dict(),
90
- "ip": request.client.host,
91
- }
92
- fout.write(json.dumps(data) + "\n")
93
-
94
-
95
- def upvote_last_response(state, model_selector, request: gr.Request):
96
- vote_last_response(state, "upvote", model_selector, request)
97
- return ("",) + (disable_btn,) * 3
98
-
99
-
100
- def downvote_last_response(state, model_selector, request: gr.Request):
101
- vote_last_response(state, "downvote", model_selector, request)
102
- return ("",) + (disable_btn,) * 3
103
-
104
-
105
- def flag_last_response(state, model_selector, request: gr.Request):
106
- vote_last_response(state, "flag", model_selector, request)
107
- return ("",) + (disable_btn,) * 3
108
-
109
-
110
- def regenerate(state, image_process_mode, request: gr.Request):
111
- state.messages[-1][-1] = None
112
- prev_human_msg = state.messages[-2]
113
- if type(prev_human_msg[1]) in (tuple, list):
114
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
115
- state.skip_next = False
116
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
117
-
118
-
119
- def clear_history(request: gr.Request):
120
- state = default_conversation.copy()
121
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
122
-
123
-
124
- def add_text(state, text, image, image_process_mode, request: gr.Request):
125
- if len(text) <= 0 and image is None:
126
- state.skip_next = True
127
- return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
128
- if args.moderate:
129
- flagged = violates_moderation(text)
130
- if flagged:
131
- state.skip_next = True
132
- return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
133
- no_change_btn,) * 5
134
-
135
- text = text[:1536] # Hard cut-off
136
- if image is not None:
137
- text = text[:1200] # Hard cut-off for images
138
- if '<image>' not in text:
139
- # text = '<Image><image></Image>' + text
140
- text = text + '\n<image>'
141
- text = (text, image, image_process_mode)
142
- state = default_conversation.copy()
143
- state.append_message(state.roles[0], text)
144
- state.append_message(state.roles[1], None)
145
- state.skip_next = False
146
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
147
 
148
  @spaces.GPU()
149
- def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
150
- start_tstamp = time.time()
151
- model_name = model_selector
152
-
153
- if state.skip_next:
154
- # This generate call is skipped due to invalid inputs
155
- yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
156
- return
157
-
158
- if len(state.messages) == state.offset + 2:
159
- # First round of conversation
160
- if "llava" in model_name.lower():
161
- if 'llama-2' in model_name.lower():
162
- template_name = "llava_llama_2"
163
- elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
164
- if 'orca' in model_name.lower():
165
- template_name = "mistral_orca"
166
- elif 'hermes' in model_name.lower():
167
- template_name = "chatml_direct"
168
- else:
169
- template_name = "mistral_instruct"
170
- elif 'llava-v1.6-34b' in model_name.lower():
171
- template_name = "chatml_direct"
172
- elif "v1" in model_name.lower():
173
- if 'mmtag' in model_name.lower():
174
- template_name = "v1_mmtag"
175
- elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
176
- template_name = "v1_mmtag"
177
- else:
178
- template_name = "llava_v1"
179
- elif "mpt" in model_name.lower():
180
- template_name = "mpt"
181
- else:
182
- if 'mmtag' in model_name.lower():
183
- template_name = "v0_mmtag"
184
- elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
185
- template_name = "v0_mmtag"
186
- else:
187
- template_name = "llava_v0"
188
- elif "mpt" in model_name:
189
- template_name = "mpt_text"
190
- elif "llama-2" in model_name:
191
- template_name = "llama_2"
192
- elif "gemma" in model_name.lower():
193
- template_name = "ferret_gemma_instruct"
194
- print("conv mode to gemma")
195
- else:
196
- template_name = "vicuna_v1"
197
- new_state = conv_templates[template_name].copy()
198
- new_state.append_message(new_state.roles[0], state.messages[-2][1])
199
- new_state.append_message(new_state.roles[1], None)
200
- state = new_state
201
-
202
- # # Query worker address
203
- # controller_url = args.controller_url
204
- # ret = requests.post(controller_url + "/get_worker_address",
205
- # json={"model": model_name})
206
- # worker_addr = ret.json()["address"]
207
- # logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
208
-
209
-
210
-
211
- # # No available worker
212
- # if worker_addr == "":
213
- # state.messages[-1][-1] = server_error_msg
214
- # yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
215
- # return
216
-
217
- # Construct prompt
218
- prompt = state.get_prompt()
219
  dir_path = "./"
220
- all_images = state.get_images(return_pil=True)
221
- all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
222
- for image, hash in zip(all_images, all_image_hash):
223
- t = datetime.datetime.now()
224
- # dir_path = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}")
225
- # filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
226
- # filename = os.path.join(dir_path, f"{hash}.jpg")
227
- filename = os.path.join(dir_path, f"{hash}.jpg")
228
- if not os.path.isfile(filename):
229
- os.makedirs(os.path.dirname(filename), exist_ok=True)
230
- image.save(filename)
231
-
232
- # Make requests
233
- pload = {
234
- "model": model_name,
235
- "prompt": prompt,
236
- "temperature": float(temperature),
237
- "top_p": float(top_p),
238
- "max_new_tokens": min(int(max_new_tokens), 1536),
239
- "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
240
- "images": f'List of {len(state.get_images())} images: {all_image_hash}',
241
- }
242
- logger.info(f"==== request ====\n{pload}")
243
-
244
- pload['images'] = state.get_images()
245
 
246
- state.messages[-1][-1] = "β–Œ"
247
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
248
-
249
- try:
250
- # Stream output
251
- # response = requests.post(worker_addr + "/worker_generate_stream",
252
- # headers=headers, json=pload, stream=True, timeout=10)
253
- stop = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
254
- #TODO: define inference and run function
255
- extracted_texts = inference_and_run(
 
 
 
 
 
 
 
256
  image_path=filename, # double check this
257
  image_dir=dir_path,
258
  prompt=prompt,
259
- model_path=model_name,
260
- conv_mode="ferret_gemma_instruct", # Default mode from the original function
261
- temperature=temperature,
262
- top_p=top_p,
263
- max_new_tokens=max_new_tokens,
264
- stop=stop # Assuming we want to process the image
265
  )
266
- response = extracted_texts
267
- logger.info(f"This is the respone {response}")
268
- delay=0.01
269
- # for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
270
- # if chunk:
271
- # data = json.loads(chunk.decode())
272
- # if data["error_code"] == 0:
273
- # output = data["text"][len(prompt):].strip()
274
- # state.messages[-1][-1] = output + "β–Œ"
275
- # yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
276
- # else:
277
- # output = data["text"] + f" (error_code: {data['error_code']})"
278
- # state.messages[-1][-1] = output
279
- # yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
280
- # return
281
- # time.sleep(0.03)
282
- text = response[0]
283
- output = "" # Will hold the progressively built output
284
- for i, char in enumerate(text):
285
- output += char
286
- state.messages[-1][-1] = output + "β–Œ" # Add cursor β–Œ at the end
287
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
288
- time.sleep(delay) # Control typing speed with delay
289
- except requests.exceptions.RequestException as e:
290
- state.messages[-1][-1] = server_error_msg
291
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
292
- return
293
-
294
- state.messages[-1][-1] = state.messages[-1][-1][:-1]
295
- yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
296
-
297
- finish_tstamp = time.time()
298
- logger.info(f"{output}")
299
-
300
- with open(get_conv_log_filename(), "a") as fout:
301
- data = {
302
- "tstamp": round(finish_tstamp, 4),
303
- "type": "chat",
304
- "model": model_name,
305
- "start": round(start_tstamp, 4),
306
- "finish": round(finish_tstamp, 4),
307
- "state": state.dict(),
308
- "images": all_image_hash,
309
- "ip": request.client.host,
310
- }
311
- fout.write(json.dumps(data) + "\n")
312
-
313
- title_markdown = ("""
314
- # πŸŒ‹ LLaVA: Large Language and Vision Assistant
315
- [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | πŸ“š [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] [[LLaVA-v1.6](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/)]
316
- """)
317
-
318
- tos_markdown = ("""
319
- ### Terms of use
320
- By using this service, users are required to agree to the following terms:
321
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
322
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
323
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
324
- """)
325
-
326
-
327
- learn_more_markdown = ("""
328
- ### License
329
- The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
330
- """)
331
-
332
- block_css = """
333
-
334
- #buttons button {
335
- min-width: min(120px,100%);
336
- }
337
-
338
  """
339
 
340
- def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
341
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
342
- with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
343
- state = gr.State()
344
-
345
- if not embed_mode:
346
- gr.Markdown(title_markdown)
347
-
348
- with gr.Row():
349
- models = [
350
- "jadechoghari/Ferret-UI-Gemma2b"
351
- ]
352
- with gr.Column(scale=3):
353
- with gr.Row(elem_id="model_selector_row"):
354
- model_selector = gr.Dropdown(
355
- choices=models,
356
- value=models[0] if len(models) > 0 else "",
357
- interactive=True,
358
- show_label=False,
359
- container=False)
360
-
361
- imagebox = gr.Image(type="pil")
362
- image_process_mode = gr.Radio(
363
- ["Crop", "Resize", "Pad", "Default"],
364
- value="Default",
365
- label="Preprocess for non-square image", visible=False)
366
-
367
- if cur_dir is None:
368
- cur_dir = os.path.dirname(os.path.abspath(__file__))
369
- gr.Examples(examples=[
370
- [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
371
- [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
372
- ], inputs=[imagebox, textbox])
373
-
374
- with gr.Accordion("Parameters", open=False) as parameter_row:
375
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
376
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
377
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
378
-
379
- with gr.Column(scale=8):
380
- chatbot = gr.Chatbot(
381
- elem_id="chatbot",
382
- label="LLaVA Chatbot",
383
- height=650,
384
- layout="panel",
385
- )
386
- with gr.Row():
387
- with gr.Column(scale=8):
388
- textbox.render()
389
- with gr.Column(scale=1, min_width=50):
390
- submit_btn = gr.Button(value="Send", variant="primary")
391
- with gr.Row(elem_id="buttons") as button_row:
392
- upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
393
- downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
394
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
395
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
396
- regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
397
- clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
398
-
399
- if not embed_mode:
400
- gr.Markdown(tos_markdown)
401
- gr.Markdown(learn_more_markdown)
402
- url_params = gr.JSON(visible=False)
403
-
404
- # Register listeners
405
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
406
- upvote_btn.click(
407
- upvote_last_response,
408
- [state, model_selector],
409
- [textbox, upvote_btn, downvote_btn, flag_btn]
410
- )
411
- downvote_btn.click(
412
- downvote_last_response,
413
- [state, model_selector],
414
- [textbox, upvote_btn, downvote_btn, flag_btn]
415
- )
416
- flag_btn.click(
417
- flag_last_response,
418
- [state, model_selector],
419
- [textbox, upvote_btn, downvote_btn, flag_btn]
420
- )
421
-
422
- regenerate_btn.click(
423
- regenerate,
424
- [state, image_process_mode],
425
- [state, chatbot, textbox, imagebox] + btn_list
426
- ).then(
427
- http_bot,
428
- [state, model_selector, temperature, top_p, max_output_tokens],
429
- [state, chatbot] + btn_list,
430
- concurrency_limit=concurrency_count
431
- )
432
-
433
- clear_btn.click(
434
- clear_history,
435
- None,
436
- [state, chatbot, textbox, imagebox] + btn_list,
437
- queue=False
438
- )
439
-
440
- textbox.submit(
441
- add_text,
442
- [state, textbox, imagebox, image_process_mode],
443
- [state, chatbot, textbox, imagebox] + btn_list,
444
- queue=False
445
- ).then(
446
- http_bot,
447
- [state, model_selector, temperature, top_p, max_output_tokens],
448
- [state, chatbot] + btn_list,
449
- concurrency_limit=concurrency_count
450
- )
451
-
452
- submit_btn.click(
453
- add_text,
454
- [state, textbox, imagebox, image_process_mode],
455
- [state, chatbot, textbox, imagebox] + btn_list
456
- ).then(
457
- http_bot,
458
- [state, model_selector, temperature, top_p, max_output_tokens],
459
- [state, chatbot] + btn_list,
460
- concurrency_limit=concurrency_count
461
- )
462
-
463
- if args.model_list_mode == "once":
464
- demo.load(
465
- load_demo,
466
- [url_params],
467
- [state, model_selector],
468
- js=get_window_url_params
469
- )
470
- elif args.model_list_mode == "reload":
471
- demo.load(
472
- load_demo_refresh_model_list,
473
- None,
474
- [state, model_selector],
475
- queue=False
476
- )
477
- else:
478
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
479
-
480
- return demo
481
-
482
-
483
- # if __name__ == "__main__":
484
- # parser = argparse.ArgumentParser()
485
- # parser.add_argument("--port", type=int, default=7860) # You can still specify the port
486
- # parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
487
- # parser.add_argument("--concurrency-count", type=int, default=16)
488
- # parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
489
- # parser.add_argument("--share", action="store_true")
490
- # parser.add_argument("--moderate", action="store_true")
491
- # parser.add_argument("--embed", action="store_true")
492
- # args = parser.parse_args()
493
- # # models = get_model_list()
494
- # demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
495
- # demo.queue(api_open=False).launch(
496
- # server_port=args.port, # Specify the port if needed
497
- # share=True,
498
- # debug=True # All other functionalities like sharing still work
499
- # )
500
- if __name__ == "__main__":
501
- parser = argparse.ArgumentParser()
502
- parser.add_argument("--host", type=str, default="0.0.0.0")
503
- parser.add_argument("--port", type=int)
504
- parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
505
- parser.add_argument("--concurrency-count", type=int, default=16)
506
- parser.add_argument("--model-list-mode", type=str, default="once",
507
- choices=["once", "reload"])
508
- parser.add_argument("--share", action="store_true")
509
- parser.add_argument("--moderate", action="store_true")
510
- parser.add_argument("--embed", action="store_true")
511
- args = parser.parse_args()
512
- logger.info(f"args: {args}")
513
-
514
- models = get_model_list()
515
-
516
- logger.info(args)
517
- demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
518
- demo.queue(
519
- api_open=False
520
- ).launch(
521
- server_name=args.host,
522
- server_port=args.port,
523
- share=True,
524
- debug=True
525
- )
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  from inference import inference_and_run
 
 
 
 
 
 
 
3
  import spaces
4
+ import os
5
+ import re
6
+ import shutil
7
 
8
+ model_name = 'Ferret-UI'
9
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  @spaces.GPU()
12
+ def inference_with_gradio(chatbot, image, prompt, model_path, box=None):
13
+ dir_path = os.path.dirname(image)
14
+ # image_path = image
15
+ # Define the directory where you want to save the image (current directory)
16
+ filename = os.path.basename(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  dir_path = "./"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Create the new path for the file (in the current directory)
20
+ image_path = os.path.join(dir_path, filename)
21
+ shutil.copy(image, image_path)
22
+ print("filename path: ", filename)
23
+ if "gemma" in model_path.lower():
24
+ conv_mode = "ferret_gemma_instruct"
25
+ else:
26
+ conv_mode = "ferret_llama_3"
27
+
28
+ # inference_text = inference_and_run(
29
+ # image_path=image_path,
30
+ # prompt=prompt,
31
+ # conv_mode=conv_mode,
32
+ # model_path=model_path,
33
+ # box=box
34
+ # )
35
+ inference_text = inference_and_run(
36
  image_path=filename, # double check this
37
  image_dir=dir_path,
38
  prompt=prompt,
39
+ model_path="jadechoghari/Ferret-UI-Gemma2b",
40
+ conv_mode=conv_mode, # Default mode from the original function
41
+ # temperature=temperature,
42
+ # top_p=top_p,
43
+ # max_new_tokens=max_new_tokens,
44
+ # stop=stop # Assuming we want to process the image
45
  )
46
+
47
+ # print("done, now appending", inference_text)
48
+ # chatbot.append((prompt, inference_text))
49
+ # return chatbot
50
+ # Convert inference_text to string if it's not already
51
+ if isinstance(inference_text, (list, tuple)):
52
+ inference_text = str(inference_text[0])
53
+
54
+ # Update chatbot history with new message pair
55
+ new_history = chatbot.copy() if chatbot else []
56
+ new_history.append((prompt, inference_text))
57
+ return new_history
58
+
59
+ def submit_chat(chatbot, text_input):
60
+ response = ''
61
+ chatbot.append((text_input, response))
62
+ return chatbot, ''
63
+
64
+ def clear_chat():
65
+ return [], None, ""
66
+
67
+ with open(f"{cur_dir}/logo.svg", "r", encoding="utf-8") as svg_file:
68
+ svg_content = svg_file.read()
69
+ font_size = "2.5em"
70
+ svg_content = re.sub(r'(<svg[^>]*)(>)', rf'\1 height="{font_size}" style="vertical-align: middle; display: inline-block;"\2', svg_content)
71
+ html = f"""
72
+ <p align="center" style="font-size: {font_size}; line-height: 1;">
73
+ <span style="display: inline-block; vertical-align: middle;">{svg_content}</span>
74
+ <span style="display: inline-block; vertical-align: middle;">{model_name}</span>
75
+ </p>
76
+ <center><font size=3><b>{model_name}</b> Demo: Upload an image, provide a prompt, and get insights using advanced AI models. <a href='https://huggingface.co/jadechoghari/Ferret-UI-Gemma2b'>😊 Huggingface</a></font></center>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
 
79
+ latex_delimiters_set = [{
80
+ "left": "\\(",
81
+ "right": "\\)",
82
+ "display": False
83
+ }, {
84
+ "left": "\\begin{equation}",
85
+ "right": "\\end{equation}",
86
+ "display": True
87
+ }, {
88
+ "left": "\\begin{align}",
89
+ "right": "\\end{align}",
90
+ "display": True
91
+ }]
92
+
93
+ # Set up UI components
94
+ image_input = gr.Image(label="Upload Image", type="filepath", height=350)
95
+ text_input = gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt")
96
+ model_dropdown = gr.Dropdown(choices=[
97
+ "jadechoghari/Ferret-UI-Gemma2b",
98
+ "jadechoghari/Ferret-UI-Llama8b",
99
+ ], label="Model Path", value="jadechoghari/Ferret-UI-Gemma2b")
100
+
101
+ bounding_box_input = gr.Textbox(placeholder="Optional bounding box (x1, y1, x2, y2)", label="Bounding Box (optional)")
102
+ chatbot = gr.Chatbot(label="Chat with Ferret-UI", height=400, show_copy_button=True, latex_delimiters=latex_delimiters_set)
103
+
104
+ with gr.Blocks(title=model_name, theme=gr.themes.Ocean()) as demo:
105
+ gr.HTML(html)
106
+ with gr.Row():
107
+ with gr.Column(scale=3):
108
+ # gr.Examples(
109
+ # examples=[
110
+ # ["appstore_reminders.png", "Describe the image in details", "jadechoghari/Ferret-UI-Gemma2b", None],
111
+ # ["appstore_reminders.png", "What's inside the selected region?", "jadechoghari/Ferret-UI-Gemma2b", "189, 906, 404, 970"],
112
+ # ["appstore_reminders.png", "Where is the Game Tab?", "jadechoghari/Ferret-UI-Gemma2b", None],
113
+ # ],
114
+ # inputs=[image_input, text_input, model_dropdown, bounding_box_input]
115
+ # )
116
+ image_input.render()
117
+ text_input.render()
118
+ model_dropdown.render()
119
+ bounding_box_input.render()
120
+ with gr.Column(scale=7):
121
+ chatbot.render()
122
+ with gr.Row():
123
+ send_btn = gr.Button("Send", variant="primary")
124
+ clear_btn = gr.Button("Clear", variant="secondary")
125
+
126
+ send_click_event = send_btn.click(
127
+ inference_with_gradio, [chatbot, image_input, text_input, model_dropdown, bounding_box_input], chatbot
128
+ ).then(submit_chat, [chatbot, text_input], [chatbot, text_input])
129
+ submit_event = text_input.submit(
130
+ inference_with_gradio, [chatbot, image_input, text_input, model_dropdown, bounding_box_input], chatbot
131
+ ).then(submit_chat, [chatbot, text_input], [chatbot, text_input])
132
+
133
+ clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input, bounding_box_input])
134
+
135
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.pyi DELETED
@@ -1,797 +0,0 @@
1
- '''
2
- Usage:
3
-
4
- python -m ferret.serve.gradio_web_server --controller http://localhost:10000 --add_region_feature
5
- '''
6
- import argparse
7
- import datetime
8
- import json
9
- import os
10
- import time
11
-
12
- import gradio as gr
13
- import requests
14
-
15
- from conversation import (default_conversation, conv_templates,
16
- SeparatorStyle)
17
- from constants import LOGDIR
18
- from utils import (build_logger, server_error_msg,
19
- violates_moderation, moderation_msg)
20
- import hashlib
21
- # Added
22
- import re
23
- from copy import deepcopy
24
- from PIL import ImageDraw, ImageFont
25
- from gradio import processing_utils
26
- import numpy as np
27
- import torch
28
- import torch.nn.functional as F
29
- from scipy.ndimage import binary_dilation, binary_erosion
30
- import pdb
31
- from gradio_css import code_highlight_css
32
- import spaces
33
-
34
- from inference import inference_and_run
35
-
36
- DEFAULT_REGION_REFER_TOKEN = "[region]"
37
- DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
38
-
39
-
40
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
41
-
42
- headers = {"User-Agent": "FERRET Client"}
43
-
44
- no_change_btn = gr.Button
45
- enable_btn = gr.Button(interactive=True)
46
- disable_btn = gr.Button(interactive=False)
47
-
48
- priority = {
49
- "vicuna-13b": "aaaaaaa",
50
- "koala-13b": "aaaaaab",
51
- }
52
-
53
- VOCAB_IMAGE_W = 1000 # 224
54
- VOCAB_IMAGE_H = 1000 # 224
55
-
56
- def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
57
- if mask is not None:
58
- assert mask.shape[0] == raw_w and mask.shape[1] == raw_h
59
- coor_mask = torch.zeros((raw_w, raw_h))
60
- # Assume it samples a point.
61
- if len(coor) == 2:
62
- # Define window size
63
- span = 5
64
- # Make sure the window does not exceed array bounds
65
- x_min = max(0, coor[0] - span)
66
- x_max = min(raw_w, coor[0] + span + 1)
67
- y_min = max(0, coor[1] - span)
68
- y_max = min(raw_h, coor[1] + span + 1)
69
- coor_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
70
- assert (coor_mask==1).any(), f"coor: {coor}, raw_w: {raw_w}, raw_h: {raw_h}"
71
- elif len(coor) == 4:
72
- # Box input or Sketch input.
73
- coor_mask = torch.zeros((raw_w, raw_h))
74
- coor_mask[coor[0]:coor[2]+1, coor[1]:coor[3]+1] = 1
75
- if mask is not None:
76
- coor_mask = coor_mask * mask
77
- # coor_mask = torch.from_numpy(coor_mask)
78
- # pdb.set_trace()
79
- assert len(coor_mask.nonzero()) != 0
80
- return coor_mask.tolist()
81
-
82
-
83
- def draw_box(coor, region_mask, region_ph, img, input_mode):
84
- colors = ["red"]
85
- draw = ImageDraw.Draw(img)
86
- font = ImageFont.truetype("./DejaVuSans.ttf", size=18)
87
- if input_mode == 'Box':
88
- draw.rectangle([coor[0], coor[1], coor[2], coor[3]], outline=colors[0], width=4)
89
- draw.rectangle([coor[0], coor[3] - int(font.size * 1.2), coor[0] + int((len(region_ph) + 0.8) * font.size * 0.6), coor[3]], outline=colors[0], fill=colors[0], width=4)
90
- draw.text([coor[0] + int(font.size * 0.2), coor[3] - int(font.size*1.2)], region_ph, font=font, fill=(255,255,255))
91
- elif input_mode == 'Point':
92
- r = 8
93
- leftUpPoint = (coor[0]-r, coor[1]-r)
94
- rightDownPoint = (coor[0]+r, coor[1]+r)
95
- twoPointList = [leftUpPoint, rightDownPoint]
96
- draw.ellipse(twoPointList, outline=colors[0], width=4)
97
- draw.rectangle([coor[0], coor[1], coor[0] + int((len(region_ph) + 0.8) * font.size * 0.6), coor[1] + int(font.size * 1.2)], outline=colors[0], fill=colors[0], width=4)
98
- draw.text([coor[0] + int(font.size * 0.2), coor[1]], region_ph, font=font, fill=(255,255,255))
99
- elif input_mode == 'Sketch':
100
- draw.rectangle([coor[0], coor[3] - int(font.size * 1.2), coor[0] + int((len(region_ph) + 0.8) * font.size * 0.6), coor[3]], outline=colors[0], fill=colors[0], width=4)
101
- draw.text([coor[0] + int(font.size * 0.2), coor[3] - int(font.size*1.2)], region_ph, font=font, fill=(255,255,255))
102
- # Use morphological operations to find the boundary
103
- mask = np.array(region_mask)
104
- dilated = binary_dilation(mask, structure=np.ones((3,3)))
105
- eroded = binary_erosion(mask, structure=np.ones((3,3)))
106
- boundary = dilated ^ eroded # XOR operation to find the difference between dilated and eroded mask
107
- # Loop over the boundary and paint the corresponding pixels
108
- for i in range(boundary.shape[0]):
109
- for j in range(boundary.shape[1]):
110
- if boundary[i, j]:
111
- # This is a pixel on the boundary, paint it red
112
- draw.point((i, j), fill=colors[0])
113
- else:
114
- NotImplementedError(f'Input mode of {input_mode} is not Implemented.')
115
- return img
116
-
117
-
118
- def get_conv_log_filename():
119
- t = datetime.datetime.now()
120
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
121
- return name
122
-
123
-
124
- # TODO: return model manually just one for now called "jadechoghari/Ferret-UI-Gemma2b"
125
- def get_model_list():
126
- # ret = requests.post(args.controller_url + "/refresh_all_workers")
127
- # assert ret.status_code == 200
128
- # ret = requests.post(args.controller_url + "/list_models")
129
- # models = ret.json()["models"]
130
- # models.sort(key=lambda x: priority.get(x, x))
131
- # logger.info(f"Models: {models}")
132
- # return models
133
- models = ["jadechoghari/Ferret-UI-Gemma2b"]
134
- logger.info(f"Models: {models}")
135
- return models
136
-
137
-
138
- get_window_url_params = """
139
- function() {
140
- const params = new URLSearchParams(window.location.search);
141
- url_params = Object.fromEntries(params);
142
- console.log(url_params);
143
- return url_params;
144
- }
145
- """
146
-
147
-
148
- def load_demo(url_params, request: gr.Request):
149
- # logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
150
-
151
- dropdown_update = gr.Dropdown(visible=True)
152
- if "model" in url_params:
153
- model = url_params["model"]
154
- if model in models:
155
- dropdown_update = gr.Dropdown(
156
- value=model, visible=True)
157
-
158
- state = default_conversation.copy()
159
- print("state", state)
160
- return (state,
161
- dropdown_update,
162
- gr.Chatbot(visible=True),
163
- gr.Textbox(visible=True),
164
- gr.Button(visible=True),
165
- gr.Row(visible=True),
166
- gr.Accordion(visible=True))
167
-
168
-
169
- def load_demo_refresh_model_list(request: gr.Request):
170
- # logger.info(f"load_demo. ip: {request.client.host}")
171
- models = get_model_list()
172
- state = default_conversation.copy()
173
- return (state, gr.Dropdown(
174
- choices=models,
175
- value=models[0] if len(models) > 0 else ""),
176
- gr.Chatbot(visible=True),
177
- gr.Textbox(visible=True),
178
- gr.Button(visible=True),
179
- gr.Row(visible=True),
180
- gr.Accordion(visible=True))
181
-
182
-
183
- def vote_last_response(state, vote_type, model_selector, request: gr.Request):
184
- with open(get_conv_log_filename(), "a") as fout:
185
- data = {
186
- "tstamp": round(time.time(), 4),
187
- "type": vote_type,
188
- "model": model_selector,
189
- "state": state.dict(),
190
- "ip": request.client.host,
191
- }
192
- fout.write(json.dumps(data) + "\n")
193
-
194
-
195
- def upvote_last_response(state, model_selector, request: gr.Request):
196
- vote_last_response(state, "upvote", model_selector, request)
197
- return ("",) + (disable_btn,) * 3
198
-
199
-
200
- def downvote_last_response(state, model_selector, request: gr.Request):
201
- vote_last_response(state, "downvote", model_selector, request)
202
- return ("",) + (disable_btn,) * 3
203
-
204
-
205
- def flag_last_response(state, model_selector, request: gr.Request):
206
- vote_last_response(state, "flag", model_selector, request)
207
- return ("",) + (disable_btn,) * 3
208
-
209
-
210
- def regenerate(state, image_process_mode, request: gr.Request):
211
- state.messages[-1][-1] = None
212
- prev_human_msg = state.messages[-2]
213
- if type(prev_human_msg[1]) in (tuple, list):
214
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
215
- state.skip_next = False
216
- return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
217
-
218
-
219
- def clear_history(request: gr.Request):
220
- state = default_conversation.copy()
221
- return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 + \
222
- (None, {'region_placeholder_tokens':[],'region_coordinates':[],'region_masks':[],'region_masks_in_prompts':[],'masks':[]}, [], None)
223
-
224
-
225
- def resize_bbox(box, image_w=None, image_h=None, default_wh=VOCAB_IMAGE_W):
226
- ratio_w = image_w * 1.0 / default_wh
227
- ratio_h = image_h * 1.0 / default_wh
228
-
229
- new_box = [int(box[0] * ratio_w), int(box[1] * ratio_h), \
230
- int(box[2] * ratio_w), int(box[3] * ratio_h)]
231
- return new_box
232
-
233
-
234
- def show_location(sketch_pad, chatbot):
235
- image = sketch_pad['image']
236
- img_w, img_h = image.size
237
- new_bboxes = []
238
- old_bboxes = []
239
- # chatbot[0] is image.
240
- text = chatbot[1:]
241
- for round_i in text:
242
- human_input = round_i[0]
243
- model_output = round_i[1]
244
- # TODO: Difference: vocab representation.
245
- # pattern = r'\[x\d*=(\d+(?:\.\d+)?), y\d*=(\d+(?:\.\d+)?), x\d*=(\d+(?:\.\d+)?), y\d*=(\d+(?:\.\d+)?)\]'
246
- pattern = r'\[(\d+(?:\.\d+)?), (\d+(?:\.\d+)?), (\d+(?:\.\d+)?), (\d+(?:\.\d+)?)\]'
247
- matches = re.findall(pattern, model_output)
248
- for match in matches:
249
- x1, y1, x2, y2 = map(int, match)
250
- new_box = resize_bbox([x1, y1, x2, y2], img_w, img_h)
251
- new_bboxes.append(new_box)
252
- old_bboxes.append([x1, y1, x2, y2])
253
-
254
- set_old_bboxes = sorted(set(map(tuple, old_bboxes)), key=list(map(tuple, old_bboxes)).index)
255
- list_old_bboxes = list(map(list, set_old_bboxes))
256
-
257
- set_bboxes = sorted(set(map(tuple, new_bboxes)), key=list(map(tuple, new_bboxes)).index)
258
- list_bboxes = list(map(list, set_bboxes))
259
-
260
- output_image = deepcopy(image)
261
- draw = ImageDraw.Draw(output_image)
262
- #TODO: change from local to online path
263
- font = ImageFont.truetype("./DejaVuSans.ttf", 28)
264
- for i in range(len(list_bboxes)):
265
- x1, y1, x2, y2 = list_old_bboxes[i]
266
- x1_new, y1_new, x2_new, y2_new = list_bboxes[i]
267
- obj_string = '[obj{}]'.format(i)
268
- for round_i in text:
269
- model_output = round_i[1]
270
- model_output = model_output.replace('[{}, {}, {}, {}]'.format(x1, y1, x2, y2), obj_string)
271
- round_i[1] = model_output
272
- draw.rectangle([(x1_new, y1_new), (x2_new, y2_new)], outline="red", width=3)
273
- draw.text((x1_new+2, y1_new+5), obj_string[1:-1], fill="red", font=font)
274
-
275
- return (output_image, [chatbot[0]] + text, disable_btn)
276
-
277
-
278
- def add_text(state, text, image_process_mode, original_image, sketch_pad, request: gr.Request):
279
- print("add text called!")
280
-
281
-
282
- image = sketch_pad['image']
283
- print("text", text, "and : ", len(text))
284
- print("Image path", original_image)
285
-
286
- if len(text) <= 0 and image is None:
287
- state.skip_next = True
288
- return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
289
- if args.moderate:
290
- flagged = violates_moderation(text)
291
- if flagged:
292
- state.skip_next = True
293
- return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
294
- no_change_btn,) * 5
295
-
296
- text = text[:1536] # Hard cut-off
297
- if original_image is None:
298
- assert image is not None
299
- original_image = image.copy()
300
- print('No location, copy original image in add_text')
301
-
302
- if image is not None:
303
- if state.first_round:
304
- text = text[:1200] # Hard cut-off for images
305
- if '<image>' not in text:
306
- # text = '<Image><image></Image>' + text
307
- text = text + '\n<image>'
308
- text = (text, original_image, image_process_mode)
309
- if len(state.get_images(return_pil=True)) > 0:
310
- new_state = default_conversation.copy()
311
- new_state.first_round = False
312
- state=new_state
313
- print('First round add image finsihed.')
314
-
315
- state.append_message(state.roles[0], text)
316
- state.append_message(state.roles[1], None)
317
- state.skip_next = False
318
- return (state, state.to_gradio_chatbot(), "", original_image) + (disable_btn,) * 5
319
-
320
-
321
- def post_process_code(code):
322
- sep = "\n```"
323
- if sep in code:
324
- blocks = code.split(sep)
325
- if len(blocks) % 2 == 1:
326
- for i in range(1, len(blocks), 2):
327
- blocks[i] = blocks[i].replace("\\_", "_")
328
- code = sep.join(blocks)
329
- return code
330
-
331
-
332
- def find_indices_in_order(str_list, STR):
333
- indices = []
334
- i = 0
335
- while i < len(STR):
336
- for element in str_list:
337
- if STR[i:i+len(element)] == element:
338
- indices.append(str_list.index(element))
339
- i += len(element) - 1
340
- break
341
- i += 1
342
- return indices
343
-
344
-
345
- def format_region_prompt(prompt, refer_input_state):
346
- # Find regions in prompts and assign corresponding region masks
347
- refer_input_state['region_masks_in_prompts'] = []
348
- indices_region_placeholder_in_prompt = find_indices_in_order(refer_input_state['region_placeholder_tokens'], prompt)
349
- refer_input_state['region_masks_in_prompts'] = [refer_input_state['region_masks'][iii] for iii in indices_region_placeholder_in_prompt]
350
-
351
- # Find regions in prompts and replace with real coordinates and region feature token.
352
- for region_ph_index, region_ph_i in enumerate(refer_input_state['region_placeholder_tokens']):
353
- prompt = prompt.replace(region_ph_i, '{} {}'.format(refer_input_state['region_coordinates'][region_ph_index], DEFAULT_REGION_FEA_TOKEN))
354
- return prompt
355
-
356
- @spaces.GPU()
357
- def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_input_state, request: gr.Request):
358
- # def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
359
- start_tstamp = time.time()
360
- model_name = model_selector
361
-
362
- if state.skip_next:
363
- # This generate call is skipped due to invalid inputs
364
- yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
365
- return
366
-
367
- print("state messages: ", state.messages)
368
- if len(state.messages) == state.offset + 2:
369
- # First round of conversation
370
- # template_name = 'ferret_v1'
371
- template_name = 'ferret_gemma_instruct'
372
- # Below is LLaVA's original templates.
373
- # if "llava" in model_name.lower():
374
- # if 'llama-2' in model_name.lower():
375
- # template_name = "llava_llama_2"
376
- # elif "v1" in model_name.lower():
377
- # if 'mmtag' in model_name.lower():
378
- # template_name = "v1_mmtag"
379
- # elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
380
- # template_name = "v1_mmtag"
381
- # else:
382
- # template_name = "llava_v1"
383
- # elif "mpt" in model_name.lower():
384
- # template_name = "mpt"
385
- # else:
386
- # if 'mmtag' in model_name.lower():
387
- # template_name = "v0_mmtag"
388
- # elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
389
- # template_name = "v0_mmtag"
390
- # else:
391
- # template_name = "llava_v0"
392
- # elif "mpt" in model_name:
393
- # template_name = "mpt_text"
394
- # elif "llama-2" in model_name:
395
- # template_name = "llama_2"
396
- # else:
397
- # template_name = "vicuna_v1"
398
- new_state = conv_templates[template_name].copy()
399
- new_state.append_message(new_state.roles[0], state.messages[-2][1])
400
- new_state.append_message(new_state.roles[1], None)
401
- state = new_state
402
- state.first_round = False
403
-
404
- # # Query worker address
405
- # controller_url = args.controller_url
406
- # ret = requests.post(controller_url + "/get_worker_address",
407
- # json={"model": model_name})
408
- # worker_addr = ret.json()["address"]
409
- # logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
410
-
411
- # No available worker
412
- # if worker_addr == "":
413
- # state.messages[-1][-1] = server_error_msg
414
- # yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
415
- # return
416
-
417
- # Construct prompt
418
- prompt = state.get_prompt()
419
- if args.add_region_feature:
420
- prompt = format_region_prompt(prompt, refer_input_state)
421
-
422
- all_images = state.get_images(return_pil=True)
423
- all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
424
- for image, hash in zip(all_images, all_image_hash):
425
- t = datetime.datetime.now()
426
- # fishy can remove it
427
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
428
- if not os.path.isfile(filename):
429
- os.makedirs(os.path.dirname(filename), exist_ok=True)
430
- image.save(filename)
431
-
432
- # Make requests
433
- pload = {
434
- "model": model_name,
435
- "prompt": prompt,
436
- "temperature": float(temperature),
437
- "top_p": float(top_p),
438
- "max_new_tokens": min(int(max_new_tokens), 1536),
439
- "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
440
- "images": f'List of {len(state.get_images())} images: {all_image_hash}',
441
- }
442
- logger.info(f"==== request ====\n{pload}")
443
- if args.add_region_feature:
444
- pload['region_masks'] = refer_input_state['region_masks_in_prompts']
445
- logger.info(f"==== add region_masks_in_prompts to request ====\n")
446
-
447
- pload['images'] = state.get_images()
448
- print(f'Input Prompt: {prompt}')
449
- print("all_image_hash", all_image_hash)
450
-
451
- state.messages[-1][-1] = "β–Œ"
452
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
453
-
454
- try:
455
- # Stream output
456
- stop = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
457
- #TODO: define inference and run function
458
- results, extracted_texts = inference_and_run(
459
- image_path=all_image_hash[0], # double check this
460
- prompt=prompt,
461
- model_path=model_name,
462
- conv_mode="ferret_gemma_instruct", # Default mode from the original function
463
- temperature=temperature,
464
- top_p=top_p,
465
- max_new_tokens=max_new_tokens,
466
- stop=stop # Assuming we want to process the image
467
- )
468
-
469
- # response = requests.post(worker_addr + "/worker_generate_stream",
470
- # headers=headers, json=pload, stream=True, timeout=10)
471
- response = extracted_texts
472
- logger.info(f"This is the respone {response}")
473
-
474
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
475
- if chunk:
476
- data = json.loads(chunk.decode())
477
- if data["error_code"] == 0:
478
- output = data["text"][len(prompt):].strip()
479
- output = post_process_code(output)
480
- state.messages[-1][-1] = output + "β–Œ"
481
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
482
- else:
483
- output = data["text"] + f" (error_code: {data['error_code']})"
484
- state.messages[-1][-1] = output
485
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
486
- return
487
- time.sleep(0.03)
488
- except requests.exceptions.RequestException as e:
489
- state.messages[-1][-1] = server_error_msg
490
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
491
- return
492
-
493
- state.messages[-1][-1] = state.messages[-1][-1][:-1]
494
- yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
495
-
496
- finish_tstamp = time.time()
497
- logger.info(f"{output}")
498
-
499
- with open(get_conv_log_filename(), "a") as fout:
500
- data = {
501
- "tstamp": round(finish_tstamp, 4),
502
- "type": "chat",
503
- "model": model_name,
504
- "start": round(start_tstamp, 4),
505
- "finish": round(start_tstamp, 4),
506
- "state": state.dict(),
507
- "images": all_image_hash,
508
- "ip": request.client.host,
509
- }
510
- fout.write(json.dumps(data) + "\n")
511
-
512
- title_markdown = ("""
513
- # 🦦 Ferret: Refer and Ground Anything Anywhere at Any Granularity
514
- """)
515
- # [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485)
516
-
517
- tos_markdown = ("""
518
- ### Terms of use
519
- By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
520
- """)
521
-
522
-
523
- learn_more_markdown = ("""
524
- ### License
525
- The service is a research preview intended for non-commercial use only
526
- """)
527
-
528
-
529
- css = code_highlight_css + """
530
- pre {
531
- white-space: pre-wrap; /* Since CSS 2.1 */
532
- white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
533
- white-space: -pre-wrap; /* Opera 4-6 */
534
- white-space: -o-pre-wrap; /* Opera 7 */
535
- word-wrap: break-word; /* Internet Explorer 5.5+ */
536
- }
537
- """
538
-
539
- Instructions = '''
540
- Instructions:
541
- 1. Select a 'Referring Input Type'
542
- 2. Draw on the image to refer to a region/point.
543
- 3. Copy the region id from 'Referring Input Type' to refer to a region in your chat.
544
- '''
545
- from gradio.events import Dependency
546
-
547
- class ImageMask(gr.components.Image):
548
- """
549
- Sets: source="canvas", tool="sketch"
550
- """
551
-
552
- is_template = True
553
-
554
- def __init__(self, **kwargs):
555
- super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
556
-
557
- def preprocess(self, x):
558
- return super().preprocess(x)
559
- from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
560
- from gradio.blocks import Block
561
- if TYPE_CHECKING:
562
- from gradio.components import Timer
563
-
564
-
565
- def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer):
566
- if type(input) == dict:
567
- image = deepcopy(input['image'])
568
- mask = deepcopy(input['mask'])
569
- else:
570
- mask = deepcopy(input)
571
-
572
- # W, H -> H, W, 3
573
- image_new = np.asarray(image)
574
- img_height = image_new.shape[0]
575
- img_width = image_new.shape[1]
576
-
577
- # W, H, 4 -> H, W
578
- mask_new = np.asarray(mask)[:,:,0].copy()
579
- mask_new = torch.from_numpy(mask_new)
580
- mask_new = (F.interpolate(mask_new.unsqueeze(0).unsqueeze(0), (img_height, img_width), mode='bilinear') > 0)
581
- mask_new = mask_new[0, 0].transpose(1, 0).long()
582
-
583
- if len(refer_input_state['masks']) == 0:
584
- last_mask = torch.zeros_like(mask_new)
585
- else:
586
- last_mask = refer_input_state['masks'][-1]
587
-
588
- diff_mask = mask_new - last_mask
589
- if torch.all(diff_mask == 0):
590
- print('Init Uploading Images.')
591
- return (refer_input_state, refer_text_show, image)
592
- else:
593
- refer_input_state['masks'].append(mask_new)
594
-
595
- if input_mode == 'Point':
596
- nonzero_points = diff_mask.nonzero()
597
- nonzero_points_avg_x = torch.median(nonzero_points[:, 0])
598
- nonzero_points_avg_y = torch.median(nonzero_points[:, 1])
599
- sampled_coor = [nonzero_points_avg_x, nonzero_points_avg_y]
600
- # pdb.set_trace()
601
- cur_region_masks = generate_mask_for_feature(sampled_coor, raw_w=img_width, raw_h=img_height)
602
- elif input_mode == 'Box' or input_mode == 'Sketch':
603
- # pdb.set_trace()
604
- x1x2 = diff_mask.max(1)[0].nonzero()[:, 0]
605
- y1y2 = diff_mask.max(0)[0].nonzero()[:, 0]
606
- y1, y2 = y1y2.min(), y1y2.max()
607
- x1, x2 = x1x2.min(), x1x2.max()
608
- # pdb.set_trace()
609
- sampled_coor = [x1, y1, x2, y2]
610
- if input_mode == 'Box':
611
- cur_region_masks = generate_mask_for_feature(sampled_coor, raw_w=img_width, raw_h=img_height)
612
- else:
613
- cur_region_masks = generate_mask_for_feature(sampled_coor, raw_w=img_width, raw_h=img_height, mask=diff_mask)
614
- else:
615
- raise NotImplementedError(f'Input mode of {input_mode} is not Implemented.')
616
-
617
- # TODO(haoxuan): Hack img_size to be 224 here, need to make it a argument.
618
- if len(sampled_coor) == 2:
619
- point_x = int(VOCAB_IMAGE_W * sampled_coor[0] / img_width)
620
- point_y = int(VOCAB_IMAGE_H * sampled_coor[1] / img_height)
621
- cur_region_coordinates = f'[{int(point_x)}, {int(point_y)}]'
622
- elif len(sampled_coor) == 4:
623
- point_x1 = int(VOCAB_IMAGE_W * sampled_coor[0] / img_width)
624
- point_y1 = int(VOCAB_IMAGE_H * sampled_coor[1] / img_height)
625
- point_x2 = int(VOCAB_IMAGE_W * sampled_coor[2] / img_width)
626
- point_y2 = int(VOCAB_IMAGE_H * sampled_coor[3] / img_height)
627
- cur_region_coordinates = f'[{int(point_x1)}, {int(point_y1)}, {int(point_x2)}, {int(point_y2)}]'
628
-
629
- cur_region_id = len(refer_input_state['region_placeholder_tokens'])
630
- cur_region_token = DEFAULT_REGION_REFER_TOKEN.split(']')[0] + str(cur_region_id) + ']'
631
- refer_input_state['region_placeholder_tokens'].append(cur_region_token)
632
- refer_input_state['region_coordinates'].append(cur_region_coordinates)
633
- refer_input_state['region_masks'].append(cur_region_masks)
634
- assert len(refer_input_state['region_masks']) == len(refer_input_state['region_coordinates']) == len(refer_input_state['region_placeholder_tokens'])
635
- refer_text_show.append((cur_region_token, ''))
636
-
637
- # Show Parsed Referring.
638
- imagebox_refer = draw_box(sampled_coor, cur_region_masks, \
639
- cur_region_token, imagebox_refer, input_mode)
640
-
641
- return (refer_input_state, refer_text_show, imagebox_refer)
642
-
643
- def build_demo(embed_mode):
644
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", visible=False, container=False)
645
- with gr.Blocks(title="FERRET", theme=gr.themes.Base(), css=css) as demo:
646
- state = gr.State()
647
-
648
- if not embed_mode:
649
- gr.Markdown(title_markdown)
650
- gr.Markdown(Instructions)
651
-
652
- with gr.Row():
653
- with gr.Column(scale=4):
654
- with gr.Row(elem_id="model_selector_row"):
655
- model_selector = gr.Dropdown(
656
- choices=models,
657
- value=models[0] if len(models) > 0 else "",
658
- interactive=True,
659
- show_label=False,
660
- container=False)
661
-
662
- original_image = gr.Image(type="pil", visible=False)
663
- image_process_mode = gr.Radio(
664
- ["Raw+Processor", "Crop", "Resize", "Pad"],
665
- value="Raw+Processor",
666
- label="Preprocess for non-square image",
667
- visible=False)
668
-
669
- # Added for any-format input.
670
- sketch_pad = ImageMask(label="Image & Sketch", type="pil", elem_id="img2text")
671
- refer_input_mode = gr.Radio(
672
- ["Point", "Box", "Sketch"],
673
- value="Point",
674
- label="Referring Input Type")
675
- refer_input_state = gr.State({'region_placeholder_tokens':[],
676
- 'region_coordinates':[],
677
- 'region_masks':[],
678
- 'region_masks_in_prompts':[],
679
- 'masks':[],
680
- })
681
- refer_text_show = gr.HighlightedText(value=[], label="Referring Input Cache")
682
-
683
- imagebox_refer = gr.Image(type="pil", label="Parsed Referring Input")
684
- imagebox_output = gr.Image(type="pil", label='Output Vis')
685
-
686
- cur_dir = os.path.dirname(os.path.abspath(__file__))
687
- # gr.Examples(examples=[
688
- # # [f"{cur_dir}/examples/harry-potter-hogwarts.jpg", "What is in [region0]? And what do people use it for?"],
689
- # # [f"{cur_dir}/examples/ingredients.jpg", "What objects are in [region0] and [region1]?"],
690
- # # [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image? And tell me the coordinates of mentioned objects."],
691
- # [f"{cur_dir}/examples/ferret.jpg", "What's the relationship between object [region0] and object [region1]?"],
692
- # [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here? Tell me the coordinates in response."],
693
- # [f"{cur_dir}/examples/flickr_9472793441.jpg", "Describe the image in details."],
694
- # # [f"{cur_dir}/examples/coco_000000281759.jpg", "What are the locations of the woman wearing a blue dress, the woman in flowery top, the girl in purple dress, the girl wearing green shirt?"],
695
- # [f"{cur_dir}/examples/room_planning.jpg", "How to improve the design of the given room?"],
696
- # [f"{cur_dir}/examples/make_sandwitch.jpg", "How can I make a sandwich with available ingredients?"],
697
- # [f"{cur_dir}/examples/bathroom.jpg", "What is unusual about this image?"],
698
- # [f"{cur_dir}/examples/kitchen.png", "Is the object a man or a chicken? Explain the reason."],
699
- # ], inputs=[sketch_pad, textbox])
700
-
701
- with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
702
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
703
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
704
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
705
-
706
- with gr.Column(scale=5):
707
- chatbot = gr.Chatbot(elem_id="chatbot", label="FERRET", visible=False).style(height=750)
708
- with gr.Row():
709
- with gr.Column(scale=8):
710
- textbox.render()
711
- with gr.Column(scale=1, min_width=60):
712
- submit_btn = gr.Button(value="Submit", visible=False)
713
- with gr.Row(visible=False) as button_row:
714
- upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
715
- downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
716
- # flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
717
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
718
- regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
719
- clear_btn = gr.Button(value="πŸ—‘οΈ Clear history", interactive=False)
720
- location_btn = gr.Button(value="πŸͺ„ Show location", interactive=False)
721
-
722
- if not embed_mode:
723
- gr.Markdown(tos_markdown)
724
- gr.Markdown(learn_more_markdown)
725
- url_params = gr.JSON(visible=False)
726
-
727
- # Register listeners
728
- btn_list = [upvote_btn, downvote_btn, location_btn, regenerate_btn, clear_btn]
729
- upvote_btn.click(upvote_last_response,
730
- [state, model_selector], [textbox, upvote_btn, downvote_btn, location_btn])
731
- downvote_btn.click(downvote_last_response,
732
- [state, model_selector], [textbox, upvote_btn, downvote_btn, location_btn])
733
- # flag_btn.click(flag_last_response,
734
- # [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
735
- regenerate_btn.click(regenerate, [state, image_process_mode],
736
- [state, chatbot, textbox] + btn_list).then(
737
- http_bot, [state, model_selector, temperature, top_p, max_output_tokens, refer_input_state],
738
- [state, chatbot] + btn_list)
739
- clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox_output, original_image] + btn_list + \
740
- [sketch_pad, refer_input_state, refer_text_show, imagebox_refer])
741
- location_btn.click(show_location,
742
- [sketch_pad, chatbot], [imagebox_output, chatbot, location_btn])
743
-
744
-
745
- #TODO: fix bug text and image not adding when clicking submit
746
- textbox.submit(add_text, [state, textbox, image_process_mode, original_image, sketch_pad], [state, chatbot, textbox, original_image] + btn_list
747
- ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens, refer_input_state],
748
- [state, chatbot] + btn_list)
749
-
750
- submit_btn.click(add_text, [state, textbox, image_process_mode, original_image, sketch_pad], [state, chatbot, textbox, original_image] + btn_list
751
- ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens, refer_input_state],
752
- [state, chatbot] + btn_list)
753
-
754
-
755
-
756
- sketch_pad.edit(
757
- draw,
758
- inputs=[refer_input_mode, sketch_pad, refer_input_state, refer_text_show, imagebox_refer],
759
- outputs=[refer_input_state, refer_text_show, imagebox_refer],
760
- queue=True,
761
- )
762
-
763
- if args.model_list_mode == "once":
764
- demo.load(load_demo, [url_params], [state, model_selector,
765
- chatbot, textbox, submit_btn, button_row, parameter_row],
766
- _js=get_window_url_params)
767
- elif args.model_list_mode == "reload":
768
- demo.load(load_demo_refresh_model_list, None, [state, model_selector,
769
- chatbot, textbox, submit_btn, button_row, parameter_row])
770
- else:
771
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
772
-
773
- return demo
774
-
775
-
776
- if __name__ == "__main__":
777
- parser = argparse.ArgumentParser()
778
- parser.add_argument("--host", type=str, default="0.0.0.0")
779
- parser.add_argument("--port", type=int)
780
- parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
781
- parser.add_argument("--concurrency-count", type=int, default=8)
782
- parser.add_argument("--model-list-mode", type=str, default="once",
783
- choices=["once", "reload"])
784
- parser.add_argument("--share", action="store_true")
785
- parser.add_argument("--moderate", action="store_true")
786
- parser.add_argument("--embed", action="store_true")
787
- parser.add_argument("--add_region_feature", action="store_true")
788
- args = parser.parse_args()
789
- logger.info(f"args: {args}")
790
-
791
- models = get_model_list()
792
-
793
- logger.info(args)
794
- demo = build_demo(args.embed)
795
- demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10,
796
- api_open=False).launch(
797
- server_name=args.host, server_port=args.port, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cli.py CHANGED
@@ -1,130 +1,164 @@
1
- import argparse
2
  import torch
3
-
4
  from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
  from conversation import conv_templates, SeparatorStyle
6
  from builder import load_pretrained_model
7
  from utils import disable_torch_init
8
- from ferretui.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
9
-
10
  from PIL import Image
11
-
12
  import requests
13
- from PIL import Image
14
  from io import BytesIO
15
  from transformers import TextStreamer
16
-
 
 
 
 
 
 
 
 
 
 
17
 
18
  def load_image(image_file):
19
- if image_file.startswith('http://') or image_file.startswith('https://'):
20
- response = requests.get(image_file)
21
- image = Image.open(BytesIO(response.content)).convert('RGB')
22
- else:
23
- image = Image.open(image_file).convert('RGB')
 
 
 
24
  return image
25
 
26
-
27
- def main(args):
28
- # Model
 
 
 
 
 
 
 
 
 
 
 
 
29
  disable_torch_init()
30
 
31
- model_name = get_model_name_from_path(args.model_path)
32
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
 
 
33
 
 
34
  if "llama-2" in model_name.lower():
35
- conv_mode = "llava_llama_2"
36
  elif "mistral" in model_name.lower():
37
- conv_mode = "mistral_instruct"
38
  elif "v1.6-34b" in model_name.lower():
39
- conv_mode = "chatml_direct"
40
  elif "v1" in model_name.lower():
41
- conv_mode = "llava_v1"
42
  elif "mpt" in model_name.lower():
43
- conv_mode = "mpt"
44
- if "gemma" in model_name.lower():
45
- conv_mode = "ferret_gemma_instruct"
46
- if "llama" in model_name.lower():
47
- conv_mode = "ferret_llama_3"
48
  else:
49
- conv_mode = "llava_v0"
50
 
51
- if args.conv_mode is not None and conv_mode != args.conv_mode:
52
- print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
53
- else:
54
- args.conv_mode = conv_mode
 
 
 
55
 
56
- conv = conv_templates[args.conv_mode].copy()
57
  if "mpt" in model_name.lower():
58
  roles = ('user', 'assistant')
59
  else:
60
  roles = conv.roles
61
 
62
- image = load_image(args.image_file)
 
 
 
 
63
  image_size = image.size
64
- # Similar operation in model_worker.py
65
- image_tensor = process_images([image], image_processor, model.config)
66
- if type(image_tensor) is list:
67
- image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  else:
69
- image_tensor = image_tensor.to(model.device, dtype=torch.float16)
70
-
71
- while True:
72
- try:
73
- inp = input(f"{roles[0]}: ")
74
- except EOFError:
75
- inp = ""
76
- if not inp:
77
- print("exit...")
78
- break
79
-
80
- print(f"{roles[1]}: ", end="")
81
-
82
- if image is not None:
83
- # first message
84
- if model.config.mm_use_im_start_end:
85
- inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
86
- else:
87
- inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
88
- image = None
89
-
90
- conv.append_message(conv.roles[0], inp)
91
- conv.append_message(conv.roles[1], None)
92
- prompt = conv.get_prompt()
93
-
94
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
95
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
96
- keywords = [stop_str]
97
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
98
-
99
- with torch.inference_mode():
100
- output_ids = model.generate(
101
- input_ids,
102
- images=image_tensor,
103
- image_sizes=[image_size],
104
- do_sample=True if args.temperature > 0 else False,
105
- temperature=args.temperature,
106
- max_new_tokens=args.max_new_tokens,
107
- streamer=streamer,
108
- use_cache=True)
109
-
110
- outputs = tokenizer.decode(output_ids[0]).strip()
111
- conv.messages[-1][-1] = outputs
112
-
113
- if args.debug:
114
- print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
115
-
116
-
117
- if __name__ == "__main__":
118
- parser = argparse.ArgumentParser()
119
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
120
- parser.add_argument("--model-base", type=str, default=None)
121
- parser.add_argument("--image-file", type=str, required=True)
122
- parser.add_argument("--device", type=str, default="cuda")
123
- parser.add_argument("--conv-mode", type=str, default=None)
124
- parser.add_argument("--temperature", type=float, default=0.2)
125
- parser.add_argument("--max-new-tokens", type=int, default=512)
126
- parser.add_argument("--load-8bit", action="store_true")
127
- parser.add_argument("--load-4bit", action="store_true")
128
- parser.add_argument("--debug", action="store_true")
129
- args = parser.parse_args()
130
- main(args)
 
 
1
  import torch
 
2
  from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
3
  from conversation import conv_templates, SeparatorStyle
4
  from builder import load_pretrained_model
5
  from utils import disable_torch_init
6
+ from mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
 
7
  from PIL import Image
 
8
  import requests
 
9
  from io import BytesIO
10
  from transformers import TextStreamer
11
+ import spaces
12
+ from functools import partial
13
+ import traceback
14
+ import sys
15
+ # def load_image(image_file):
16
+ # if image_file.startswith('http://') or image_file.startswith('https://'):
17
+ # response = requests.get(image_file)
18
+ # image = Image.open(BytesIO(response.content)).convert('RGB')
19
+ # else:
20
+ # image = Image.open(image_file).convert('RGB')
21
+ # return image
22
 
23
  def load_image(image_file):
24
+ print("the image file : ", image_file)
25
+
26
+ image = Image.open(image_file).convert('RGB')
27
+
28
+ if image is None:
29
+ print("image is None")
30
+ sys.exit("Aborting program: Image is None.")
31
+
32
  return image
33
 
34
+ @spaces.GPU()
35
+ def run_inference(
36
+ model_path,
37
+ image_file,
38
+ prompt_text,
39
+ model_base=None,
40
+ device="cuda",
41
+ conv_mode=None,
42
+ temperature=0.2,
43
+ max_new_tokens=512,
44
+ load_8bit=False,
45
+ load_4bit=False,
46
+ debug=False
47
+ ):
48
+ # Model initialization
49
  disable_torch_init()
50
 
51
+ model_name = get_model_name_from_path(model_path)
52
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
53
+ model_path, model_base, model_name, load_8bit, load_4bit
54
+ )
55
 
56
+ # Determine conversation mode
57
  if "llama-2" in model_name.lower():
58
+ conv_mode_inferred = "llava_llama_2"
59
  elif "mistral" in model_name.lower():
60
+ conv_mode_inferred = "mistral_instruct"
61
  elif "v1.6-34b" in model_name.lower():
62
+ conv_mode_inferred = "chatml_direct"
63
  elif "v1" in model_name.lower():
64
+ conv_mode_inferred = "llava_v1"
65
  elif "mpt" in model_name.lower():
66
+ conv_mode_inferred = "mpt"
67
+ elif "gemma" in model_name.lower():
68
+ conv_mode_inferred = "ferret_gemma_instruct"
69
+ elif "llama" in model_name.lower():
70
+ conv_mode_inferred = "ferret_llama_3"
71
  else:
72
+ conv_mode_inferred = "llava_v0"
73
 
74
+ # Use user-specified conversation mode if provided
75
+ conv_mode = conv_mode or conv_mode_inferred
76
+
77
+ if conv_mode != conv_mode_inferred:
78
+ print(f'[WARNING] the auto inferred conversation mode is {conv_mode_inferred}, while `conv_mode` is {conv_mode}, using {conv_mode}')
79
+
80
+ conv = conv_templates[conv_mode].copy()
81
 
 
82
  if "mpt" in model_name.lower():
83
  roles = ('user', 'assistant')
84
  else:
85
  roles = conv.roles
86
 
87
+ # Load and process image
88
+ print("loading image", image_file)
89
+ image = load_image(image_file)
90
+ if image is None:
91
+ print("image is None")
92
  image_size = image.size
93
+ image_h = 336 # Height of the image
94
+ image_w = 336
95
+ #ERROR
96
+ # image_tensor = process_images([image], image_processor, model.config)
97
+ # if type(image_tensor) is list:
98
+ # image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
99
+ # else:
100
+ # image_tensor = image_tensor.to(model.device, dtype=torch.float16)
101
+ if model.config.image_aspect_ratio == "square_nocrop":
102
+ image_tensor = image_processor.preprocess(image, return_tensors='pt', do_resize=True,
103
+ do_center_crop=False, size=[image_h, image_w])['pixel_values'][0]
104
+ elif model.config.image_aspect_ratio == "anyres":
105
+ image_process_func = partial(image_processor.preprocess, return_tensors='pt', do_resize=True, do_center_crop=False, size=[image_h, image_w])
106
+ image_tensor = process_images([image], image_processor, model.config, image_process_func=image_process_func)[0]
107
+ else:
108
+ image_tensor = process_images([image], image_processor, model.config)[0]
109
+
110
+ if model.dtype == torch.float16:
111
+ image_tensor = image_tensor.half() # Convert image tensor to float16
112
+ data_type = torch.float16
113
+ else:
114
+ image_tensor = image_tensor.float() # Keep it in float32
115
+ data_type = torch.float32
116
+
117
+ # Now, add the batch dimension and move to GPU
118
+ images = image_tensor.unsqueeze(0).to(data_type).cuda()
119
+
120
+
121
+ # Process the first message with the image
122
+ if model.config.mm_use_im_start_end:
123
+ prompt_text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt_text
124
  else:
125
+ prompt_text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
126
+
127
+ # Prepare conversation
128
+ conv.append_message(conv.roles[0], prompt_text)
129
+ conv.append_message(conv.roles[1], None)
130
+ prompt = conv.get_prompt()
131
+
132
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
133
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
134
+ keywords = [stop_str]
135
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
136
+ print("image size: ", image_size)
137
+ # Generate the model's response
138
+
139
+ with torch.inference_mode():
140
+ output_ids = model.generate(
141
+ input_ids,
142
+ images=images,
143
+ image_sizes=[image_size],
144
+ do_sample=True if temperature > 0 else False,
145
+ temperature=temperature,
146
+ max_new_tokens=max_new_tokens,
147
+ streamer=streamer,
148
+ num_beams=1,
149
+ use_cache=True
150
+ )
151
+
152
+ # Decode and return the output
153
+ outputs = tokenizer.decode(output_ids[0]).strip()
154
+ conv.messages[-1][-1] = outputs
155
+
156
+ if debug:
157
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
158
+
159
+ return outputs
160
+
161
+
162
+ # Example usage:
163
+ # response = run_inference("path_to_model", "path_to_image", "your_prompt")
164
+ # print(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval.json CHANGED
@@ -1 +1 @@
1
- [{"id": 0, "image": "8b23f327b90b6211049acd36e3f99975.jpg", "image_h": 433, "image_w": 400, "conversations": [{"from": "human", "value": "<image>\nA chat between a human and an AI that understands visuals. In images, [x, y] denotes points: top-left [0, 0], bottom-right [width-1, height-1]. Increasing x moves right; y moves down. Bounding box: [x1, y1, x2, y2]. Image size: 1000x1000. Follow instructions.<start_of_turn>user\n<image>\nexplain what you see<end_of_turn>\n<start_of_turn>model\n"}]}]
 
1
+ [{"id": 0, "image": "Screenshot 2024-10-13 at 12.01.05\u202fAM.png", "image_h": 76, "image_w": 90, "conversations": [{"from": "human", "value": "<image>\ndescribe what you see in details"}]}]
eval_output.jsonl/0_of_1.jsonl CHANGED
@@ -1 +1 @@
1
- {"id": 0, "image_path": "8b23f327b90b6211049acd36e3f99975.jpg", "prompt": "A chat between a human and an AI that understands visuals. In images, [x, y] denotes points: top-left [0, 0], bottom-right [width-1, height-1]. Increasing x moves right; y moves down. Bounding box: [x1, y1, x2, y2]. Image size: 1000x1000. Follow instructions.<start_of_turn>user", "text": "A chat between a human and an unknown entity. \n\nThe conversation starts with a message from Jackyline Herrera saying, \"Ask Jackie to borrow her truck\". The reply is, \"Get gravel for bow, walk, 10, 1, 1, Shopping List\". \n\nThe next message is from Get Gravel for the truck, and the reply is, \"Buy mulch, #shoppinglist\". \n\nThe third message is from Buy mulch for the garden, and the reply is, \"Pick up succulents\". \n\nThe fourth message is from Pick up succulents for the garden, and the reply is, \"Buy soil for succulents\". \n\nThe fifth message is from Buy soil for succulents, and the reply is, \"Pick up soil for succulents\". \n\nThe sixth message is from Pick up succulents for the garden, and the reply is, \"Pick up soil for succulents\". \n\nThe seventh message is from Pick up succulents for the garden, and the reply is, \"Pick up soil for succulents\". \n\nThe eighth message is from Pick up succulents for the garden, and the reply is, \"Pick up soil for succulents\". \n\nThe ninth message is from Pick up succulents for the garden, and the reply is, \"Look up native vegetables along the fence\". \n\nThe tenth message is from Shopping List, and the reply is, \"Shopping List\". \n\nThe message at the bottom is from Shopping List, and the reply is, \"Look up native vegetables along the fence\". \n\nThe message at the very bottom is from Shopping List, and the reply is, \"Looking: Fran\".", "label": null}
 
1
+ {"id": 0, "image_path": "Screenshot 2024-10-13 at 12.01.05\u202fAM.png", "prompt": "describe what you see in details", "text": "The screen contains a large picture that occupies most of the screen, extending from nearly the top to the bottom. In the lower portion of the screen, there is a button labeled \"menu\". The button is relatively large and positioned at the lower part of the screen.", "label": null}
gradio_web_server.log CHANGED
The diff for this file is too large to render. See raw diff
 
logo.svg ADDED
serve_images/2024-10-19/8b23f327b90b6211049acd36e3f99975.jpg DELETED
Binary file (24.4 kB)
 
serve_images/2024-10-20/8b23f327b90b6211049acd36e3f99975.jpg DELETED
Binary file (24.4 kB)
 
untitled DELETED
File without changes