Singularity666 commited on
Commit
8724709
1 Parent(s): b3900cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -510
app.py CHANGED
@@ -1,383 +1,4 @@
1
- # Install required packages
2
- !pip install sentencepiece
3
- !pip install git+https://github.com/huggingface/transformers.git@cae78c46
4
- !pip install diffusers
5
- !pip install tokenizers==0.12.1
6
- !pip install datasets
7
- !pip install accelerate
8
- !pip install evaluate
9
- !pip install gradio==4.12.0
10
- !pip install gradio_client==0.8.0
11
- !pip install -i https://download.pytorch.org/whl/cu118 torch==2.0 torchvision==0.15 torchaudio==2.0
12
-
13
- # conversation.py
14
- import dataclasses
15
- from enum import auto, Enum
16
- from typing import List, Tuple
17
-
18
- class SeparatorStyle(Enum):
19
- """Different separator style."""
20
- SINGLE = auto()
21
- TWO = auto()
22
- MPT = auto()
23
-
24
- @dataclasses.dataclass
25
- class Conversation:
26
- """A class that keeps all conversation history."""
27
- system: str
28
- roles: List[str]
29
- messages: List[List[str]]
30
- offset: int
31
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
32
- sep: str = "###"
33
- sep2: str = None
34
- version: str = "Unknown"
35
-
36
- skip_next: bool = False
37
-
38
- def get_prompt(self):
39
- if self.sep_style == SeparatorStyle.SINGLE:
40
- ret = self.system + self.sep
41
- for role, message in self.messages:
42
- if message:
43
- if type(message) is tuple:
44
- message, _, _ = message
45
- ret += role + ": " + message + self.sep
46
- else:
47
- ret += role + ":"
48
- return ret
49
- elif self.sep_style == SeparatorStyle.TWO:
50
- seps = [self.sep, self.sep2]
51
- ret = self.system + seps[0]
52
- for i, (role, message) in enumerate(self.messages):
53
- if message:
54
- if type(message) is tuple:
55
- message, _, _ = message
56
- ret += role + ": " + message + seps[i % 2]
57
- else:
58
- ret += role + ":"
59
- return ret
60
- if self.sep_style == SeparatorStyle.MPT:
61
- ret = self.system + self.sep
62
- for role, message in self.messages:
63
- if message:
64
- if type(message) is tuple:
65
- message, _, _ = message
66
- ret += role + message + self.sep
67
- else:
68
- ret += role
69
- return ret
70
- else:
71
- raise ValueError(f"Invalid style: {self.sep_style}")
72
-
73
- def append_message(self, role, message):
74
- self.messages.append([role, message])
75
-
76
- def get_images(self, return_pil=False):
77
- images = []
78
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
79
- if i % 2 == 0:
80
- if type(msg) is tuple:
81
- import base64
82
- from io import BytesIO
83
- from PIL import Image
84
- msg, image, image_process_mode = msg
85
- if image_process_mode == "Pad":
86
- def expand2square(pil_img, background_color=(122, 116, 104)):
87
- width, height = pil_img.size
88
- if width == height:
89
- return pil_img
90
- elif width > height:
91
- result = Image.new(pil_img.mode, (width, width), background_color)
92
- result.paste(pil_img, (0, (width - height) // 2))
93
- return result
94
- else:
95
- result = Image.new(pil_img.mode, (height, height), background_color)
96
- result.paste(pil_img, ((height - width) // 2, 0))
97
- return result
98
- image = expand2square(image)
99
- elif image_process_mode == "Crop":
100
- pass
101
- elif image_process_mode == "Resize":
102
- image = image.resize((224, 224))
103
- else:
104
- raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
105
- max_hw, min_hw = max(image.size), min(image.size)
106
- aspect_ratio = max_hw / min_hw
107
- max_len, min_len = 800, 400
108
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
109
- longest_edge = int(shortest_edge * aspect_ratio)
110
- W, H = image.size
111
- if H > W:
112
- H, W = longest_edge, shortest_edge
113
- else:
114
- H, W = shortest_edge, longest_edge
115
- image = image.resize((W, H))
116
- if return_pil:
117
- images.append(image)
118
- else:
119
- buffered = BytesIO()
120
- image.save(buffered, format="JPEG")
121
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
122
- images.append(img_b64_str)
123
- return images
124
-
125
- def to_gradio_chatbot(self):
126
- ret = []
127
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
128
- if i % 2 == 0:
129
- if type(msg) is tuple:
130
- import base64
131
- from io import BytesIO
132
- msg, image, image_process_mode = msg
133
- max_hw, min_hw = max(image.size), min(image.size)
134
- aspect_ratio = max_hw / min_hw
135
- max_len, min_len = 800, 400
136
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
137
- longest_edge = int(shortest_edge * aspect_ratio)
138
- W, H = image.size
139
- if H > W:
140
- H, W = longest_edge, shortest_edge
141
- else:
142
- H, W = shortest_edge, longest_edge
143
- image = image.resize((W, H))
144
- # image = image.resize((224, 224))
145
- buffered = BytesIO()
146
- image.save(buffered, format="JPEG")
147
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
148
- img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
149
- msg = msg.replace('<image>', img_str)
150
- ret.append([msg, None])
151
- else:
152
- ret[-1][-1] = msg
153
- return ret
154
-
155
- def copy(self):
156
- return Conversation(
157
- system=self.system,
158
- roles=self.roles,
159
- messages=[[x, y] for x, y in self.messages],
160
- offset=self.offset,
161
- sep_style=self.sep_style,
162
- sep=self.sep,
163
- sep2=self.sep2)
164
-
165
- def dict(self):
166
- if len(self.get_images()) > 0:
167
- return {
168
- "system": self.system,
169
- "roles": self.roles,
170
- "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
171
- "offset": self.offset,
172
- "sep": self.sep,
173
- "sep2": self.sep2,
174
- }
175
- return {
176
- "system": self.system,
177
- "roles": self.roles,
178
- "messages": self.messages,
179
- "offset": self.offset,
180
- "sep": self.sep,
181
- "sep2": self.sep2,
182
- }
183
-
184
-
185
- conv_v1 = Conversation(
186
- system="A chat between a curious human and an artificial intelligence assistant. "
187
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
188
- roles=("Human", "Assistant"),
189
- messages=(
190
- ("Human", "Give three tips for staying healthy."),
191
- ("Assistant",
192
- "Sure, here are three tips for staying healthy:\n"
193
- "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
194
- "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
195
- "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
196
- "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
197
- "activities at least two days per week.\n"
198
- "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
199
- "vegetables, whole grains, lean proteins, and healthy fats can help support "
200
- "your overall health. Try to limit your intake of processed and high-sugar foods, "
201
- "and aim to drink plenty of water throughout the day.\n"
202
- "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
203
- "and mental health. Adults should aim for seven to nine hours of sleep per night. "
204
- "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
205
- "help improve the quality of your sleep.")
206
- ),
207
- offset=2,
208
- sep_style=SeparatorStyle.SINGLE,
209
- sep="###",
210
- )
211
-
212
- conv_v1_2 = Conversation(
213
- system="A chat between a curious human and an artificial intelligence assistant. "
214
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
215
- roles=("Human", "Assistant"),
216
- messages=(
217
- ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
218
- ("Assistant",
219
- "Renewable energy sources are those that can be replenished naturally in a relatively "
220
- "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
221
- "Non-renewable energy sources, on the other hand, are finite and will eventually be "
222
- "depleted, such as coal, oil, and natural gas. Here are some key differences between "
223
- "renewable and non-renewable energy sources:\n"
224
- "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
225
- "energy sources are finite and will eventually run out.\n"
226
- "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
227
- "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
228
- "and other negative effects.\n"
229
- "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
230
- "have lower operational costs than non-renewable sources.\n"
231
- "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
232
- "locations than non-renewable sources.\n"
233
- "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
234
- "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
235
- "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
236
- "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
237
- ),
238
- offset=2,
239
- sep_style=SeparatorStyle.SINGLE,
240
- sep="###",
241
- )
242
-
243
- conv_vicuna_v1_1 = Conversation(
244
- system="A chat between a curious user and an artificial intelligence assistant. "
245
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
246
- roles=("USER", "ASSISTANT"),
247
- version="v1",
248
- messages=(),
249
- offset=0,
250
- sep_style=SeparatorStyle.TWO,
251
- sep=" ",
252
- sep2="</s>",
253
- )
254
-
255
- conv_mpt = Conversation(
256
- system="""system
257
- - You are a helpful language and vision assistant.
258
- - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
259
- - You should follow the instructions carefully and explain your answers in detail.""",
260
- roles=("user\n", "assistant\n"),
261
- version="mpt",
262
- messages=(),
263
- offset=0,
264
- sep_style=SeparatorStyle.MPT,
265
- sep="",
266
- )
267
-
268
- conv_mpt_text = Conversation(
269
- system="""system
270
- - You are a helpful assistant chatbot trained by MosaicML.
271
- - You answer questions.
272
- - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
273
- - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
274
- roles=("user\n", "assistant\n"),
275
- version="mpt",
276
- messages=(),
277
- offset=0,
278
- sep_style=SeparatorStyle.MPT,
279
- sep="",
280
- )
281
-
282
- conv_bair_v1 = Conversation(
283
- system="BEGINNING OF CONVERSATION:",
284
- roles=("USER", "GPT"),
285
- messages=(),
286
- offset=0,
287
- sep_style=SeparatorStyle.TWO,
288
- sep=" ",
289
- sep2="</s>",
290
- )
291
-
292
- simple_conv = Conversation(
293
- system="A chat between a curious human and an artificial intelligence assistant. "
294
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
295
- roles=("Human", "Assistant"),
296
- messages=(
297
- ("Human", "Hi!"),
298
- ("Assistant", "Hi there! How can I help you today?")
299
- ),
300
- offset=2,
301
- sep_style=SeparatorStyle.SINGLE,
302
- sep="###",
303
- )
304
-
305
- simple_conv_multimodal = Conversation(
306
- system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
307
- "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
308
- "Follow the instructions carefully and explain your answers in detail.",
309
- roles=("Human", "Assistant"),
310
- messages=(
311
- ("Human", "Hi!"),
312
- ("Assistant", "Hi there! How can I help you today?\n")
313
- ),
314
- offset=2,
315
- sep_style=SeparatorStyle.SINGLE,
316
- sep="###",
317
- )
318
-
319
- simple_conv_mpt_multimodal = Conversation(
320
- system="""system
321
- - You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
322
- - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
323
- - You should follow the instructions carefully and explain your answers in detail.""",
324
- roles=("user\n", "assistant\n"),
325
- version="mpt",
326
- messages=(),
327
- offset=0,
328
- sep_style=SeparatorStyle.MPT,
329
- sep="",
330
- )
331
-
332
- simple_conv_legacy = Conversation(
333
- system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
334
- "You are designed to assist human with a variety of tasks using natural language."
335
- "Follow the instructions carefully.",
336
- roles=("Human", "Assistant"),
337
- messages=(
338
- ("Human", "Hi!\n\n### Response:"),
339
- ("Assistant", "Hi there! How can I help you today?\n")
340
- ),
341
- offset=2,
342
- sep_style=SeparatorStyle.SINGLE,
343
- sep="###",
344
- )
345
-
346
- conv_llava_v1 = Conversation(
347
- system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
348
- "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
349
- "Follow the instructions carefully and explain your answers in detail.",
350
- roles=("USER", "ASSISTANT"),
351
- version="v1",
352
- messages=(),
353
- offset=0,
354
- sep_style=SeparatorStyle.TWO,
355
- sep=" ",
356
- sep2="</s>",
357
- )
358
-
359
- default_conversation = conv_v1_2
360
- conv_templates = {
361
- "default": conv_v1_2,
362
- "simple": simple_conv,
363
- "simple_legacy": simple_conv_legacy,
364
- "multimodal": simple_conv_multimodal,
365
- "mpt_multimodal": simple_conv_mpt_multimodal,
366
- "llava_v1": conv_llava_v1,
367
-
368
- # fastchat
369
- "v1": conv_v1_2,
370
- "bair_v1": conv_bair_v1,
371
- "vicuna_v1_1": conv_vicuna_v1_1,
372
- "mpt": conv_mpt,
373
- "mpt_text": conv_mpt_text,
374
- }
375
-
376
-
377
- if __name__ == "__main__":
378
- print(default_conversation.get_prompt())
379
-
380
- # mgie_llava.py
381
  from typing import List, Optional, Tuple, Union
382
 
383
  import torch
@@ -398,9 +19,11 @@ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
398
  DEFAULT_IM_START_TOKEN = "<im_start>"
399
  DEFAULT_IM_END_TOKEN = "<im_end>"
400
 
 
401
  class LlavaConfig(LlamaConfig):
402
  model_type = "llava"
403
 
 
404
  class LlavaLlamaModel(LlamaModel):
405
  config_class = LlavaConfig
406
 
@@ -776,133 +399,4 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
776
  vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
777
 
778
  AutoConfig.register("llava", LlavaConfig)
779
- AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
780
-
781
- # main.py
782
- from google.colab import drive
783
- drive.mount('/content/drive')
784
-
785
- import os
786
- from PIL import Image
787
- import numpy as np
788
- import torch as T
789
- import transformers
790
- import diffusers
791
- import gradio as gr
792
- import huggingface_hub
793
-
794
- CKPT_DIR = '/content/drive/My Drive/_ckpt'
795
-
796
- def crop_resize(f, sz=512):
797
- w, h = f.size
798
- if w > h:
799
- p = (w - h) // 2
800
- f = f.crop([p, 0, p + h, h])
801
- elif h > w:
802
- p = (h - w) // 2
803
- f = f.crop([0, p, w, p + w])
804
- f = f.resize([sz, sz])
805
- return f
806
-
807
- def remove_alter(s):
808
- if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:') + 10:].strip()
809
- if '</s>' in s: s = s[:s.index('</s>')].strip()
810
- if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')]
811
- if '[IMG0]' in s: s = s[:s.index('[IMG0]')]
812
- s = '.'.join([s.strip() for s in s.split('.')[:2]])
813
- if s[-1] != '.': s += '.'
814
- return s.strip()
815
-
816
- DEFAULT_IMAGE_TOKEN = '<image>'
817
- DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
818
- DEFAULT_IM_START_TOKEN = '<im_start>'
819
- DEFAULT_IM_END_TOKEN = '<im_end>'
820
- PATH_LLAVA = f'{CKPT_DIR}/LLaVA-7B-v1'
821
-
822
- tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)
823
- model = LlavaLlamaForCausalLM.from_pretrained(PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()
824
- image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=T.float16)
825
-
826
- tokenizer.padding_side = 'left'
827
- tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
828
- model.resize_token_embeddings(len(tokenizer))
829
- ckpt = T.load(f'{CKPT_DIR}/mgie_7b/mllm.pt', map_location='cpu')
830
- model.load_state_dict(ckpt, strict=False)
831
-
832
- mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
833
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
834
- if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
835
-
836
- vision_tower = model.get_model().vision_tower[0]
837
- vision_tower = transformers.CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=T.float16, low_cpu_mem_usage=True).cuda()
838
- model.get_model().vision_tower[0] = vision_tower
839
- vision_config = vision_tower.config
840
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
841
- vision_config.use_im_start_end = mm_use_im_start_end
842
- if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
843
- image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
844
-
845
- _ = model.eval()
846
-
847
- pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16).to('cuda')
848
- pipe.set_progress_bar_config(disable=True)
849
- pipe.unet.load_state_dict(T.load(f'{CKPT_DIR}/mgie_7b/unet.pt', map_location='cpu'))
850
- print('--init MGIE--')
851
-
852
- def go_mgie(img, txt, seed, cfg_txt, cfg_img):
853
- EMB = ckpt['emb'].cuda()
854
- with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)
855
-
856
- img, seed = crop_resize(Image.fromarray(img).convert('RGB')), int(seed)
857
- inp = img
858
-
859
- img = image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0]
860
- txt = "what will this image be like if '%s'" % (txt)
861
- txt = txt + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
862
- conv = conv_templates['vicuna_v1_1'].copy()
863
- conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None)
864
- txt = conv.get_prompt()
865
- txt = tokenizer(txt)
866
- txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(txt['attention_mask'])
867
-
868
- with T.inference_mode():
869
- _ = model.cuda()
870
- out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(),
871
- do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3,
872
- return_dict_in_generate=True, output_hidden_states=True)
873
- out, hid = out['sequences'][0].tolist(), T.cat([x[-1] for x in out['hidden_states']], dim=1)[0]
874
-
875
- if 32003 in out: p = out.index(32003) - 1
876
- else: p = len(hid) - 9
877
- p = min(p, len(hid) - 9)
878
- hid = hid[p:p + 8]
879
-
880
- out = remove_alter(tokenizer.decode(out))
881
- _ = model.cuda()
882
- emb = model.edit_head(hid.unsqueeze(dim=0), EMB)
883
- res = pipe(image=inp, prompt_embeds=emb, negative_prompt_embeds=NULL,
884
- generator=T.Generator(device='cuda').manual_seed(seed), guidance_scale=cfg_txt, image_guidance_scale=cfg_img).images[0]
885
-
886
- return res, out
887
-
888
- with gr.Blocks() as app:
889
- gr.Markdown(
890
- """
891
- # MagiX: Edit Personalized Images using Gen AI by Ateeb Taser
892
- """
893
- )
894
- with gr.Row():
895
- inp, res = [gr.Image(height=384, width=384, label='Input Image', interactive=True),
896
- gr.Image(height=384, width=384, label='Goal Image', interactive=True)]
897
- with gr.Row():
898
- txt, out = [gr.Textbox(label='Instruction', interactive=True),
899
- gr.Textbox(label='Expressive Instruction', interactive=False)]
900
- with gr.Row():
901
- seed, cfg_txt, cfg_img = [gr.Number(value=13331, label='Seed', interactive=True),
902
- gr.Number(value=7.5, label='Text CFG', interactive=True),
903
- gr.Number(value=1.5, label='Image CFG', interactive=True)]
904
- with gr.Row():
905
- btn_sub = gr.Button('Submit')
906
- btn_sub.click(fn=go_mgie, inputs=[inp, txt, seed, cfg_txt, cfg_img], outputs=[res, out])
907
-
908
- app.launch()
 
1
+ #mgie_llava.py:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from typing import List, Optional, Tuple, Union
3
 
4
  import torch
 
19
  DEFAULT_IM_START_TOKEN = "<im_start>"
20
  DEFAULT_IM_END_TOKEN = "<im_end>"
21
 
22
+
23
  class LlavaConfig(LlamaConfig):
24
  model_type = "llava"
25
 
26
+
27
  class LlavaLlamaModel(LlamaModel):
28
  config_class = LlavaConfig
29
 
 
399
  vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
400
 
401
  AutoConfig.register("llava", LlavaConfig)
402
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)