Sabareeshr commited on
Commit
d18d2d2
β€’
1 Parent(s): 1b3f07c

Upload app(1).py

Browse files
Files changed (1) hide show
  1. app(1).py +375 -0
app(1).py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import json
4
+ import os
5
+ import time
6
+ from threading import Thread
7
+
8
+ import gradio as gr
9
+ import torch
10
+ from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
11
+ DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
12
+ from llava.conversation import (SeparatorStyle, conv_templates,
13
+ default_conversation)
14
+ from llava.mm_utils import (KeywordsStoppingCriteria, load_image_from_base64,
15
+ process_images, tokenizer_image_token)
16
+ from llava.model.builder import load_pretrained_model
17
+ from transformers import TextIteratorStreamer
18
+
19
+ print(gr.__version__)
20
+
21
+ block_css = """
22
+
23
+ #buttons button {
24
+ min-width: min(120px,100%);
25
+ }
26
+ """
27
+ title_markdown = ("""
28
+ # 🐬 ShareGPT4V: Improving Large Multi-modal Models with Better Captions
29
+ ### πŸ”Š Notice: The demo of Share-Captioner will soon be supported. Stay tune for updates!
30
+ [[Project Page](https://sharegpt4v.github.io/)] [[Code](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V)] | πŸ“š [[Paper](https://arxiv.org/pdf/2311.12793.pdf)]
31
+ """)
32
+ tos_markdown = ("""
33
+ ### Terms of use
34
+ By using this service, users are required to agree to the following terms:
35
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
36
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
37
+ """)
38
+ learn_more_markdown = ("""
39
+ ### License
40
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
41
+ """)
42
+ ack_markdown = ("""
43
+ ### Acknowledgement
44
+ The template for this web demo is from [LLaVA](https://github.com/haotian-liu/LLaVA), and we are very grateful to LLaVA for their open source contributions to the community!
45
+ """)
46
+
47
+
48
+ def regenerate(state, image_process_mode):
49
+ state.messages[-1][-1] = None
50
+ prev_human_msg = state.messages[-2]
51
+ if type(prev_human_msg[1]) in (tuple, list):
52
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
53
+ state.skip_next = False
54
+ return (state, state.to_gradio_chatbot(), "", None)
55
+
56
+
57
+ def clear_history():
58
+ state = default_conversation.copy()
59
+ return (state, state.to_gradio_chatbot(), "", None)
60
+
61
+
62
+ def add_text(state, text, image, image_process_mode):
63
+ if len(text) <= 0 and image is None:
64
+ state.skip_next = True
65
+ return (state, state.to_gradio_chatbot(), "", None)
66
+
67
+ text = text[:1536] # Hard cut-off
68
+ if image is not None:
69
+ text = text[:1200] # Hard cut-off for images
70
+ if '<image>' not in text:
71
+ # text = '<Image><image></Image>' + text
72
+ text = text + '\n<image>'
73
+ text = (text, image, image_process_mode)
74
+ if len(state.get_images(return_pil=True)) > 0:
75
+ state = default_conversation.copy()
76
+ state.append_message(state.roles[0], text)
77
+ state.append_message(state.roles[1], None)
78
+ state.skip_next = False
79
+ return (state, state.to_gradio_chatbot(), "", None)
80
+
81
+
82
+ def load_demo():
83
+ state = default_conversation.copy()
84
+ return state
85
+
86
+
87
+ @torch.inference_mode()
88
+ def get_response(params):
89
+ prompt = params["prompt"]
90
+ ori_prompt = prompt
91
+ images = params.get("images", None)
92
+ num_image_tokens = 0
93
+ if images is not None and len(images) > 0:
94
+ if len(images) > 0:
95
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
96
+ raise ValueError(
97
+ "Number of images does not match number of <image> tokens in prompt")
98
+
99
+ images = [load_image_from_base64(image) for image in images]
100
+ images = process_images(images, image_processor, model.config)
101
+
102
+ if type(images) is list:
103
+ images = [image.to(model.device, dtype=torch.float16)
104
+ for image in images]
105
+ else:
106
+ images = images.to(model.device, dtype=torch.float16)
107
+
108
+ replace_token = DEFAULT_IMAGE_TOKEN
109
+ if getattr(model.config, 'mm_use_im_start_end', False):
110
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
111
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
112
+
113
+ num_image_tokens = prompt.count(
114
+ replace_token) * model.get_vision_tower().num_patches
115
+ else:
116
+ images = None
117
+ image_args = {"images": images}
118
+ else:
119
+ images = None
120
+ image_args = {}
121
+
122
+ temperature = float(params.get("temperature", 1.0))
123
+ top_p = float(params.get("top_p", 1.0))
124
+ max_context_length = getattr(
125
+ model.config, 'max_position_embeddings', 2048)
126
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
127
+ stop_str = params.get("stop", None)
128
+ do_sample = True if temperature > 0.001 else False
129
+
130
+ input_ids = tokenizer_image_token(
131
+ prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
132
+ keywords = [stop_str]
133
+ stopping_criteria = KeywordsStoppingCriteria(
134
+ keywords, tokenizer, input_ids)
135
+ streamer = TextIteratorStreamer(
136
+ tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
137
+
138
+ max_new_tokens = min(max_new_tokens, max_context_length -
139
+ input_ids.shape[-1] - num_image_tokens)
140
+
141
+ if max_new_tokens < 1:
142
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
143
+ return
144
+
145
+ # local inference
146
+ thread = Thread(target=model.generate, kwargs=dict(
147
+ inputs=input_ids,
148
+ do_sample=do_sample,
149
+ temperature=temperature,
150
+ top_p=top_p,
151
+ max_new_tokens=max_new_tokens,
152
+ streamer=streamer,
153
+ stopping_criteria=[stopping_criteria],
154
+ use_cache=True,
155
+ **image_args
156
+ ))
157
+ thread.start()
158
+
159
+ generated_text = ori_prompt
160
+ for new_text in streamer:
161
+ generated_text += new_text
162
+ if generated_text.endswith(stop_str):
163
+ generated_text = generated_text[:-len(stop_str)]
164
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode()
165
+
166
+
167
+ def http_bot(state, temperature, top_p, max_new_tokens):
168
+ if state.skip_next:
169
+ # This generate call is skipped due to invalid inputs
170
+ yield (state, state.to_gradio_chatbot())
171
+ return
172
+
173
+ if len(state.messages) == state.offset + 2:
174
+ # First round of conversation
175
+ if "llava" in model_name.lower():
176
+ if 'llama-2' in model_name.lower():
177
+ template_name = "llava_llama_2"
178
+ elif "v1" in model_name.lower():
179
+ if 'mmtag' in model_name.lower():
180
+ template_name = "v1_mmtag"
181
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
182
+ template_name = "v1_mmtag"
183
+ else:
184
+ template_name = "llava_v1"
185
+ elif "mpt" in model_name.lower():
186
+ template_name = "mpt"
187
+ else:
188
+ if 'mmtag' in model_name.lower():
189
+ template_name = "v0_mmtag"
190
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
191
+ template_name = "v0_mmtag"
192
+ else:
193
+ template_name = "llava_v0"
194
+ elif "mpt" in model_name:
195
+ template_name = "mpt_text"
196
+ elif "llama-2" in model_name:
197
+ template_name = "llama_2"
198
+ else:
199
+ template_name = "vicuna_v1"
200
+ new_state = conv_templates[template_name].copy()
201
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
202
+ new_state.append_message(new_state.roles[1], None)
203
+ state = new_state
204
+
205
+ # Construct prompt
206
+ prompt = state.get_prompt()
207
+
208
+ all_images = state.get_images(return_pil=True)
209
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest()
210
+ for image in all_images]
211
+
212
+ # Make requests
213
+ pload = {
214
+ "model": model_name,
215
+ "prompt": prompt,
216
+ "temperature": float(temperature),
217
+ "top_p": float(top_p),
218
+ "max_new_tokens": min(int(max_new_tokens), 1536),
219
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
220
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
221
+ }
222
+
223
+ pload['images'] = state.get_images()
224
+
225
+ state.messages[-1][-1] = "β–Œ"
226
+ yield (state, state.to_gradio_chatbot())
227
+
228
+ # for stream
229
+ output = get_response(pload)
230
+ for chunk in output:
231
+ if chunk:
232
+ data = json.loads(chunk.decode())
233
+ if data["error_code"] == 0:
234
+ output = data["text"][len(prompt):].strip()
235
+ state.messages[-1][-1] = output + "β–Œ"
236
+ yield (state, state.to_gradio_chatbot())
237
+ else:
238
+ output = data["text"] + \
239
+ f" (error_code: {data['error_code']})"
240
+ state.messages[-1][-1] = output
241
+ yield (state, state.to_gradio_chatbot())
242
+ return
243
+ time.sleep(0.03)
244
+
245
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
246
+ yield (state, state.to_gradio_chatbot())
247
+
248
+
249
+ def build_demo():
250
+ textbox = gr.Textbox(
251
+ show_label=False, placeholder="Enter text and press ENTER", container=False)
252
+ with gr.Blocks(title="ShareGPT4V", theme=gr.themes.Default(), css=block_css) as demo:
253
+ state = gr.State()
254
+ gr.Markdown(title_markdown)
255
+
256
+ with gr.Row():
257
+ with gr.Column(scale=5):
258
+ with gr.Row(elem_id="Model ID"):
259
+ gr.Dropdown(
260
+ choices=['ShareGPT4V-7B'],
261
+ value='ShareGPT4V-7B',
262
+ interactive=True,
263
+ label='Model ID',
264
+ container=False)
265
+ imagebox = gr.Image(type="pil")
266
+ image_process_mode = gr.Radio(
267
+ ["Crop", "Resize", "Pad", "Default"],
268
+ value="Default",
269
+ label="Preprocess for non-square image", visible=False)
270
+
271
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
272
+ gr.Examples(examples=[
273
+ [f"{cur_dir}/examples/breaking_bad.png",
274
+ "What is the most common catchphrase of the character on the right?"],
275
+ [f"{cur_dir}/examples/photo.png",
276
+ "From a photography perspective, analyze what makes this picture beautiful?"],
277
+ ], inputs=[imagebox, textbox])
278
+
279
+ with gr.Accordion("Parameters", open=False) as _:
280
+ temperature = gr.Slider(
281
+ minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
282
+ top_p = gr.Slider(
283
+ minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
284
+ max_output_tokens = gr.Slider(
285
+ minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
286
+
287
+ with gr.Column(scale=8):
288
+ chatbot = gr.Chatbot(
289
+ elem_id="chatbot", label="ShareGPT4V Chatbot", height=550)
290
+ with gr.Row():
291
+ with gr.Column(scale=8):
292
+ textbox.render()
293
+ with gr.Column(scale=1, min_width=50):
294
+ submit_btn = gr.Button(value="Send", variant="primary")
295
+ with gr.Row(elem_id="buttons") as _:
296
+ regenerate_btn = gr.Button(
297
+ value="πŸ”„ Regenerate", interactive=True)
298
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=True)
299
+
300
+ gr.Markdown(tos_markdown)
301
+ gr.Markdown(learn_more_markdown)
302
+ gr.Markdown(ack_markdown)
303
+
304
+ regenerate_btn.click(
305
+ regenerate,
306
+ [state, image_process_mode],
307
+ [state, chatbot, textbox, imagebox],
308
+ queue=False
309
+ ).then(
310
+ http_bot,
311
+ [state, temperature, top_p, max_output_tokens],
312
+ [state, chatbot]
313
+ )
314
+
315
+ clear_btn.click(
316
+ clear_history,
317
+ None,
318
+ [state, chatbot, textbox, imagebox],
319
+ queue=False
320
+ )
321
+
322
+ textbox.submit(
323
+ add_text,
324
+ [state, textbox, imagebox, image_process_mode],
325
+ [state, chatbot, textbox, imagebox],
326
+ queue=False
327
+ ).then(
328
+ http_bot,
329
+ [state, temperature, top_p, max_output_tokens],
330
+ [state, chatbot]
331
+ )
332
+
333
+ submit_btn.click(
334
+ add_text,
335
+ [state, textbox, imagebox, image_process_mode],
336
+ [state, chatbot, textbox, imagebox],
337
+ queue=False
338
+ ).then(
339
+ http_bot,
340
+ [state, temperature, top_p, max_output_tokens],
341
+ [state, chatbot]
342
+ )
343
+
344
+ demo.load(
345
+ load_demo,
346
+ None,
347
+ [state],
348
+ queue=False
349
+ )
350
+ return demo
351
+
352
+
353
+ def parse_args():
354
+ parser = argparse.ArgumentParser()
355
+ parser.add_argument("--host", type=str, default="0.0.0.0")
356
+ parser.add_argument("--port", type=int, default=7860)
357
+ parser.add_argument("--share", default=True)
358
+ parser.add_argument("--model-path", type=str,
359
+ default="Lin-Chen/ShareGPT4V-7B")
360
+ parser.add_argument("--model-name", type=str,
361
+ default="llava-v1.5-7b")
362
+ args = parser.parse_args()
363
+ return args
364
+
365
+
366
+ if __name__ == '__main__':
367
+ args = parse_args()
368
+ model_name = args.model_name
369
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
370
+ args.model_path, None, args.model_name, False, False)
371
+ demo = build_demo()
372
+ demo.queue()
373
+ demo.launch(server_name=args.host,
374
+ server_port=args.port,
375
+ share=args.share)