Vision-CAIR commited on
Commit
85efb5b
1 Parent(s): 679922c

Upload 39 files

Browse files
Files changed (40) hide show
  1. .gitattributes +2 -0
  2. app.py +317 -0
  3. examples/video1.mp4 +3 -0
  4. examples/video2.mp4 +3 -0
  5. inference.py +94 -0
  6. longvu/.DS_Store +0 -0
  7. longvu/__init__.py +3 -0
  8. longvu/apply_delta.py +59 -0
  9. longvu/builder.py +249 -0
  10. longvu/cambrian_arch.py +1705 -0
  11. longvu/consolidate.py +33 -0
  12. longvu/constants.py +13 -0
  13. longvu/conversation.py +606 -0
  14. longvu/file_io.py +11 -0
  15. longvu/language_model/__pycache__/cambrian_llama.cpython-310.pyc +0 -0
  16. longvu/language_model/__pycache__/cambrian_qwen.cpython-310.pyc +0 -0
  17. longvu/language_model/cambrian_llama.py +546 -0
  18. longvu/language_model/cambrian_qwen.py +471 -0
  19. longvu/make_delta.py +66 -0
  20. longvu/mm_datautils.py +1688 -0
  21. longvu/mm_utils.py +327 -0
  22. longvu/multimodal_encoder/__pycache__/base_encoder.cpython-310.pyc +0 -0
  23. longvu/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  24. longvu/multimodal_encoder/__pycache__/dino_encoder.cpython-310.pyc +0 -0
  25. longvu/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc +0 -0
  26. longvu/multimodal_encoder/base_encoder.py +135 -0
  27. longvu/multimodal_encoder/builder.py +37 -0
  28. longvu/multimodal_encoder/dino_encoder.py +131 -0
  29. longvu/multimodal_encoder/drop.py +41 -0
  30. longvu/multimodal_encoder/image.py +80 -0
  31. longvu/multimodal_encoder/logging.py +131 -0
  32. longvu/multimodal_encoder/loss.py +96 -0
  33. longvu/multimodal_encoder/registry.py +56 -0
  34. longvu/multimodal_encoder/siglip_encoder.py +78 -0
  35. longvu/multimodal_encoder/utils.py +66 -0
  36. longvu/multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
  37. longvu/multimodal_projector/builder.py +52 -0
  38. longvu/utils.py +25 -0
  39. longvu/vision_sampler.py +566 -0
  40. requirements.txt +28 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/video1.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ examples/video2.mp4 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import os
4
+ import re
5
+ import traceback
6
+
7
+ import torch
8
+ import gradio as gr
9
+
10
+ import sys
11
+
12
+ import numpy as np
13
+
14
+ from longvu.builder import load_pretrained_model
15
+ from longvu.constants import (
16
+ DEFAULT_IMAGE_TOKEN,
17
+ IMAGE_TOKEN_INDEX,
18
+ )
19
+ from longvu.conversation import conv_templates, SeparatorStyle
20
+ from longvu.mm_datautils import (
21
+ KeywordsStoppingCriteria,
22
+ process_images,
23
+ tokenizer_image_token,
24
+ )
25
+ from decord import cpu, VideoReader
26
+
27
+
28
+ title_markdown = ("""
29
+ LongVU
30
+ """)
31
+
32
+ block_css = """
33
+ #buttons button {
34
+ min-width: min(120px,100%);
35
+ color: #9C276A
36
+ }
37
+ """
38
+
39
+ plum_color = gr.themes.colors.Color(
40
+ name='plum',
41
+ c50='#F8E4EF',
42
+ c100='#E9D0DE',
43
+ c200='#DABCCD',
44
+ c300='#CBA8BC',
45
+ c400='#BC94AB',
46
+ c500='#AD809A',
47
+ c600='#9E6C89',
48
+ c700='#8F5878',
49
+ c800='#804467',
50
+ c900='#713056',
51
+ c950='#662647',
52
+ )
53
+
54
+
55
+ class Chat:
56
+
57
+ def __init__(self):
58
+ self.version = "qwen"
59
+ model_name = "cambrian_qwen"
60
+ model_path = "./checkpoints/longvu_qwen"
61
+ device = "cuda:7"
62
+
63
+ self.tokenizer, self.model, self.processor, _ = load_pretrained_model(model_path, None, model_name, device=device)
64
+ self.model.eval()
65
+
66
+ def remove_after_last_dot(self, s):
67
+ last_dot_index = s.rfind('.')
68
+ if last_dot_index == -1:
69
+ return s
70
+ return s[:last_dot_index + 1]
71
+
72
+ @spaces.GPU(duration=120)
73
+ @torch.inference_mode()
74
+ def generate(self, data: list, message, temperature, top_p, max_output_tokens):
75
+ # TODO: support multiple turns of conversation.
76
+ assert len(data) == 1
77
+
78
+ tensor, image_sizes, modal = data[0]
79
+
80
+ conv = conv_templates[self.version].copy()
81
+
82
+ if isinstance(message, str):
83
+ conv.append_message("user", DEFAULT_IMAGE_TOKEN + '\n' + message)
84
+ elif isinstance(message, list):
85
+ if DEFAULT_IMAGE_TOKEN not in message[0]['content']:
86
+ message[0]['content'] = DEFAULT_IMAGE_TOKEN + '\n' + message[0]['content']
87
+ for mes in message:
88
+ conv.append_message(mes["role"], mes["content"])
89
+
90
+ conv.append_message("assistant", None)
91
+
92
+ prompt = conv.get_prompt()
93
+
94
+ input_ids = (
95
+ tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
96
+ .unsqueeze(0)
97
+ .to(self.model.device)
98
+ )
99
+
100
+ if "llama3" in self.version:
101
+ input_ids = input_ids[0][1:].unsqueeze(0) # remove bos
102
+
103
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
104
+ keywords = [stop_str]
105
+ stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
106
+ with torch.inference_mode():
107
+ output_ids = self.model.generate(
108
+ input_ids,
109
+ images=tensor,
110
+ image_sizes=image_sizes,
111
+ do_sample=True,
112
+ temperature=temperature,
113
+ max_new_tokens=max_output_tokens,
114
+ use_cache=True,
115
+ top_p=top_p,
116
+ stopping_criteria=[stopping_criteria],
117
+ )
118
+
119
+ pred = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
120
+
121
+ return self.remove_after_last_dot(pred)
122
+
123
+
124
+ @spaces.GPU(duration=120)
125
+ def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
126
+ data = []
127
+
128
+ processor = handler.processor
129
+ try:
130
+ if image is not None:
131
+ data.append((processor['image'](image).to(handler.model.device, dtype=dtype), None, '<image>'))
132
+ elif video is not None:
133
+ vr = VideoReader(video, ctx=cpu(0), num_threads=1)
134
+ fps = float(vr.get_avg_fps())
135
+ frame_indices = np.array(
136
+ [
137
+ i
138
+ for i in range(
139
+ 0,
140
+ len(vr),
141
+ round(fps),
142
+ )
143
+ ]
144
+ )
145
+ video_tensor = []
146
+ for frame_index in frame_indices:
147
+ img = vr[frame_index].asnumpy()
148
+ video_tensor.append(img)
149
+ video_tensor = np.stack(video_tensor)
150
+ image_sizes = [video_tensor[0].shape[:2]]
151
+ video_tensor = process_images(video_tensor, processor, handler.model.config)
152
+ video_tensor = [item.unsqueeze(0).to(handler.model.device, dtype=dtype) for item in video_tensor]
153
+ data.append((video_tensor, image_sizes, '<video>'))
154
+ elif image is None and video is None:
155
+ data.append((None, None, '<text>'))
156
+ else:
157
+ raise NotImplementedError("Not support image and video at the same time")
158
+ except Exception as e:
159
+ traceback.print_exc()
160
+ return gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), message, chatbot
161
+
162
+ assert len(message) % 2 == 0, "The message should be a pair of user and system message."
163
+
164
+ show_images = ""
165
+ if image is not None:
166
+ show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">'
167
+ if video is not None:
168
+ show_images += f'<video controls playsinline width="300" style="display: inline-block;" src="./file={video}"></video>'
169
+
170
+ one_turn_chat = [textbox_in, None]
171
+
172
+ # 1. first run case
173
+ if len(chatbot) == 0:
174
+ one_turn_chat[0] += "\n" + show_images
175
+ # 2. not first run case
176
+ else:
177
+ # scanning the last image or video
178
+ length = len(chatbot)
179
+ for i in range(length - 1, -1, -1):
180
+ previous_image = re.findall(r'<img src="./file=(.+?)"', chatbot[i][0])
181
+ previous_video = re.findall(r'<video controls playsinline width="500" style="display: inline-block;" src="./file=(.+?)"', chatbot[i][0])
182
+
183
+ if len(previous_image) > 0:
184
+ previous_image = previous_image[-1]
185
+ # 2.1 new image append or pure text input will start a new conversation
186
+ if (video is not None) or (image is not None and os.path.basename(previous_image) != os.path.basename(image)):
187
+ message.clear()
188
+ one_turn_chat[0] += "\n" + show_images
189
+ break
190
+ elif len(previous_video) > 0:
191
+ previous_video = previous_video[-1]
192
+ # 2.2 new video append or pure text input will start a new conversation
193
+ if image is not None or (video is not None and os.path.basename(previous_video) != os.path.basename(video)):
194
+ message.clear()
195
+ one_turn_chat[0] += "\n" + show_images
196
+ break
197
+
198
+ message.append({'role': 'user', 'content': textbox_in})
199
+ text_en_out = handler.generate(data, message, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens)
200
+ message.append({'role': 'assistant', 'content': text_en_out})
201
+
202
+ one_turn_chat[1] = text_en_out
203
+ chatbot.append(one_turn_chat)
204
+
205
+ return gr.update(value=image, interactive=True), gr.update(value=video, interactive=True), message, chatbot
206
+
207
+
208
+ def regenerate(message, chatbot):
209
+ message.pop(-1), message.pop(-1)
210
+ chatbot.pop(-1)
211
+ return message, chatbot
212
+
213
+
214
+ def clear_history(message, chatbot):
215
+ message.clear(), chatbot.clear()
216
+ return (gr.update(value=None, interactive=True),
217
+ gr.update(value=None, interactive=True),
218
+ message, chatbot,
219
+ gr.update(value=None, interactive=True))
220
+
221
+ handler = Chat()
222
+
223
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
224
+
225
+ theme = gr.themes.Default(primary_hue=plum_color)
226
+ # theme.update_color("primary", plum_color.c500)
227
+ theme.set(slider_color="#9C276A")
228
+ theme.set(block_title_text_color="#9C276A")
229
+ theme.set(block_label_text_color="#9C276A")
230
+ theme.set(button_primary_text_color="#9C276A")
231
+
232
+ with gr.Blocks(title='LongVU', theme=theme, css=block_css) as demo:
233
+ gr.Markdown(title_markdown)
234
+ message = gr.State([])
235
+
236
+ with gr.Row():
237
+ with gr.Column(scale=3):
238
+ image = gr.State(None)
239
+ video = gr.Video(label="Input Video")
240
+
241
+ with gr.Accordion("Parameters", open=True) as parameter_row:
242
+
243
+ temperature = gr.Slider(
244
+ minimum=0.1,
245
+ maximum=1.0,
246
+ value=0.2,
247
+ step=0.1,
248
+ interactive=True,
249
+ label="Temperature",
250
+ )
251
+
252
+ top_p = gr.Slider(
253
+ minimum=0.0,
254
+ maximum=1.0,
255
+ value=0.7,
256
+ step=0.1,
257
+ interactive=True,
258
+ label="Top P",
259
+ )
260
+
261
+ max_output_tokens = gr.Slider(
262
+ minimum=64,
263
+ maximum=512,
264
+ value=128,
265
+ step=64,
266
+ interactive=True,
267
+ label="Max output tokens",
268
+ )
269
+
270
+ with gr.Column(scale=7):
271
+ chatbot = gr.Chatbot(label="LongVU", bubble_full_width=True, height=420)
272
+ with gr.Row():
273
+ with gr.Column(scale=8):
274
+ textbox.render()
275
+ with gr.Column(scale=1, min_width=50):
276
+ submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
277
+ with gr.Row(elem_id="buttons") as button_row:
278
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
279
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
280
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
281
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
282
+
283
+ with gr.Row():
284
+ with gr.Column():
285
+ gr.Examples(
286
+ examples=[
287
+ [
288
+ f"./examples/video1.mp4",
289
+ "Describe this video in detail.",
290
+ ],
291
+ [
292
+ f"./examples/video2.mp4",
293
+ "Which country does the boy in the video probably come from?",
294
+ ]
295
+ ],
296
+ inputs=[video, textbox],
297
+ )
298
+
299
+ submit_btn.click(
300
+ generate,
301
+ [image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
302
+ [image, video, message, chatbot])
303
+
304
+ regenerate_btn.click(
305
+ regenerate,
306
+ [message, chatbot],
307
+ [message, chatbot]).then(
308
+ generate,
309
+ [image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
310
+ [image, video, message, chatbot])
311
+
312
+ clear_btn.click(
313
+ clear_history,
314
+ [message, chatbot],
315
+ [image, video, message, chatbot, textbox])
316
+
317
+ demo.launch()
examples/video1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0431514403beba3c269a318b4e5eb6c08cd6940edb10d5014b4745c5fee31ac0
3
+ size 1171735
examples/video2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3069129741c1ed79524a3eeb138ce4f99a9b7fedf36652afd8ad7d83b1d6008b
3
+ size 1606730
inference.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+
5
+ from longvu.builder import load_pretrained_model
6
+ from longvu.constants import (
7
+ DEFAULT_IMAGE_TOKEN,
8
+ IMAGE_TOKEN_INDEX,
9
+ )
10
+ from longvu.conversation import conv_templates, SeparatorStyle
11
+ from longvu.mm_datautils import (
12
+ KeywordsStoppingCriteria,
13
+ process_images,
14
+ tokenizer_image_token,
15
+ )
16
+ from decord import cpu, VideoReader
17
+
18
+ version = "qwen"
19
+ model_name = "cambrian_qwen"
20
+ input_model_local_path = "./checkpoints/longvu_qwen"
21
+
22
+ device = "cuda:7"
23
+
24
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
25
+ input_model_local_path, None, model_name, device=device
26
+ )
27
+ model.get_model().config.tokenizer_model_max_length = 8192
28
+ model.get_model().config.inference_max_length = 128
29
+ model.config.use_cache = True
30
+ print(model.device)
31
+
32
+ model.eval()
33
+
34
+ video_path = "./examples/video1.mp4"
35
+ qs = "Describe this video in detail"
36
+
37
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
38
+ fps = float(vr.get_avg_fps())
39
+ frame_indices = np.array(
40
+ [
41
+ i
42
+ for i in range(
43
+ 0,
44
+ len(vr),
45
+ round(fps),
46
+ )
47
+ ]
48
+ )
49
+ video = []
50
+ for frame_index in frame_indices:
51
+ img = vr[frame_index].asnumpy()
52
+ video.append(img)
53
+ video = np.stack(video)
54
+ image_sizes = [video[0].shape[:2]]
55
+ video = process_images(video, image_processor, model.config)
56
+ video = [item.unsqueeze(0) for item in video]
57
+
58
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
59
+
60
+ conv = conv_templates[version].copy()
61
+ conv.append_message(conv.roles[0], qs)
62
+ conv.append_message(conv.roles[1], None)
63
+ prompt = conv.get_prompt()
64
+
65
+ input_ids = (
66
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
67
+ .unsqueeze(0)
68
+ .to(model.device)
69
+ )
70
+
71
+ if "llama3" in version:
72
+ input_ids = input_ids[0][1:].unsqueeze(0) # remove bos
73
+
74
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
75
+ keywords = [stop_str]
76
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
77
+ with torch.inference_mode():
78
+ output_ids = model.generate(
79
+ input_ids,
80
+ images=video,
81
+ image_sizes=image_sizes,
82
+ do_sample=False,
83
+ temperature=0.2,
84
+ max_new_tokens=128,
85
+ use_cache=True,
86
+ stopping_criteria=[stopping_criteria],
87
+ )
88
+
89
+ if isinstance(output_ids, tuple):
90
+ output_ids = output_ids[0]
91
+
92
+ pred = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
93
+
94
+ print("pred: ", pred, flush=True)
longvu/.DS_Store ADDED
Binary file (8.2 kB). View file
 
longvu/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # pyre-unsafe
2
+ from .language_model.cambrian_qwen import CambrianQwenModel
3
+ from .language_model.cambrian_llama import CambrianLlamaModel
longvu/apply_delta.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ """
3
+ Usage:
4
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
5
+ """
6
+
7
+ import argparse
8
+
9
+ import torch
10
+ from tqdm import tqdm
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+
13
+ from . import LlavaLlamaForCausalLM
14
+
15
+
16
+ def apply_delta(base_model_path, target_model_path, delta_path):
17
+ print("Loading base model")
18
+ base = AutoModelForCausalLM.from_pretrained(
19
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
20
+ )
21
+
22
+ print("Loading delta")
23
+ delta = LlavaLlamaForCausalLM.from_pretrained(
24
+ delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
25
+ )
26
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
27
+
28
+ print("Applying delta")
29
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
30
+ if name not in base.state_dict():
31
+ assert name in [
32
+ "model.mm_projector.weight",
33
+ "model.mm_projector.bias",
34
+ ], f"{name} not in base model"
35
+ continue
36
+ if param.data.shape == base.state_dict()[name].shape:
37
+ param.data += base.state_dict()[name]
38
+ else:
39
+ assert name in [
40
+ "model.embed_tokens.weight",
41
+ "lm_head.weight",
42
+ ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
43
+ bparam = base.state_dict()[name]
44
+ param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
45
+
46
+ print("Saving target model")
47
+ delta.save_pretrained(target_model_path)
48
+ delta_tokenizer.save_pretrained(target_model_path)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument("--base-model-path", type=str, required=True)
54
+ parser.add_argument("--target-model-path", type=str, required=True)
55
+ parser.add_argument("--delta-path", type=str, required=True)
56
+
57
+ args = parser.parse_args()
58
+
59
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
longvu/builder.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pyre-unsafe
16
+
17
+
18
+ import os
19
+ import shutil
20
+ import warnings
21
+
22
+ import torch
23
+ from longvu.constants import (
24
+ DEFAULT_IM_END_TOKEN,
25
+ DEFAULT_IM_START_TOKEN,
26
+ DEFAULT_IMAGE_PATCH_TOKEN,
27
+ )
28
+
29
+ from longvu.language_model.cambrian_llama import CambrianLlamaForCausalLM
30
+ from longvu.language_model.cambrian_qwen import CambrianQwenForCausalLM
31
+
32
+ from transformers import (
33
+ AutoConfig,
34
+ AutoModelForCausalLM,
35
+ AutoTokenizer,
36
+ BitsAndBytesConfig,
37
+ )
38
+
39
+
40
+ def load_pretrained_model(
41
+ model_path,
42
+ model_base,
43
+ model_name,
44
+ load_8bit=False,
45
+ load_4bit=False,
46
+ device_map="auto",
47
+ device="cuda",
48
+ use_flash_attn=False,
49
+ model_args=None,
50
+ **kwargs,
51
+ ):
52
+ kwargs = {"device_map": device_map, **kwargs}
53
+
54
+ if device != "cuda":
55
+ kwargs["device_map"] = {"": device}
56
+
57
+ if load_8bit:
58
+ kwargs["load_in_8bit"] = True
59
+ elif load_4bit:
60
+ kwargs["load_in_4bit"] = True
61
+ kwargs["quantization_config"] = BitsAndBytesConfig(
62
+ load_in_4bit=True,
63
+ bnb_4bit_compute_dtype=torch.float16,
64
+ bnb_4bit_use_double_quant=True,
65
+ bnb_4bit_quant_type="nf4",
66
+ )
67
+ else:
68
+ kwargs["torch_dtype"] = torch.float16
69
+
70
+ if use_flash_attn:
71
+ kwargs["attn_implementation"] = "flash_attention_2"
72
+
73
+ if "cambrian" in model_name.lower():
74
+ # Load Cambrian model
75
+ if "lora" in model_name.lower() and model_base is None:
76
+ warnings.warn(
77
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
78
+ )
79
+ if "lora" in model_name.lower() and model_base is not None:
80
+ # pyre-fixme[21]: Could not find module
81
+ # `core_ai.llava.language_model.cambrian_llama`.
82
+ from core_ai.llava.language_model.cambrian_llama import CambrianConfig
83
+
84
+ lora_cfg_pretrained = CambrianConfig.from_pretrained(model_path)
85
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
86
+ print("Loading Cambrian from base model...")
87
+ model = CambrianLlamaForCausalLM.from_pretrained(
88
+ model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs
89
+ )
90
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
91
+ if model.lm_head.weight.shape[0] != token_num:
92
+ model.lm_head.weight = torch.nn.Parameter(
93
+ torch.empty(
94
+ token_num, tokem_dim, device=model.device, dtype=model.dtype
95
+ )
96
+ )
97
+ model.model.embed_tokens.weight = torch.nn.Parameter(
98
+ torch.empty(
99
+ token_num, tokem_dim, device=model.device, dtype=model.dtype
100
+ )
101
+ )
102
+
103
+ print("Loading additional Cambrian weights...")
104
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
105
+ non_lora_trainables = torch.load(
106
+ os.path.join(model_path, "non_lora_trainables.bin"),
107
+ map_location="cpu",
108
+ )
109
+ else:
110
+ # this is probably from HF Hub
111
+ from huggingface_hub import hf_hub_download
112
+
113
+ def load_from_hf(repo_id, filename, subfolder=None):
114
+ cache_file = hf_hub_download(
115
+ repo_id=repo_id, filename=filename, subfolder=subfolder
116
+ )
117
+ return torch.load(cache_file, map_location="cpu")
118
+
119
+ non_lora_trainables = load_from_hf(
120
+ model_path, "non_lora_trainables.bin"
121
+ )
122
+ non_lora_trainables = {
123
+ (k[11:] if k.startswith("base_model.") else k): v
124
+ for k, v in non_lora_trainables.items()
125
+ }
126
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
127
+ non_lora_trainables = {
128
+ (k[6:] if k.startswith("model.") else k): v
129
+ for k, v in non_lora_trainables.items()
130
+ }
131
+ model.load_state_dict(non_lora_trainables, strict=False)
132
+
133
+ from peft import PeftModel
134
+
135
+ print("Loading LoRA weights...")
136
+ model = PeftModel.from_pretrained(model, model_path)
137
+ print("Merging LoRA weights...")
138
+ model = model.merge_and_unload()
139
+ print("Model is loaded...")
140
+ elif model_base is not None:
141
+ # this may be mm projector only
142
+ print(f"Loading Cambrian-1 from base model... {model_base}")
143
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
144
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
145
+ model = CambrianLlamaForCausalLM.from_pretrained(
146
+ model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
147
+ )
148
+
149
+ mm_projector_weights = torch.load(
150
+ os.path.join(model_path, "mm_projector.bin"), map_location="cpu"
151
+ )
152
+ mm_projector_weights = {
153
+ k: v.to(torch.float16) for k, v in mm_projector_weights.items()
154
+ }
155
+ model.load_state_dict(mm_projector_weights, strict=False)
156
+ else:
157
+ if "qwen" in model_name.lower():
158
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
159
+ model = CambrianQwenForCausalLM.from_pretrained(
160
+ model_path, low_cpu_mem_usage=True, **kwargs
161
+ )
162
+ else:
163
+ print(f"Loading Cambrian from {model_path}")
164
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
165
+ model = CambrianLlamaForCausalLM.from_pretrained(
166
+ model_path, low_cpu_mem_usage=True, **kwargs
167
+ )
168
+ else:
169
+ # Load language model
170
+ if model_base is not None:
171
+ # PEFT model
172
+ from peft import PeftModel
173
+
174
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
175
+ model = AutoModelForCausalLM.from_pretrained(
176
+ model_base, low_cpu_mem_usage=True, **kwargs
177
+ )
178
+ print(f"Loading LoRA weights from {model_path}")
179
+ model = PeftModel.from_pretrained(model, model_path)
180
+ print(f"Merging weights")
181
+ model = model.merge_and_unload()
182
+ print("Convert to FP16...")
183
+ model.to(torch.float16)
184
+ else:
185
+ use_fast = False
186
+ if "mpt" in model_name.lower():
187
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
188
+ model = AutoModelForCausalLM.from_pretrained(
189
+ model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs
190
+ )
191
+ else:
192
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
193
+ model = AutoModelForCausalLM.from_pretrained(
194
+ model_path, low_cpu_mem_usage=True, **kwargs
195
+ )
196
+
197
+ image_processor = None
198
+
199
+ if "llava" in model_name.lower():
200
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
201
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
202
+ if mm_use_im_patch_token:
203
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
204
+ if mm_use_im_start_end:
205
+ tokenizer.add_tokens(
206
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
207
+ )
208
+ model.resize_token_embeddings(len(tokenizer))
209
+
210
+ vision_tower = model.get_vision_tower()
211
+ if not vision_tower.is_loaded:
212
+ try:
213
+ vision_tower.load_model(device_map=device_map)
214
+ except ValueError:
215
+ # ClipVisionTower doesn't support loading with device_map 'auto'
216
+ vision_tower.load_model()
217
+ vision_tower.to(device="cuda", dtype=torch.float16)
218
+ if device_map != "auto":
219
+ vision_tower.to(device=device_map, dtype=torch.float16)
220
+ image_processor = vision_tower.image_processor
221
+ elif "cambrian" in model_name.lower():
222
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
223
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
224
+ if mm_use_im_patch_token:
225
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
226
+ if mm_use_im_start_end:
227
+ tokenizer.add_tokens(
228
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
229
+ )
230
+ model.resize_token_embeddings(len(tokenizer))
231
+
232
+ vision_tower_aux_list = model.get_vision_tower_aux_list()
233
+
234
+ for vision_tower_aux in vision_tower_aux_list:
235
+ if not vision_tower_aux.is_loaded:
236
+ vision_tower_aux.load_model(device_map=device_map)
237
+ vision_tower_aux.to(device=device, dtype=torch.float16)
238
+
239
+ image_processor = [
240
+ vision_tower_aux.image_processor
241
+ for vision_tower_aux in vision_tower_aux_list
242
+ ]
243
+
244
+ if hasattr(model.config, "max_sequence_length"):
245
+ context_len = model.config.max_sequence_length
246
+ else:
247
+ context_len = 2048
248
+
249
+ return tokenizer, model, image_processor, context_len
longvu/cambrian_arch.py ADDED
@@ -0,0 +1,1705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import math
17
+ import random
18
+ from abc import ABC, abstractmethod
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ from longvu.constants import (
25
+ DEFAULT_IM_END_TOKEN,
26
+ DEFAULT_IM_START_TOKEN,
27
+ DEFAULT_IMAGE_PATCH_TOKEN,
28
+ IGNORE_INDEX,
29
+ IMAGE_TOKEN_INDEX,
30
+ )
31
+
32
+ from .multimodal_encoder.builder import build_vision_tower_aux_list
33
+ from .multimodal_projector.builder import build_vision_projector
34
+ from .vision_sampler import VisionTokenSampler
35
+
36
+ IS_XLA_AVAILABLE = False
37
+
38
+
39
+ class CambrianMetaModel:
40
+
41
+ def __init__(self, config):
42
+ super(CambrianMetaModel, self).__init__(config)
43
+
44
+ if hasattr(config, "mm_vision_tower_aux_list"):
45
+
46
+ projector_type = getattr(config, "mm_projector_type", "linear")
47
+ if projector_type == "sva":
48
+
49
+ vision_hidden_size = config.vision_hidden_size
50
+ num_query_group = config.num_query_group
51
+ query_num_list = config.query_num_list
52
+ connector_only = config.connector_only
53
+ connector_depth = config.connector_depth
54
+ self.vision_tower_aux_list = build_vision_tower_aux_list(
55
+ config, delay_load=True
56
+ )
57
+ self.mm_projector = nn.Sequential(
58
+ nn.Linear(vision_hidden_size * num_query_group, config.hidden_size),
59
+ nn.GELU(),
60
+ nn.Linear(config.hidden_size, config.hidden_size),
61
+ )
62
+
63
+ image_token_len = config.image_token_len
64
+ vision_tower_aux_token_len_list = (
65
+ self.config.mm_vision_tower_aux_token_len_list
66
+ )
67
+ cross_att_token_len_list = [
68
+ int(vision_tower_aux_token_len**0.5) // int(image_token_len**0.5)
69
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
70
+ ]
71
+
72
+ for aux_i, vision_tower_aux in enumerate(self.vision_tower_aux_list):
73
+ setattr(
74
+ self,
75
+ "mm_projector_aux_{}".format(aux_i),
76
+ nn.Sequential(
77
+ nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size),
78
+ nn.GELU(),
79
+ nn.Linear(vision_hidden_size, vision_hidden_size),
80
+ nn.LayerNorm(vision_hidden_size),
81
+ ),
82
+ )
83
+
84
+ for query_group_i in range(num_query_group):
85
+ cross_att_token_len_list = [
86
+ int(vision_tower_aux_token_len**0.5)
87
+ // int(query_num_list[query_group_i] ** 0.5)
88
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
89
+ ]
90
+ setattr(
91
+ self,
92
+ "vision_sampler_{}".format(query_group_i),
93
+ VisionTokenSampler(
94
+ vision_hidden_size,
95
+ vision_hidden_size,
96
+ [vision_hidden_size] * len(self.vision_tower_aux_list),
97
+ cross_att_token_len_list,
98
+ vision_hidden_size,
99
+ connector_depth,
100
+ ),
101
+ )
102
+
103
+ if not connector_only:
104
+ num_of_vision_sampler_layers = (
105
+ config.num_of_vision_sampler_layers
106
+ ) = config.num_of_vision_sampler_layers
107
+ config.start_of_vision_sampler_layers = (
108
+ config.start_of_vision_sampler_layers
109
+ )
110
+ config.stride_of_vision_sampler_layers = (
111
+ config.stride_of_vision_sampler_layers
112
+ )
113
+ cross_att_token_len_list = [
114
+ int(vision_tower_aux_token_len**0.5)
115
+ // int(image_token_len**0.5)
116
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
117
+ ]
118
+ self.vision_sampler_layers = nn.ModuleList(
119
+ [
120
+ VisionTokenSampler(
121
+ config.hidden_size,
122
+ vision_hidden_size,
123
+ [vision_hidden_size] * len(self.vision_tower_aux_list),
124
+ cross_att_token_len_list,
125
+ vision_hidden_size,
126
+ 1,
127
+ )
128
+ for layer_idx in range(0, num_of_vision_sampler_layers)
129
+ ]
130
+ )
131
+
132
+ self.vision_query = nn.Parameter(
133
+ torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype)
134
+ )
135
+
136
+ self.image_newline = nn.Parameter(
137
+ torch.empty(config.hidden_size, dtype=self.dtype)
138
+ )
139
+
140
+ self.frame_pos = torch.stack(
141
+ [
142
+ 1
143
+ / torch.pow(
144
+ torch.tensor(10000),
145
+ torch.tensor(2 * (hid_j // 2) / config.hidden_size),
146
+ )
147
+ for hid_j in range(config.hidden_size)
148
+ ]
149
+ )
150
+
151
+ else:
152
+ self.vision_tower_aux_list = build_vision_tower_aux_list(
153
+ config, delay_load=True
154
+ )
155
+ config.mm_hidden_size = sum(
156
+ [
157
+ vision_tower_aux.hidden_size
158
+ for vision_tower_aux in self.vision_tower_aux_list
159
+ ]
160
+ )
161
+ self.mm_projector = build_vision_projector(config)
162
+ self.image_newline = nn.Parameter(
163
+ torch.empty(config.hidden_size, dtype=self.dtype)
164
+ )
165
+
166
+ def get_frame_pos(self, time_range):
167
+ frame_pos = self.frame_pos.reshape(1, -1) * time_range.reshape(-1, 1).to(
168
+ self.frame_pos.device
169
+ )
170
+ frame_pos[:, 0::2] = torch.sin(frame_pos[:, 0::2])
171
+ frame_pos[:, 1::2] = torch.cos(frame_pos[:, 0::2])
172
+ frame_pos = frame_pos.unsqueeze(1)
173
+ return frame_pos
174
+
175
+ # def get_vision_tower(self):
176
+ # vision_tower = getattr(self, 'vision_tower', None)
177
+ # if type(vision_tower) is list:
178
+ # vision_tower = vision_tower[0]
179
+ # return vision_tower
180
+
181
+ def get_vision_tower_aux_list(self):
182
+ vision_tower_aux_list = getattr(self, "vision_tower_aux_list", None)
183
+ return vision_tower_aux_list
184
+
185
+ def initialize_vision_modules(self, model_args, fsdp=None):
186
+ # vision_tower = model_args.vision_tower
187
+ num_query_group = model_args.num_query_group
188
+ query_num_list = model_args.query_num_list
189
+ vision_hidden_size = model_args.vision_hidden_size
190
+ vision_tower_aux_list = model_args.vision_tower_aux_list
191
+ vision_tower_aux_token_len_list = model_args.vision_tower_aux_token_len_list
192
+ image_token_len = model_args.image_token_len
193
+ mm_vision_select_layer = model_args.mm_vision_select_layer
194
+ mm_vision_select_feature = model_args.mm_vision_select_feature
195
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
196
+ connector_only = model_args.connector_only
197
+ connector_depth = model_args.connector_depth
198
+
199
+ # self.config.mm_vision_tower = vision_tower
200
+ self.config.image_token_len = image_token_len
201
+ self.config.num_query_group = num_query_group
202
+ self.config.query_num_list = query_num_list
203
+ assert num_query_group == len(query_num_list)
204
+ self.config.connector_depth = connector_depth
205
+ self.config.mm_vision_tower_aux_list = vision_tower_aux_list
206
+ self.config.mm_vision_tower_aux_token_len_list = vision_tower_aux_token_len_list
207
+ self.config.connector_only = connector_only
208
+ self.config.highres_connect = model_args.highres_connect
209
+ self.config.highres = model_args.highres
210
+ self.config.frame_pos = model_args.frame_pos
211
+ self.config.lowres_token = model_args.lowres_token
212
+ self.config.connect_layer = model_args.connect_layer
213
+ self.config.dino_threshold = getattr(model_args, "dino_threshold", 0.83)
214
+ self.config.drop_threshold = getattr(model_args, "drop_threshold", 0.6)
215
+ self.config.is_image_newline = getattr(model_args, "is_image_newline", True)
216
+
217
+ if self.get_vision_tower_aux_list() is None:
218
+ vision_tower_aux_list = build_vision_tower_aux_list(model_args)
219
+ if model_args.unfreeze_mm_vision_tower:
220
+ self.vision_tower_aux_list = nn.ModuleList(vision_tower_aux_list)
221
+ else:
222
+ self.vision_tower_aux_list = vision_tower_aux_list
223
+ else:
224
+ vision_tower_aux_list = self.vision_tower_aux_list
225
+ for vision_tower_aux in vision_tower_aux_list:
226
+ vision_tower_aux.load_model()
227
+
228
+ self.config.use_mm_proj = True
229
+ self.config.mm_projector_type = getattr(
230
+ model_args, "mm_projector_type", "linear"
231
+ )
232
+ self.config.vision_hidden_size = vision_hidden_size
233
+ self.config.mm_vision_select_layer = mm_vision_select_layer
234
+ self.config.mm_vision_select_feature = mm_vision_select_feature
235
+
236
+ if getattr(self, "mm_projector", None) is None:
237
+
238
+ if self.config.mm_projector_type == "sva":
239
+ self.mm_projector = nn.Sequential(
240
+ nn.Linear(
241
+ vision_hidden_size * num_query_group, self.config.hidden_size
242
+ ),
243
+ nn.GELU(),
244
+ nn.Linear(self.config.hidden_size, self.config.hidden_size),
245
+ )
246
+ for aux_i, vision_tower_aux in enumerate(vision_tower_aux_list):
247
+ setattr(
248
+ self,
249
+ "mm_projector_aux_{}".format(aux_i),
250
+ nn.Sequential(
251
+ nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size),
252
+ nn.GELU(),
253
+ nn.Linear(vision_hidden_size, vision_hidden_size),
254
+ nn.LayerNorm(vision_hidden_size),
255
+ ),
256
+ )
257
+
258
+ # vision sampler for each group of query as the connector before the LLM
259
+ for query_group_i in range(num_query_group):
260
+ cross_att_token_len_list = [
261
+ int(vision_tower_aux_token_len**0.5)
262
+ // int(query_num_list[query_group_i] ** 0.5)
263
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
264
+ ]
265
+ setattr(
266
+ self,
267
+ "vision_sampler_{}".format(query_group_i),
268
+ VisionTokenSampler(
269
+ vision_hidden_size,
270
+ vision_hidden_size,
271
+ [vision_hidden_size] * len(vision_tower_aux_list),
272
+ cross_att_token_len_list,
273
+ vision_hidden_size,
274
+ connector_depth,
275
+ ),
276
+ )
277
+
278
+ # sampler layers within LLM
279
+ if not connector_only:
280
+ num_of_vision_sampler_layers = (
281
+ self.config.num_of_vision_sampler_layers
282
+ ) = model_args.num_of_vision_sampler_layers
283
+ self.config.start_of_vision_sampler_layers = (
284
+ model_args.start_of_vision_sampler_layers
285
+ )
286
+ self.config.stride_of_vision_sampler_layers = (
287
+ model_args.stride_of_vision_sampler_layers
288
+ )
289
+ cross_att_token_len_list = [
290
+ int(vision_tower_aux_token_len**0.5)
291
+ // int(image_token_len**0.5)
292
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
293
+ ]
294
+ self.vision_sampler_layers = nn.ModuleList(
295
+ [
296
+ VisionTokenSampler(
297
+ self.config.hidden_size,
298
+ vision_hidden_size,
299
+ [vision_hidden_size] * len(vision_tower_aux_list),
300
+ cross_att_token_len_list,
301
+ vision_hidden_size,
302
+ 1,
303
+ )
304
+ for layer_idx in range(0, num_of_vision_sampler_layers)
305
+ ]
306
+ )
307
+ vision_embed_std = 1 / torch.sqrt(
308
+ torch.tensor(vision_hidden_size, dtype=self.dtype)
309
+ )
310
+ self.vision_query = nn.Parameter(
311
+ torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype)
312
+ * vision_embed_std
313
+ )
314
+
315
+ embed_std = 1 / torch.sqrt(
316
+ torch.tensor(self.config.hidden_size, dtype=self.dtype)
317
+ )
318
+ self.image_newline = nn.Parameter(
319
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
320
+ )
321
+
322
+ else:
323
+ self.config.mm_hidden_size = sum(
324
+ [
325
+ vision_tower_aux.hidden_size
326
+ for vision_tower_aux in vision_tower_aux_list
327
+ ]
328
+ )
329
+ self.mm_projector = build_vision_projector(self.config)
330
+ embed_std = 1 / torch.sqrt(
331
+ torch.tensor(self.config.hidden_size, dtype=self.dtype)
332
+ )
333
+ self.image_newline = nn.Parameter(
334
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
335
+ )
336
+ else:
337
+ # In case it is frozen by LoRA
338
+ for p in self.mm_projector.parameters():
339
+ p.requires_grad = True
340
+
341
+ if pretrain_mm_mlp_adapter is not None:
342
+ mm_projector_weights = torch.load(
343
+ pretrain_mm_mlp_adapter, map_location="cpu"
344
+ )
345
+
346
+ def get_w(weights, keyword):
347
+ return {
348
+ k.split(keyword + ".")[1]: v
349
+ for k, v in weights.items()
350
+ if keyword + "." in k
351
+ }
352
+
353
+ self.mm_projector.load_state_dict(
354
+ get_w(mm_projector_weights, "mm_projector"), strict=True
355
+ )
356
+
357
+ if self.config.mm_projector_type == "sva":
358
+ for aux_i in range(len(vision_tower_aux_list)):
359
+ getattr(self, "mm_projector_aux_{}".format(aux_i)).load_state_dict(
360
+ get_w(
361
+ mm_projector_weights, "mm_projector_aux_{}".format(aux_i)
362
+ ),
363
+ strict=True,
364
+ )
365
+
366
+ for query_group_i in range(num_query_group):
367
+ getattr(
368
+ self, "vision_sampler_{}".format(query_group_i)
369
+ ).load_state_dict(
370
+ get_w(
371
+ mm_projector_weights,
372
+ "vision_sampler_{}".format(query_group_i),
373
+ ),
374
+ strict=True,
375
+ )
376
+
377
+ if not connector_only:
378
+ self.vision_sampler_layers.load_state_dict(
379
+ get_w(mm_projector_weights, "vision_sampler_layers"),
380
+ strict=True,
381
+ )
382
+ self.vision_query.data = mm_projector_weights["model.vision_query"]
383
+ self.image_newline.data = mm_projector_weights["model.image_newline"]
384
+
385
+
386
+ def unmask_attention_mask(mask, original_size):
387
+ original_w, original_h = original_size
388
+ cur_h, cur_w = mask.shape[1:3]
389
+
390
+ original_aspect_ratio = original_w / original_h
391
+ current_aspect_ratio = cur_w / cur_h
392
+
393
+ if original_aspect_ratio > current_aspect_ratio:
394
+ scale_factor = cur_w / original_w
395
+ new_height = int(original_h * scale_factor)
396
+ padding = (cur_h - new_height) // 2
397
+ if padding > 0:
398
+ mask[:, :padding, :] = 0
399
+ mask[:, -padding:, :] = 0
400
+ return mask
401
+ else:
402
+ scale_factor = cur_h / original_h
403
+ new_width = int(original_w * scale_factor)
404
+ padding = (cur_w - new_width) // 2
405
+ if padding > 0:
406
+ mask[:, :, :padding] = 0
407
+ mask[:, :, -padding:] = 0
408
+ return mask
409
+
410
+
411
+ def unpad_image(tensor, original_size):
412
+ """
413
+ Unpads a PyTorch tensor of a padded and resized image.
414
+
415
+ Args:
416
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
417
+ original_size (tuple): The original size of the image (height, width).
418
+
419
+ Returns:
420
+ torch.Tensor: The unpadded image tensor.
421
+ """
422
+ original_width, original_height = original_size
423
+ current_height, current_width = tensor.shape[1:3]
424
+
425
+ original_aspect_ratio = original_width / original_height
426
+ current_aspect_ratio = current_width / current_height
427
+
428
+ if original_aspect_ratio > current_aspect_ratio:
429
+ scale_factor = current_width / original_width
430
+ new_height = int(original_height * scale_factor)
431
+ padding = (current_height - new_height) // 2
432
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
433
+ # if 0 in unpadded_tensor.shape:
434
+ # print(f"scale_factor: {scale_factor}, new_height: {new_height}, padding: {padding}, original_width: {original_width}, original_height: {original_height}")
435
+ else:
436
+ scale_factor = current_height / original_height
437
+ new_width = int(original_width * scale_factor)
438
+ padding = (current_width - new_width) // 2
439
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
440
+ # if 0 in unpadded_tensor.shape:
441
+ # print(f"scale_factor: {scale_factor}, new_width: {new_width}, padding: {padding}, original_width: {original_width}, original_height: {original_height}")
442
+
443
+ return unpadded_tensor
444
+
445
+
446
+ class CambrianMetaForCausalLM(ABC):
447
+
448
+ @abstractmethod
449
+ def get_model(self):
450
+ pass
451
+
452
+ # def get_vision_tower(self):
453
+ # return self.get_model().get_vision_tower()
454
+
455
+ def get_vision_tower_aux_list(self):
456
+ return self.get_model().get_vision_tower_aux_list()
457
+
458
+ def rearrange_vision_tower_features_train(
459
+ self,
460
+ vision_tower_aux_feature_list,
461
+ vision_tower_aux_attention_masks_list,
462
+ query_side_len,
463
+ ):
464
+ vision_tower_aux_feature_rearranged_list = []
465
+ vision_tower_aux_attention_masks_rearranged_list = []
466
+ bs = vision_tower_aux_feature_list[0].shape[0]
467
+ for vision_tower_aux_feature, vision_tower_aux_attention_masks in zip(
468
+ vision_tower_aux_feature_list, vision_tower_aux_attention_masks_list
469
+ ):
470
+ aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5)
471
+ assert (aux_height // query_side_len) * query_side_len == aux_height
472
+
473
+ reduce_factor = aux_height // query_side_len
474
+ vision_tower_aux_feature_rearranged = vision_tower_aux_feature.view(
475
+ bs, query_side_len, reduce_factor, query_side_len, reduce_factor, -1
476
+ )
477
+ vision_tower_aux_feature_rearranged = (
478
+ vision_tower_aux_feature_rearranged.permute(0, 1, 3, 2, 4, 5)
479
+ .contiguous()
480
+ .flatten(0, 2)
481
+ .flatten(1, 2)
482
+ )
483
+
484
+ vision_tower_aux_attention_masks_rearranged = (
485
+ vision_tower_aux_attention_masks.view(
486
+ bs * query_side_len * query_side_len, reduce_factor * reduce_factor
487
+ )
488
+ )
489
+
490
+ vision_tower_aux_feature_rearranged_list.append(
491
+ vision_tower_aux_feature_rearranged
492
+ )
493
+ vision_tower_aux_attention_masks_rearranged_list.append(
494
+ vision_tower_aux_attention_masks_rearranged
495
+ )
496
+ return (
497
+ vision_tower_aux_feature_rearranged_list,
498
+ vision_tower_aux_attention_masks_rearranged_list,
499
+ )
500
+
501
+ def rearrange_vision_tower_features_inference(
502
+ self, vision_tower_aux_feature_list, query_side_len, image_sizes, unpad=False
503
+ ):
504
+ vision_tower_aux_feature_rearranged_list = []
505
+ vision_tower_aux_attention_masks_rearranged_list = []
506
+ bs = vision_tower_aux_feature_list[0].shape[0]
507
+ for vision_tower_aux_feature in vision_tower_aux_feature_list:
508
+ aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5)
509
+ assert (aux_height // query_side_len) * query_side_len == aux_height
510
+
511
+ reduce_factor = aux_height // query_side_len
512
+
513
+ vision_tower_aux_feature_rearranged = []
514
+ vision_tower_aux_attention_masks_rearranged = []
515
+ for batch_i in range(bs):
516
+ image_size = image_sizes[batch_i]
517
+ cur_vision_tower_aux_feature = vision_tower_aux_feature[batch_i]
518
+
519
+ cur_vision_tower_aux_attention_masks_rearranged = torch.ones(
520
+ (1, aux_height, aux_width),
521
+ dtype=torch.bool,
522
+ device=cur_vision_tower_aux_feature.device,
523
+ )
524
+ cur_vision_tower_aux_feature_rearranged = (
525
+ cur_vision_tower_aux_feature.view(
526
+ 1,
527
+ query_side_len,
528
+ reduce_factor,
529
+ query_side_len,
530
+ reduce_factor,
531
+ -1,
532
+ )
533
+ )
534
+ cur_vision_tower_aux_feature_rearranged = (
535
+ cur_vision_tower_aux_feature_rearranged.permute(
536
+ 0, 1, 3, 2, 4, 5
537
+ ).contiguous()
538
+ )
539
+ if unpad:
540
+ cur_vision_tower_aux_feature_rearranged = unpad_image(
541
+ cur_vision_tower_aux_feature_rearranged, image_size
542
+ )
543
+ cur_vision_tower_aux_feature_rearranged = (
544
+ cur_vision_tower_aux_feature_rearranged.flatten(0, 2).flatten(1, 2)
545
+ ) # query_side_len*query_side_len X reduce_factor*reduce_factor X C
546
+
547
+ cur_vision_tower_aux_attention_masks_rearranged = unmask_attention_mask(
548
+ cur_vision_tower_aux_attention_masks_rearranged, image_size
549
+ )
550
+ cur_vision_tower_aux_attention_masks_rearranged = (
551
+ cur_vision_tower_aux_attention_masks_rearranged.view(
552
+ 1, query_side_len, reduce_factor, query_side_len, reduce_factor
553
+ )
554
+ .permute(0, 1, 3, 2, 4)
555
+ .contiguous()
556
+ )
557
+ if unpad:
558
+ cur_vision_tower_aux_attention_masks_rearranged = unpad_image(
559
+ cur_vision_tower_aux_attention_masks_rearranged, image_size
560
+ )
561
+ cur_vision_tower_aux_attention_masks_rearranged = (
562
+ cur_vision_tower_aux_attention_masks_rearranged.flatten(
563
+ 0, 2
564
+ ).flatten(1, 2)
565
+ )
566
+
567
+ cur_vision_tower_aux_attention_masks_rearranged[
568
+ cur_vision_tower_aux_attention_masks_rearranged.sum(-1) == 0
569
+ ] = True
570
+
571
+ vision_tower_aux_feature_rearranged.append(
572
+ cur_vision_tower_aux_feature_rearranged
573
+ )
574
+ vision_tower_aux_attention_masks_rearranged.append(
575
+ cur_vision_tower_aux_attention_masks_rearranged
576
+ )
577
+
578
+ vision_tower_aux_feature_rearranged = torch.cat(
579
+ vision_tower_aux_feature_rearranged, 0
580
+ )
581
+ vision_tower_aux_attention_masks_rearranged = torch.cat(
582
+ vision_tower_aux_attention_masks_rearranged, 0
583
+ )
584
+
585
+ vision_tower_aux_feature_rearranged_list.append(
586
+ vision_tower_aux_feature_rearranged
587
+ )
588
+ vision_tower_aux_attention_masks_rearranged_list.append(
589
+ vision_tower_aux_attention_masks_rearranged
590
+ )
591
+
592
+ return (
593
+ vision_tower_aux_feature_rearranged_list,
594
+ vision_tower_aux_attention_masks_rearranged_list,
595
+ )
596
+
597
+ def encode_images(self, image_aux_list, encode_type=None):
598
+ vision_tower_aux_list = self.get_model().get_vision_tower_aux_list()
599
+ image_aux_features_list = []
600
+ chunk_size = 64
601
+ if encode_type == "dino":
602
+ image_aux = image_aux_list[-1]
603
+ vision_tower_aux = vision_tower_aux_list[-1]
604
+ if image_aux.shape[0] > chunk_size:
605
+ image_aux_features_chunks = []
606
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
607
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
608
+ chunk = image_aux[start_idx:end_idx]
609
+ image_aux_features_chunk = vision_tower_aux(chunk)
610
+ image_aux_features_chunks.append(image_aux_features_chunk)
611
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
612
+ else:
613
+ image_aux_features = vision_tower_aux(image_aux)
614
+ return image_aux_features
615
+ elif encode_type == "siglip":
616
+ image_aux = image_aux_list[0]
617
+ vision_tower_aux = vision_tower_aux_list[0]
618
+ if image_aux.shape[0] > chunk_size:
619
+ image_aux_features_chunks = []
620
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
621
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
622
+ chunk = image_aux[start_idx:end_idx]
623
+ image_aux_features_chunk = vision_tower_aux(chunk)
624
+ image_aux_features_chunks.append(image_aux_features_chunk)
625
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
626
+ else:
627
+ image_aux_features = vision_tower_aux(image_aux)
628
+ return image_aux_features
629
+ else:
630
+ for image_aux, vision_tower_aux in zip(
631
+ image_aux_list, vision_tower_aux_list
632
+ ):
633
+ if image_aux.shape[0] > chunk_size:
634
+ image_aux_features_chunks = []
635
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
636
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
637
+ chunk = image_aux[start_idx:end_idx]
638
+ image_aux_features_chunk = vision_tower_aux(chunk)
639
+ image_aux_features_chunks.append(image_aux_features_chunk)
640
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
641
+ else:
642
+ image_aux_features = vision_tower_aux(image_aux)
643
+ image_aux_features_list.append(image_aux_features)
644
+ return image_aux_features_list
645
+
646
+ def select_frame(
647
+ self,
648
+ feature_list,
649
+ split_sizes,
650
+ input_ids,
651
+ new_image_aux_list,
652
+ image_sizes,
653
+ window_size=16,
654
+ threshold=0.83,
655
+ ):
656
+ dino_features_batch = torch.split(feature_list, split_sizes, dim=0)
657
+ new_image_aux_batch_0 = torch.split(new_image_aux_list[0], split_sizes, dim=0)
658
+ new_image_aux_batch_1 = torch.split(new_image_aux_list[1], split_sizes, dim=0)
659
+ new_split_sizes = []
660
+ selected_frames_all_0 = []
661
+ selected_frames_all_1 = []
662
+ selected_frames_feature_all = []
663
+ selected_frame_indices_all = []
664
+ for i_batch, frame_features in enumerate(dino_features_batch):
665
+ try:
666
+ if "llama" in self.get_model().config.model_type:
667
+ text_len = torch.where(input_ids[i_batch] == 128002)[-1][0]
668
+ else:
669
+ text_len = torch.where(input_ids[i_batch] == 151643)[-1][0]
670
+ except:
671
+ text_len = len(input_ids[i_batch])
672
+ original_width, original_height = image_sizes[i_batch]
673
+ if getattr(self.get_model().config, "highres", False):
674
+ token_per_frame = self.get_model().config.lowres_token ** 2
675
+ else:
676
+ token_per_frame = self.get_model().config.image_token_len
677
+ # current_height, current_width = token_per_side, token_per_side
678
+ # original_aspect_ratio = original_width / original_height
679
+ # current_aspect_ratio = current_width / current_height
680
+ # if original_aspect_ratio > current_aspect_ratio:
681
+ # scale_factor = current_width / original_width
682
+ # new_height = int(original_height * scale_factor)
683
+ # padding = math.ceil((current_height - new_height) / 2.0)
684
+ # token_per_frame = (
685
+ # current_height - padding * 2
686
+ # ) * token_per_side + token_per_side
687
+ # else:
688
+ # scale_factor = current_height / original_height
689
+ # new_width = int(original_width * scale_factor)
690
+ # padding = math.ceil((current_width - new_width) / 2.0)
691
+ # token_per_frame = (current_width - padding * 2) * token_per_side + (
692
+ # current_width - padding * 2
693
+ # )
694
+ # token_per_frame = (
695
+ # token_per_side**2 if token_per_frame < 1 else token_per_frame
696
+ # )
697
+ max_num_frames = max(
698
+ 1,
699
+ (
700
+ self.get_model().config.tokenizer_model_max_length
701
+ - text_len
702
+ - getattr(self.get_model().config, "inference_max_length", 16)
703
+ )
704
+ // token_per_frame,
705
+ )
706
+ if len(frame_features) < max_num_frames:
707
+ selected_frames_all_0.append(new_image_aux_batch_0[i_batch])
708
+ selected_frames_all_1.append(new_image_aux_batch_1[i_batch])
709
+ selected_frames_feature_all.append(frame_features)
710
+ new_split_sizes.append(len(frame_features))
711
+ selected_frame_indices_all.append(torch.arange(len(frame_features)))
712
+ continue
713
+
714
+ num_segments = len(frame_features) // window_size
715
+ if num_segments == 0:
716
+ query_feature = frame_features.flatten(1, 2)
717
+ query_feature = query_feature / torch.norm(
718
+ (query_feature), dim=1, keepdim=True
719
+ )
720
+ similarities = torch.mean(query_feature @ query_feature.T, dim=1)
721
+ similarities[len(frame_features) // 2] = 0
722
+ indices = torch.where(similarities < threshold)[0]
723
+ selected_frame_indices_all.append(indices)
724
+ selected_frames_all_0.append(new_image_aux_batch_0[i_batch][indices])
725
+ selected_frames_all_1.append(new_image_aux_batch_1[i_batch][indices])
726
+ selected_frames_feature_all.append(frame_features[indices])
727
+ new_split_sizes.append(len(indices))
728
+ continue
729
+ segments_frames_0 = []
730
+ segments_frames_1 = []
731
+ segments_features = []
732
+ for start_idx in range(0, len(frame_features), window_size):
733
+ end_idx = min(start_idx + window_size, len(frame_features))
734
+ segments_frames_0.append(
735
+ new_image_aux_batch_0[i_batch][start_idx:end_idx]
736
+ )
737
+ segments_frames_1.append(
738
+ new_image_aux_batch_1[i_batch][start_idx:end_idx]
739
+ )
740
+ segments_features.append(frame_features[start_idx:end_idx])
741
+ selected_frames_0 = []
742
+ selected_frames_1 = []
743
+ selected_features = []
744
+ selected_frame_indices = []
745
+ for i, segment in enumerate(segments_features):
746
+ query_feature = segment.flatten(1, 2)
747
+ query_feature = query_feature / torch.norm(
748
+ (query_feature), dim=1, keepdim=True
749
+ )
750
+ similarities = torch.mean(query_feature @ query_feature.T, dim=1)
751
+ similarities[len(segment) // 2] = 0
752
+ indices = torch.where(similarities < threshold)[0]
753
+ selected_frames_0.append(segments_frames_0[i][indices])
754
+ selected_frames_1.append(segments_frames_1[i][indices])
755
+ selected_features.append(segment[indices])
756
+ selected_frame_indices.extend(indices + i * window_size)
757
+ selected_frames_0 = torch.cat(selected_frames_0, dim=0)
758
+ selected_frames_1 = torch.cat(selected_frames_1, dim=0)
759
+ selected_features = torch.cat(selected_features, dim=0)
760
+ selected_frame_indices = torch.tensor(selected_frame_indices)
761
+ # ablation
762
+ max_num_frames = 400 # in case of OOM
763
+ if len(selected_frames_0) > max_num_frames:
764
+ interval = len(selected_frames_0) / float(max_num_frames)
765
+ indices = [int(interval * i) for i in range(max_num_frames)]
766
+ new_split_sizes.append(len(indices))
767
+ selected_frames_all_0.append(selected_frames_0[indices])
768
+ selected_frames_all_1.append(selected_frames_1[indices])
769
+ selected_frames_feature_all.append(selected_features[indices])
770
+ selected_frame_indices = selected_frame_indices[indices]
771
+ else:
772
+ new_split_sizes.append(len(selected_frames_0))
773
+ selected_frames_all_0.append(selected_frames_0)
774
+ selected_frames_all_1.append(selected_frames_1)
775
+ selected_frames_feature_all.append(selected_features)
776
+ selected_frame_indices_all.append(selected_frame_indices)
777
+ selected_frames_all_0 = torch.cat(selected_frames_all_0, dim=0)
778
+ selected_frames_all_1 = torch.cat(selected_frames_all_1, dim=0)
779
+ selected_frames_feature_all = torch.cat(selected_frames_feature_all, dim=0)
780
+ return (
781
+ selected_frames_feature_all,
782
+ new_split_sizes,
783
+ [selected_frames_all_0, selected_frames_all_1],
784
+ selected_frame_indices_all,
785
+ )
786
+
787
+ def prepare_inputs_labels_for_multimodal(
788
+ self,
789
+ input_ids,
790
+ position_ids,
791
+ attention_mask,
792
+ past_key_values,
793
+ labels,
794
+ images,
795
+ image_aux_attention_masks_list=None,
796
+ image_sizes=None,
797
+ ):
798
+ # vision_tower = self.get_vision_tower()
799
+ vision_tower_aux_list = self.get_model().get_vision_tower_aux_list()
800
+ if vision_tower_aux_list is None or images is None or input_ids.shape[1] == 1:
801
+ return (
802
+ input_ids,
803
+ position_ids,
804
+ attention_mask,
805
+ past_key_values,
806
+ None,
807
+ labels,
808
+ None,
809
+ None,
810
+ None,
811
+ None,
812
+ )
813
+
814
+ image_aux_list = images
815
+
816
+ split_sizes = None
817
+
818
+ if type(image_aux_list[0]) is list or image_aux_list[0].ndim == 5:
819
+ split_sizes_ori = [
820
+ 1 if image.ndim == 3 else image.shape[0] for image in image_aux_list[0]
821
+ ]
822
+ new_image_aux_list = []
823
+ for image_aux in image_aux_list:
824
+ if type(image_aux) is list:
825
+ image_aux = [
826
+ x.unsqueeze(0) if x.ndim == 3 else x for x in image_aux
827
+ ]
828
+ concat_image_aux = torch.cat([image for image in image_aux], dim=0)
829
+ new_image_aux_list.append(concat_image_aux)
830
+ image_aux_features_dino = self.encode_images(
831
+ new_image_aux_list, encode_type="dino"
832
+ )
833
+
834
+ (
835
+ image_aux_features_dino,
836
+ split_sizes,
837
+ new_image_aux_list,
838
+ selected_frame_indices_all,
839
+ ) = self.select_frame(
840
+ image_aux_features_dino,
841
+ split_sizes_ori,
842
+ input_ids,
843
+ new_image_aux_list,
844
+ image_sizes,
845
+ threshold=getattr(self.get_model().config, "dino_threshold", 0.83),
846
+ )
847
+
848
+ image_aux_features_siglip = self.encode_images(
849
+ new_image_aux_list, encode_type="siglip"
850
+ )
851
+ image_aux_features_list = [
852
+ image_aux_features_siglip,
853
+ image_aux_features_dino,
854
+ ]
855
+
856
+ bs = image_aux_features_list[0].shape[0]
857
+ dtype = new_image_aux_list[0].dtype
858
+
859
+ frame_sizes = []
860
+ for i in range(len(image_sizes)):
861
+ for j in range(split_sizes[i]):
862
+ frame_sizes.append(image_sizes[i])
863
+ image_sizes = frame_sizes
864
+ else:
865
+ image_aux_features_list = self.encode_images(image_aux_list)
866
+ bs = image_aux_list[0].shape[0]
867
+ dtype = image_aux_list[0].dtype
868
+
869
+ image_token_len = self.get_model().config.image_token_len
870
+ query_num_list = self.get_model().config.query_num_list
871
+
872
+ final_height = final_width = int(image_token_len**0.5)
873
+
874
+ final_image_features_list = []
875
+ final_image_features_down_list = []
876
+
877
+ # only needed for sva
878
+ vision_tower_aux_feature_list_final = None
879
+ vision_tower_aux_attention_masks_list_final = None
880
+ global_context_feature_final = None
881
+
882
+ if self.get_model().config.mm_projector_type == "sva":
883
+ vision_tower_aux_feature_list = []
884
+ vision_tower_aux_attention_masks_list = []
885
+ # get vision tokens from each vision tower
886
+ for aux_i in range(len(vision_tower_aux_list)):
887
+ image_aux_features = image_aux_features_list[aux_i]
888
+
889
+ image_aux_features = getattr(
890
+ self.get_model(), "mm_projector_aux_{}".format(aux_i)
891
+ )(image_aux_features).to(dtype)
892
+ if aux_i == 0:
893
+ global_context_feature = image_aux_features.mean(1).view(
894
+ bs, 1, 1, -1
895
+ )
896
+
897
+ vision_tower_aux_feature_list.append(image_aux_features)
898
+ input_mix_res = True
899
+ input_high_res = True
900
+ # perform vision sampling for each query group
901
+ for query_group_i, query_num in enumerate(query_num_list):
902
+ query_features_i = (
903
+ self.get_model()
904
+ .vision_query[query_group_i, :]
905
+ .view(1, 1, 1, -1)
906
+ .expand(bs, query_num, -1, -1)
907
+ )
908
+ global_context_feature_i = global_context_feature.expand(
909
+ -1, query_num, 1, -1
910
+ ).flatten(0, 1)
911
+ query_side_len = int(query_num**0.5)
912
+ if IS_XLA_AVAILABLE:
913
+ (
914
+ vision_tower_aux_feature_list_i,
915
+ vision_tower_aux_attention_masks_list_i,
916
+ ) = self.rearrange_vision_tower_features_train(
917
+ vision_tower_aux_feature_list,
918
+ image_aux_attention_masks_list,
919
+ query_side_len,
920
+ )
921
+ else:
922
+ (
923
+ vision_tower_aux_feature_list_i,
924
+ vision_tower_aux_attention_masks_list_i,
925
+ ) = self.rearrange_vision_tower_features_inference(
926
+ vision_tower_aux_feature_list, query_side_len, image_sizes
927
+ )
928
+
929
+ query_features_i = getattr(
930
+ self.get_model(), "vision_sampler_{}".format(query_group_i)
931
+ )(
932
+ query_features_i.flatten(0, 1),
933
+ global_context_feature_i,
934
+ *vision_tower_aux_feature_list_i,
935
+ *vision_tower_aux_attention_masks_list_i,
936
+ )
937
+ query_features_i = query_features_i.view(bs, query_num, -1)
938
+
939
+ if split_sizes is not None:
940
+ try:
941
+ if "llama" in self.get_model().config.model_type:
942
+ text_len = torch.where(input_ids[0] == 128002)[-1][0]
943
+ else:
944
+ text_len = torch.where(input_ids[0] == 151643)[-1][0]
945
+ except:
946
+ text_len = len(input_ids[0])
947
+ max_visual_len = (
948
+ self.get_model().config.tokenizer_model_max_length
949
+ - text_len
950
+ - getattr(self.get_model().config, "inference_max_length", 16)
951
+ )
952
+ max_num_frames = max(
953
+ 1,
954
+ math.floor(max_visual_len // (final_height * final_width)),
955
+ )
956
+ max_num_frames_low = max(
957
+ 1,
958
+ math.floor(
959
+ max_visual_len
960
+ // (self.get_model().config.lowres_token ** 2)
961
+ ),
962
+ )
963
+ if split_sizes[0] < max_num_frames:
964
+ input_mix_res = False
965
+ elif split_sizes[0] > max_num_frames_low:
966
+ input_mix_res = False
967
+ input_high_res = False
968
+
969
+ # input_mix_res = False # ablation
970
+
971
+ if (getattr(self.config, "highres", False)) and input_mix_res:
972
+ _query_features_i = (
973
+ query_features_i.permute(0, 2, 1)
974
+ .contiguous()
975
+ .view(bs, -1, query_side_len, query_side_len)
976
+ )
977
+ _query_features_i = F.interpolate(
978
+ _query_features_i.float(),
979
+ size=(
980
+ self.get_model().config.lowres_token,
981
+ self.get_model().config.lowres_token,
982
+ ),
983
+ mode="bilinear",
984
+ align_corners=False,
985
+ ).to(dtype=query_features_i.dtype)
986
+ _query_features_i = (
987
+ _query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
988
+ )
989
+ final_image_features_down_list.append(_query_features_i)
990
+
991
+ # interpolate to the final target size
992
+ if query_side_len != final_height:
993
+ query_features_i = (
994
+ query_features_i.permute(0, 2, 1)
995
+ .contiguous()
996
+ .view(bs, -1, query_side_len, query_side_len)
997
+ )
998
+ if input_high_res:
999
+ query_features_i = F.interpolate(
1000
+ query_features_i.float(),
1001
+ size=(final_height, final_width),
1002
+ mode="bilinear",
1003
+ align_corners=False,
1004
+ ).to(dtype=query_features_i.dtype)
1005
+ else:
1006
+ query_features_i = F.interpolate(
1007
+ query_features_i.float(),
1008
+ size=(8, 8),
1009
+ mode="bilinear",
1010
+ align_corners=False,
1011
+ ).to(dtype=query_features_i.dtype)
1012
+ query_features_i = (
1013
+ query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
1014
+ )
1015
+ final_image_features_list.append(query_features_i)
1016
+
1017
+ if IS_XLA_AVAILABLE:
1018
+ (
1019
+ vision_tower_aux_feature_list_final,
1020
+ vision_tower_aux_attention_masks_list_final,
1021
+ ) = self.rearrange_vision_tower_features_train(
1022
+ vision_tower_aux_feature_list,
1023
+ image_aux_attention_masks_list,
1024
+ final_height,
1025
+ )
1026
+ global_context_feature_final = global_context_feature.expand(
1027
+ -1, final_height * final_width, 1, -1
1028
+ ).flatten(0, 1)
1029
+ else:
1030
+ final_image_features_list = image_aux_features_list
1031
+
1032
+ image_features = torch.cat(final_image_features_list, -1)
1033
+ image_features = self.get_model().mm_projector(image_features).to(dtype)
1034
+
1035
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1036
+ image_features_down = torch.cat(final_image_features_down_list, -1)
1037
+ image_features_down = (
1038
+ self.get_model().mm_projector(image_features_down).to(dtype)
1039
+ )
1040
+
1041
+ if IS_XLA_AVAILABLE:
1042
+ image_features = image_features.view(
1043
+ image_features.shape[0], final_height, final_width, -1
1044
+ )
1045
+ image_features = torch.cat(
1046
+ (
1047
+ image_features,
1048
+ self.model.image_newline[None, None, None, :].expand(
1049
+ image_features.shape[0], final_height, 1, -1
1050
+ ),
1051
+ ),
1052
+ dim=2,
1053
+ )
1054
+ image_features = image_features.flatten(1, 2)
1055
+ final_size = [(final_height, final_width)] * bs
1056
+
1057
+ else:
1058
+ image_features = image_features.view(bs, final_height, final_width, -1)
1059
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1060
+ image_features_down = image_features_down.view(
1061
+ bs,
1062
+ self.get_model().config.lowres_token,
1063
+ self.get_model().config.lowres_token,
1064
+ -1,
1065
+ )
1066
+ image_features_unpadded = []
1067
+ image_features_downsample = []
1068
+ final_size = []
1069
+ if self.get_model().config.mm_projector_type == "sva":
1070
+ (
1071
+ vision_tower_aux_feature_list_final,
1072
+ vision_tower_aux_attention_masks_list_final,
1073
+ ) = self.rearrange_vision_tower_features_inference(
1074
+ vision_tower_aux_feature_list, final_height, image_sizes, unpad=True
1075
+ )
1076
+ global_context_feature_final = []
1077
+ for batch_i in range(bs):
1078
+ cur_image_feature = image_features[batch_i]
1079
+ image_size = image_sizes[batch_i]
1080
+
1081
+ cur_image_feature = unpad_image(
1082
+ cur_image_feature.unsqueeze(0), image_size
1083
+ )
1084
+
1085
+ cur_h, cur_w = cur_image_feature.shape[1:3]
1086
+ try: # fix bug for some invalid image
1087
+ cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1)
1088
+ final_size.append((cur_h, cur_w))
1089
+ except:
1090
+ # print(f"invalid after unpad {image_features[batch_i].shape}, {image_sizes[batch_i]}", flush=True)
1091
+ cur_image_feature = image_features[batch_i].unsqueeze(0)
1092
+ image_size = image_sizes[batch_i]
1093
+ cur_h, cur_w = cur_image_feature.shape[1:3]
1094
+ cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1)
1095
+ final_size.append((cur_h, cur_w))
1096
+
1097
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1098
+ cur_image_feature_down = unpad_image(
1099
+ image_features_down[batch_i].unsqueeze(0),
1100
+ (
1101
+ int(
1102
+ image_size[0]
1103
+ / (
1104
+ image_token_len**0.5
1105
+ / self.get_model().config.lowres_token
1106
+ )
1107
+ ),
1108
+ int(
1109
+ image_size[1]
1110
+ / (
1111
+ image_token_len**0.5
1112
+ / self.get_model().config.lowres_token
1113
+ )
1114
+ ),
1115
+ ),
1116
+ )
1117
+ _cur_h, _cur_w = cur_image_feature_down.shape[1:3]
1118
+
1119
+ try: # fix bug for some invalid image
1120
+ cur_image_feature_down = cur_image_feature_down.view(
1121
+ 1, _cur_h, _cur_w, -1
1122
+ )
1123
+ except:
1124
+ print("invalid after unpad", flush=True)
1125
+ cur_image_feature_down = image_features_down[batch_i].unsqueeze(
1126
+ 0
1127
+ )
1128
+ _cur_h, _cur_w = cur_image_feature_down.shape[1:3]
1129
+ cur_image_feature_down = cur_image_feature_down.view(
1130
+ 1, _cur_h, _cur_w, -1
1131
+ )
1132
+
1133
+ cur_image_feature_down = torch.cat(
1134
+ (
1135
+ cur_image_feature_down,
1136
+ self.model.image_newline.view(1, 1, 1, -1)
1137
+ .expand(1, _cur_h, 1, -1)
1138
+ .to(cur_image_feature_down.device),
1139
+ ),
1140
+ dim=2,
1141
+ ).flatten(1, 2)
1142
+
1143
+ if split_sizes is None and getattr(self.config, "frame_pos", False):
1144
+ frame_pos = (
1145
+ self.get_model()
1146
+ .get_frame_pos(torch.arange(1))
1147
+ .to(cur_image_feature_down.device)
1148
+ .to(cur_image_feature_down.dtype)
1149
+ )
1150
+ cur_image_feature_down += frame_pos
1151
+
1152
+ image_features_downsample.append(cur_image_feature_down.squeeze(0))
1153
+
1154
+ cur_image_feature = torch.cat(
1155
+ (
1156
+ cur_image_feature,
1157
+ self.model.image_newline.view(1, 1, 1, -1)
1158
+ .expand(1, cur_h, 1, -1)
1159
+ .to(cur_image_feature.device),
1160
+ ),
1161
+ dim=2,
1162
+ )
1163
+
1164
+ if split_sizes is None and getattr(self.config, "frame_pos", False):
1165
+ frame_pos = (
1166
+ self.get_model()
1167
+ .get_frame_pos(torch.arange(1))
1168
+ .to(cur_image_feature.device)
1169
+ .to(cur_image_feature.dtype)
1170
+ )
1171
+ cur_image_feature += frame_pos
1172
+
1173
+ cur_image_feature = cur_image_feature.flatten(1, 2)
1174
+ image_features_unpadded.append(cur_image_feature.squeeze(0))
1175
+
1176
+ if self.get_model().config.mm_projector_type == "sva":
1177
+ cur_global_context_feature = global_context_feature[batch_i].expand(
1178
+ cur_h * cur_w, 1, -1
1179
+ )
1180
+ global_context_feature_final.append(cur_global_context_feature)
1181
+ if self.get_model().config.mm_projector_type == "sva":
1182
+ global_context_feature_final = torch.cat(
1183
+ global_context_feature_final, 0
1184
+ )
1185
+
1186
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1187
+ image_features = image_features_downsample
1188
+ else:
1189
+ image_features = image_features_unpadded
1190
+
1191
+ # TODO: image start / end is not implemented here to support pretraining.
1192
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
1193
+ self.config, "mm_use_im_start_end", False
1194
+ ):
1195
+ raise NotImplementedError
1196
+
1197
+ split_image_features_unpadded = None
1198
+ frame_split_sizes = None
1199
+
1200
+ if split_sizes is not None:
1201
+ split_image_features = []
1202
+ split_image_features_unpadded = (
1203
+ []
1204
+ if (getattr(self.config, "highres", False)) and input_mix_res
1205
+ else None
1206
+ )
1207
+ start_idx = 0
1208
+ for split_batch_idx, split_size in enumerate(split_sizes):
1209
+ if isinstance(image_features[start_idx : start_idx + split_size], list):
1210
+ if getattr(self.config, "frame_pos", False):
1211
+ frame_feature = torch.cat(
1212
+ image_features[start_idx : start_idx + split_size], dim=0
1213
+ ).reshape(split_size, -1, image_features[0].shape[-1])
1214
+ frame_pos = (
1215
+ self.get_model()
1216
+ .get_frame_pos(selected_frame_indices_all[split_batch_idx])
1217
+ .to(frame_feature.device)
1218
+ .to(frame_feature.dtype)
1219
+ )
1220
+ frame_feature += frame_pos
1221
+ split_image_features.append(
1222
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1223
+ )
1224
+ else:
1225
+ split_image_features.append(
1226
+ torch.cat(
1227
+ image_features[start_idx : start_idx + split_size],
1228
+ dim=0,
1229
+ )
1230
+ )
1231
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1232
+ if getattr(self.config, "frame_pos", False):
1233
+ frame_feature = torch.cat(
1234
+ image_features_unpadded[
1235
+ start_idx : start_idx + split_size
1236
+ ],
1237
+ dim=0,
1238
+ ).reshape(split_size, -1, image_features[0].shape[-1])
1239
+ frame_pos = (
1240
+ self.get_model()
1241
+ .get_frame_pos(
1242
+ selected_frame_indices_all[split_batch_idx]
1243
+ )
1244
+ .to(frame_feature.device)
1245
+ .to(frame_feature.dtype)
1246
+ )
1247
+ frame_feature += frame_pos
1248
+ split_image_features_unpadded.append(
1249
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1250
+ )
1251
+ else:
1252
+ split_image_features_unpadded.append(
1253
+ torch.cat(
1254
+ image_features_unpadded[
1255
+ start_idx : start_idx + split_size
1256
+ ],
1257
+ dim=0,
1258
+ )
1259
+ )
1260
+ else:
1261
+ if getattr(self.config, "frame_pos", False):
1262
+ frame_feature = image_features[
1263
+ start_idx : start_idx + split_size
1264
+ ].reshape(split_size, -1, image_features[0].shape[-1])
1265
+ frame_pos = (
1266
+ self.get_model()
1267
+ .get_frame_pos(selected_frame_indices_all[split_batch_idx])
1268
+ .to(frame_feature.device)
1269
+ .to(frame_feature.dtype)
1270
+ )
1271
+ frame_feature += frame_pos
1272
+ split_image_features.append(
1273
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1274
+ )
1275
+ else:
1276
+ split_image_features.append(
1277
+ image_features[start_idx : start_idx + split_size]
1278
+ )
1279
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1280
+ if getattr(self.config, "frame_pos", False):
1281
+ frame_feature = image_features_unpadded[
1282
+ start_idx : start_idx + split_size
1283
+ ]
1284
+ frame_pos = (
1285
+ self.get_model()
1286
+ .get_frame_pos(
1287
+ selected_frame_indices_all[split_batch_idx]
1288
+ )
1289
+ .to(frame_feature.device)
1290
+ .to(frame_feature.dtype)
1291
+ )
1292
+ frame_feature += frame_pos
1293
+ split_image_features_unpadded.append(
1294
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1295
+ )
1296
+ else:
1297
+ split_image_features_unpadded.append(
1298
+ image_features_unpadded[
1299
+ start_idx : start_idx + split_size
1300
+ ]
1301
+ )
1302
+ start_idx += split_size
1303
+ image_features = split_image_features
1304
+ frame_split_sizes = split_sizes
1305
+
1306
+ _labels = labels
1307
+ _position_ids = position_ids
1308
+ _attention_mask = attention_mask
1309
+ if attention_mask is None:
1310
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1311
+ else:
1312
+ attention_mask = attention_mask.bool()
1313
+ if position_ids is None:
1314
+ position_ids = torch.arange(
1315
+ 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
1316
+ )
1317
+ if labels is None:
1318
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
1319
+
1320
+ # remove the padding using attention_mask -- FIXME
1321
+ _input_ids = input_ids
1322
+
1323
+ attention_mask = attention_mask | (input_ids == IMAGE_TOKEN_INDEX)
1324
+
1325
+ input_ids = [
1326
+ cur_input_ids[cur_attention_mask]
1327
+ for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
1328
+ ]
1329
+ labels = [
1330
+ cur_labels[cur_attention_mask]
1331
+ for cur_labels, cur_attention_mask in zip(labels, attention_mask)
1332
+ ]
1333
+
1334
+ new_input_embeds = []
1335
+ new_labels = []
1336
+ image_token_indices_batch = []
1337
+ cur_image_idx = 0
1338
+ for batch_idx, cur_input_ids in enumerate(input_ids):
1339
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
1340
+ if num_images == 0:
1341
+ cur_image_features = image_features[cur_image_idx]
1342
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
1343
+ cur_input_embeds = torch.cat(
1344
+ [cur_input_embeds_1, cur_image_features[0:0]], dim=0
1345
+ )
1346
+ new_input_embeds.append(cur_input_embeds)
1347
+ new_labels.append(labels[batch_idx])
1348
+ cur_image_idx += 1
1349
+ continue
1350
+
1351
+ image_token_indices = (
1352
+ [-1]
1353
+ + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
1354
+ + [cur_input_ids.shape[0]]
1355
+ )
1356
+ image_token_indices_batch.append(
1357
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()[0]
1358
+ )
1359
+ cur_input_ids_noim = []
1360
+ cur_labels = labels[batch_idx]
1361
+ cur_labels_noim = []
1362
+ for i in range(len(image_token_indices) - 1):
1363
+ cur_input_ids_noim.append(
1364
+ cur_input_ids[
1365
+ image_token_indices[i] + 1 : image_token_indices[i + 1]
1366
+ ]
1367
+ )
1368
+ cur_labels_noim.append(
1369
+ cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
1370
+ )
1371
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
1372
+ cur_input_embeds = self.get_model().embed_tokens(
1373
+ torch.cat(cur_input_ids_noim)
1374
+ )
1375
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
1376
+ cur_new_input_embeds = []
1377
+ cur_new_labels = []
1378
+
1379
+ text_len = sum([x.shape[0] for x in cur_input_embeds_no_im])
1380
+ visual_len = len(image_features[cur_image_idx])
1381
+ max_visual_len = (
1382
+ self.get_model().config.tokenizer_model_max_length
1383
+ - getattr(self.get_model().config, "inference_max_length", 16)
1384
+ - text_len
1385
+ )
1386
+ mix_token = False
1387
+
1388
+ # ablation mix
1389
+ if (
1390
+ input_mix_res
1391
+ and (
1392
+ self.get_model().config.image_token_len
1393
+ > getattr(self.get_model().config, "lowres_token", 8) ** 2
1394
+ )
1395
+ and frame_split_sizes is not None
1396
+ and getattr(self.config, "highres", False)
1397
+ ):
1398
+ if max_visual_len > visual_len:
1399
+ visual_emb = image_features[cur_image_idx]
1400
+ text_emb = cur_input_embeds_no_im[-1]
1401
+ highres_num = math.floor(
1402
+ (max_visual_len - visual_len)
1403
+ / (
1404
+ split_image_features_unpadded[cur_image_idx].shape[0]
1405
+ // frame_split_sizes[cur_image_idx]
1406
+ - visual_emb.shape[0] // frame_split_sizes[cur_image_idx]
1407
+ )
1408
+ )
1409
+ if highres_num >= 1:
1410
+ mix_token = True
1411
+ sim = torch.matmul(visual_emb, text_emb.transpose(0, 1)).mean(
1412
+ dim=-1
1413
+ )
1414
+ sim_frame = sim.reshape(
1415
+ frame_split_sizes[cur_image_idx], -1
1416
+ ).mean(dim=-1)
1417
+ highres_num = min(highres_num, sim_frame.shape[0])
1418
+ top_values, top_indices = torch.topk(sim_frame, highres_num)
1419
+ if len(top_indices) > 0:
1420
+ sorted_indices = torch.sort(top_indices)[1]
1421
+ top_indices = top_indices[sorted_indices]
1422
+ visual_emb_frame = image_features[cur_image_idx].reshape(
1423
+ frame_split_sizes[cur_image_idx],
1424
+ -1,
1425
+ image_features[cur_image_idx].shape[-1],
1426
+ )
1427
+ visual_emb_frame_highres = split_image_features_unpadded[
1428
+ cur_image_idx
1429
+ ].reshape(
1430
+ frame_split_sizes[cur_image_idx],
1431
+ -1,
1432
+ split_image_features_unpadded[cur_image_idx].shape[-1],
1433
+ )
1434
+ current_point = 0
1435
+ mix_visual_emb_frame = []
1436
+ for frame_i in range(len(visual_emb_frame)):
1437
+ if current_point > len(top_indices) - 1:
1438
+ mix_visual_emb_frame.append(
1439
+ visual_emb_frame[frame_i]
1440
+ )
1441
+ continue
1442
+ if frame_i == top_indices[current_point]:
1443
+ mix_visual_emb_frame.append(
1444
+ visual_emb_frame_highres[frame_i]
1445
+ )
1446
+ current_point += 1
1447
+ else:
1448
+ mix_visual_emb_frame.append(
1449
+ visual_emb_frame[frame_i]
1450
+ )
1451
+ image_features[cur_image_idx] = torch.cat(
1452
+ mix_visual_emb_frame, dim=0
1453
+ )
1454
+ # ablation drop
1455
+
1456
+ if (
1457
+ max_visual_len < visual_len
1458
+ and frame_split_sizes is not None
1459
+ and not mix_token
1460
+ ):
1461
+ visual_emb_frame = image_features[cur_image_idx].reshape(
1462
+ frame_split_sizes[cur_image_idx],
1463
+ -1,
1464
+ image_features[cur_image_idx].shape[-1],
1465
+ )
1466
+
1467
+ sim = F.cosine_similarity(
1468
+ visual_emb_frame[:-1],
1469
+ visual_emb_frame[1:],
1470
+ dim=-1,
1471
+ )
1472
+
1473
+ new_visual_emb_frames = []
1474
+ for start_idx in range(0, len(visual_emb_frame), 8):
1475
+ end_idx = min(start_idx + 8, len(visual_emb_frame))
1476
+ chunk_feature = visual_emb_frame[start_idx:end_idx] # 8, HW, C
1477
+ if len(chunk_feature) == 1:
1478
+ new_visual_emb_frames.append(chunk_feature[0])
1479
+ continue
1480
+ sim = F.cosine_similarity(
1481
+ chunk_feature[0]
1482
+ .unsqueeze(0)
1483
+ .repeat_interleave(len(chunk_feature[1:]), dim=0),
1484
+ chunk_feature[1:],
1485
+ dim=-1,
1486
+ )
1487
+ new_visual_emb_frame = torch.cat(
1488
+ [
1489
+ chunk_feature[0],
1490
+ chunk_feature[1:].flatten(0, 1)[
1491
+ sim.flatten(0, 1)
1492
+ < getattr(
1493
+ self.get_model().config, "drop_threshold", 0.7
1494
+ )
1495
+ ],
1496
+ ],
1497
+ dim=0,
1498
+ )
1499
+ new_visual_emb_frames.append(new_visual_emb_frame)
1500
+
1501
+ reduced_visual_len = sum([x.shape[0] for x in new_visual_emb_frames])
1502
+
1503
+ if reduced_visual_len > max_visual_len:
1504
+ force_remove = math.ceil(
1505
+ (reduced_visual_len - max_visual_len)
1506
+ / len(new_visual_emb_frames)
1507
+ )
1508
+ for chunk_i in range(len(new_visual_emb_frames)):
1509
+ new_visual_emb_frames[chunk_i] = new_visual_emb_frames[chunk_i][
1510
+ :-force_remove
1511
+ ]
1512
+ new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0)
1513
+ else:
1514
+ new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0)
1515
+
1516
+ image_features[cur_image_idx] = new_visual_emb_frames[:max_visual_len]
1517
+
1518
+ for i in range(num_images + 1):
1519
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
1520
+ cur_new_labels.append(cur_labels_noim[i])
1521
+ if i < num_images:
1522
+ cur_image_features = image_features[cur_image_idx]
1523
+ cur_image_idx += 1
1524
+ cur_new_input_embeds.append(cur_image_features)
1525
+ cur_new_labels.append(
1526
+ torch.full(
1527
+ (cur_image_features.shape[0],),
1528
+ IGNORE_INDEX,
1529
+ device=cur_labels.device,
1530
+ dtype=cur_labels.dtype,
1531
+ )
1532
+ )
1533
+
1534
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
1535
+
1536
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
1537
+ cur_new_labels = torch.cat(cur_new_labels)
1538
+
1539
+ new_input_embeds.append(cur_new_input_embeds)
1540
+ new_labels.append(cur_new_labels)
1541
+
1542
+ # Truncate sequences to max length as image embeddings can make the sequence longer
1543
+ tokenizer_model_max_length = getattr(
1544
+ self.config, "tokenizer_model_max_length", None
1545
+ )
1546
+ if tokenizer_model_max_length is not None:
1547
+ new_input_embeds = [
1548
+ x[:tokenizer_model_max_length] for x in new_input_embeds
1549
+ ]
1550
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
1551
+
1552
+ # Combine them
1553
+ max_len = max(x.shape[0] for x in new_input_embeds)
1554
+ batch_size = len(new_input_embeds)
1555
+
1556
+ new_input_embeds_padded = []
1557
+ new_labels_padded = torch.full(
1558
+ (batch_size, max_len),
1559
+ IGNORE_INDEX,
1560
+ dtype=new_labels[0].dtype,
1561
+ device=new_labels[0].device,
1562
+ )
1563
+ attention_mask = torch.zeros(
1564
+ (batch_size, max_len),
1565
+ dtype=attention_mask.dtype,
1566
+ device=attention_mask.device,
1567
+ )
1568
+ position_ids = torch.zeros(
1569
+ (batch_size, max_len),
1570
+ dtype=position_ids.dtype,
1571
+ device=position_ids.device,
1572
+ )
1573
+
1574
+ for i, (cur_new_embed, cur_new_labels) in enumerate(
1575
+ zip(new_input_embeds, new_labels)
1576
+ ):
1577
+ cur_len = cur_new_embed.shape[0]
1578
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
1579
+ new_input_embeds_padded.append(
1580
+ torch.cat(
1581
+ (
1582
+ torch.zeros(
1583
+ (max_len - cur_len, cur_new_embed.shape[1]),
1584
+ dtype=cur_new_embed.dtype,
1585
+ device=cur_new_embed.device,
1586
+ ),
1587
+ cur_new_embed,
1588
+ ),
1589
+ dim=0,
1590
+ )
1591
+ )
1592
+ if cur_len > 0:
1593
+ new_labels_padded[i, -cur_len:] = cur_new_labels
1594
+ attention_mask[i, -cur_len:] = True
1595
+ position_ids[i, -cur_len:] = torch.arange(
1596
+ 0,
1597
+ cur_len,
1598
+ dtype=position_ids.dtype,
1599
+ device=position_ids.device,
1600
+ )
1601
+ else:
1602
+ new_input_embeds_padded.append(
1603
+ torch.cat(
1604
+ (
1605
+ cur_new_embed,
1606
+ torch.zeros(
1607
+ (max_len - cur_len, cur_new_embed.shape[1]),
1608
+ dtype=cur_new_embed.dtype,
1609
+ device=cur_new_embed.device,
1610
+ ),
1611
+ ),
1612
+ dim=0,
1613
+ )
1614
+ )
1615
+ if cur_len > 0:
1616
+ new_labels_padded[i, :cur_len] = cur_new_labels
1617
+ attention_mask[i, :cur_len] = True
1618
+ position_ids[i, :cur_len] = torch.arange(
1619
+ 0,
1620
+ cur_len,
1621
+ dtype=position_ids.dtype,
1622
+ device=position_ids.device,
1623
+ )
1624
+
1625
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
1626
+
1627
+ if _labels is None:
1628
+ new_labels = None
1629
+ else:
1630
+ new_labels = new_labels_padded
1631
+
1632
+ if _attention_mask is None:
1633
+ attention_mask = None
1634
+ else:
1635
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
1636
+
1637
+ if _position_ids is None:
1638
+ position_ids = None
1639
+
1640
+ return (
1641
+ None,
1642
+ position_ids,
1643
+ attention_mask,
1644
+ past_key_values,
1645
+ new_input_embeds,
1646
+ new_labels,
1647
+ vision_tower_aux_feature_list_final,
1648
+ vision_tower_aux_attention_masks_list_final,
1649
+ final_size,
1650
+ global_context_feature_final,
1651
+ )
1652
+
1653
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
1654
+ if model_args.mm_use_im_patch_token:
1655
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
1656
+ self.resize_token_embeddings(len(tokenizer))
1657
+
1658
+ if model_args.mm_use_im_start_end:
1659
+ num_new_tokens = tokenizer.add_tokens(
1660
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
1661
+ )
1662
+ self.resize_token_embeddings(len(tokenizer))
1663
+
1664
+ if num_new_tokens > 0:
1665
+ input_embeddings = self.get_input_embeddings().weight.data
1666
+ output_embeddings = self.get_output_embeddings().weight.data
1667
+
1668
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
1669
+ dim=0, keepdim=True
1670
+ )
1671
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
1672
+ dim=0, keepdim=True
1673
+ )
1674
+
1675
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
1676
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
1677
+
1678
+ if model_args.tune_mm_mlp_adapter:
1679
+ for p in self.get_input_embeddings().parameters():
1680
+ p.requires_grad = True
1681
+ for p in self.get_output_embeddings().parameters():
1682
+ p.requires_grad = False
1683
+
1684
+ if model_args.pretrain_mm_mlp_adapter:
1685
+ mm_projector_weights = torch.load(
1686
+ model_args.pretrain_mm_mlp_adapter, map_location="cpu"
1687
+ )
1688
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
1689
+ assert num_new_tokens == 2
1690
+ if input_embeddings.shape == embed_tokens_weight.shape:
1691
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[
1692
+ -num_new_tokens:
1693
+ ]
1694
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
1695
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
1696
+ else:
1697
+ raise ValueError(
1698
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
1699
+ )
1700
+ elif model_args.mm_use_im_patch_token:
1701
+ if model_args.tune_mm_mlp_adapter:
1702
+ for p in self.get_input_embeddings().parameters():
1703
+ p.requires_grad = False
1704
+ for p in self.get_output_embeddings().parameters():
1705
+ p.requires_grad = False
longvu/consolidate.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ """
3
+ Usage:
4
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
5
+ """
6
+
7
+ import argparse
8
+
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from longvu import * # noqa
12
+ from .utils import auto_upgrade
13
+
14
+
15
+ def consolidate_ckpt(src_path, dst_path):
16
+ print("Loading model")
17
+ auto_upgrade(src_path)
18
+ src_model = AutoModelForCausalLM.from_pretrained(
19
+ src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
20
+ )
21
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
22
+ src_model.save_pretrained(dst_path)
23
+ src_tokenizer.save_pretrained(dst_path)
24
+
25
+
26
+ if __name__ == "__main__":
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("--src", type=str, required=True)
29
+ parser.add_argument("--dst", type=str, required=True)
30
+
31
+ args = parser.parse_args()
32
+
33
+ consolidate_ckpt(args.src, args.dst)
longvu/constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
longvu/conversation.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import dataclasses
3
+ from enum import auto, Enum
4
+ from io import BytesIO
5
+ from typing import Any, Dict, List, Tuple, Union
6
+
7
+ from longvu.file_io import PathManager
8
+
9
+ from PIL import Image
10
+ from transformers import AutoTokenizer
11
+
12
+
13
+ class SeparatorStyle(Enum):
14
+ """Different separator style."""
15
+
16
+ SINGLE = auto()
17
+ TWO = auto()
18
+ MPT = auto()
19
+ PLAIN = auto()
20
+ LLAMA_2 = auto()
21
+ LLAMA_3 = auto()
22
+ LLAMA_3_1 = auto()
23
+ LLAMA_3_2 = auto()
24
+ QWEN = auto()
25
+ CHATML = auto()
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class Conversation:
30
+ """A class that keeps all conversation history."""
31
+
32
+ system: str
33
+ roles: List[str]
34
+ messages: List[List[str]]
35
+ offset: int
36
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
37
+ sep: str = "###"
38
+ # pyre-fixme[8]: Attribute has type `str`; used as `None`.
39
+ sep2: str = None
40
+ version: str = "Unknown"
41
+
42
+ tokenizer: Any = None
43
+ # Stop criteria (the default one is EOS token)
44
+ # pyre-fixme[8]: Attribute has type `Union[List[str], str]`; used as `None`.
45
+ stop_str: Union[str, List[str]] = None
46
+ # Stops generation if meeting any token in this list
47
+ # pyre-fixme[8]: Attribute has type `List[int]`; used as `None`.
48
+ stop_token_ids: List[int] = None
49
+
50
+ skip_next: bool = False
51
+
52
+ def get_prompt(self):
53
+ messages = self.messages
54
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
55
+ messages = self.messages.copy()
56
+ init_role, init_msg = messages[0].copy()
57
+ init_msg = init_msg[0].replace("<image>", "").strip()
58
+ if "mmtag" in self.version:
59
+ messages[0] = (init_role, init_msg)
60
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
61
+ messages.insert(1, (self.roles[1], "Received."))
62
+ else:
63
+ messages[0] = (init_role, "<image>\n" + init_msg)
64
+
65
+ if self.sep_style == SeparatorStyle.SINGLE:
66
+ ret = self.system + self.sep
67
+ for role, message in messages:
68
+ if message:
69
+ if type(message) is tuple:
70
+ message, _, _ = message
71
+ ret += role + ": " + message + self.sep
72
+ else:
73
+ ret += role + ":"
74
+ elif self.sep_style == SeparatorStyle.TWO:
75
+ seps = [self.sep, self.sep2]
76
+ ret = self.system + seps[0]
77
+ for i, (role, message) in enumerate(messages):
78
+ if message:
79
+ if type(message) is tuple:
80
+ message, _, _ = message
81
+ ret += role + ": " + message + seps[i % 2]
82
+ else:
83
+ ret += role + ":"
84
+
85
+ elif self.sep_style == SeparatorStyle.CHATML:
86
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
87
+ for role, message in messages:
88
+ if message:
89
+ if type(message) is tuple:
90
+ message, images, _ = message
91
+ message = "<image>" * len(images) + message
92
+ ret += role + "\n" + message + self.sep + "\n"
93
+ else:
94
+ ret += role + "\n"
95
+ return ret
96
+
97
+ elif self.sep_style == SeparatorStyle.MPT:
98
+ ret = self.system + self.sep
99
+ for role, message in messages:
100
+ if message:
101
+ if type(message) is tuple:
102
+ message, _, _ = message
103
+ ret += role + message + self.sep
104
+ else:
105
+ ret += role
106
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
107
+ wrap_sys = lambda msg: (
108
+ f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
109
+ )
110
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
111
+ ret = ""
112
+
113
+ for i, (role, message) in enumerate(messages):
114
+ if i == 0:
115
+ assert message, "first message should not be none"
116
+ assert role == self.roles[0], "first message should come from user"
117
+ if message:
118
+ if type(message) is tuple:
119
+ message, _, _ = message
120
+ if i == 0:
121
+ message = wrap_sys(self.system) + message
122
+ if i % 2 == 0:
123
+ message = wrap_inst(message)
124
+ ret += self.sep + message
125
+ else:
126
+ ret += " " + message + " " + self.sep2
127
+ else:
128
+ ret += ""
129
+ ret = ret.lstrip(self.sep)
130
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
131
+ if self.tokenizer is None:
132
+ self.tokenizer = AutoTokenizer.from_pretrained(
133
+ PathManager.get_local_path(
134
+ "manifold://xr_core_ai_asl_llm/tree/users/shenx/models/Cambrian-Llama3_1-8b-t576/"
135
+ )
136
+ )
137
+ chat_template_messages = [{"role": "system", "content": self.system}]
138
+ for role, message in messages:
139
+ if message:
140
+ if type(message) is tuple:
141
+ message, images = message
142
+ message = "<image>" * len(images) + message
143
+ chat_template_messages.append({"role": role, "content": message})
144
+
145
+ # print("chat", chat_template_messages, flush=True)
146
+ return self.tokenizer.apply_chat_template(
147
+ chat_template_messages, tokenize=False, add_generation_prompt=True
148
+ )
149
+ elif self.sep_style == SeparatorStyle.LLAMA_3_1:
150
+ if self.tokenizer is None:
151
+ self.tokenizer = AutoTokenizer.from_pretrained(
152
+ PathManager.get_local_path(
153
+ "manifold://xr_core_ai_asl_llm/tree/users/shenx/models/Cambrian-Llama3_1-8b-t576/"
154
+ )
155
+ )
156
+ chat_template_messages = [{"role": "system", "content": self.system}]
157
+ for role, message in messages:
158
+ if message:
159
+ if type(message) is tuple:
160
+ message, images = message
161
+ message = "<image>" * len(images) + message
162
+ chat_template_messages.append({"role": role, "content": message})
163
+
164
+ return self.tokenizer.apply_chat_template(
165
+ chat_template_messages, tokenize=False, add_generation_prompt=False
166
+ )
167
+ elif (
168
+ # self.sep_style == SeparatorStyle.LLAMA_3 or
169
+ self.sep_style
170
+ == SeparatorStyle.LLAMA_3_2
171
+ ):
172
+ wrap_sys = lambda msg: (
173
+ f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{msg}<|eot_id|>"
174
+ if len(msg) > 0
175
+ else msg
176
+ )
177
+ wrap_inst_user = (
178
+ lambda msg: f"<|start_header_id|>user<|end_header_id|>{msg}<|eot_id|>"
179
+ )
180
+ wrap_inst_assistant = (
181
+ lambda msg: f"<|start_header_id|>assistant<|end_header_id|>{msg}<|eot_id|>"
182
+ )
183
+ ret = ""
184
+
185
+ for i, (role, message) in enumerate(messages):
186
+ if i == 0:
187
+ assert message, "first message should not be none"
188
+ assert role == self.roles[0], "first message should come from user"
189
+ if message:
190
+ if type(message) is tuple:
191
+ message, _, _ = message
192
+ if i == 0:
193
+ ret += wrap_sys(self.system)
194
+
195
+ if i % 2 == 0:
196
+ message = wrap_inst_user(message)
197
+ ret += message
198
+ else:
199
+ message = wrap_inst_assistant(message)
200
+ ret += message
201
+ else:
202
+ ret += ""
203
+ ret += "<|start_header_id|>assistant<|end_header_id|>"
204
+ elif self.sep_style == SeparatorStyle.PLAIN:
205
+ seps = [self.sep, self.sep2]
206
+ ret = self.system
207
+ for i, (role, message) in enumerate(messages):
208
+ if message:
209
+ if type(message) is tuple:
210
+ message, _, _ = message
211
+ ret += message + seps[i % 2]
212
+ else:
213
+ ret += ""
214
+ else:
215
+ raise ValueError(f"Invalid style: {self.sep_style}")
216
+
217
+ return ret
218
+
219
+ def append_message(self, role, message):
220
+ self.messages.append([role, message])
221
+
222
+ def process_image(
223
+ self,
224
+ image,
225
+ image_process_mode,
226
+ return_pil=False,
227
+ image_format="PNG",
228
+ max_len=1344,
229
+ min_len=672,
230
+ ):
231
+ if image_process_mode == "Pad":
232
+
233
+ def expand2square(pil_img, background_color=(122, 116, 104)):
234
+ width, height = pil_img.size
235
+ if width == height:
236
+ return pil_img
237
+ elif width > height:
238
+ result = Image.new(pil_img.mode, (width, width), background_color)
239
+ result.paste(pil_img, (0, (width - height) // 2))
240
+ return result
241
+ else:
242
+ result = Image.new(pil_img.mode, (height, height), background_color)
243
+ result.paste(pil_img, ((height - width) // 2, 0))
244
+ return result
245
+
246
+ image = expand2square(image)
247
+ elif image_process_mode in ["Default", "Crop"]:
248
+ pass
249
+ elif image_process_mode == "Resize":
250
+ image = image.resize((336, 336))
251
+ else:
252
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
253
+ if max(image.size) > max_len:
254
+ max_hw, min_hw = max(image.size), min(image.size)
255
+ aspect_ratio = max_hw / min_hw
256
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
257
+ longest_edge = int(shortest_edge * aspect_ratio)
258
+ W, H = image.size
259
+ if H > W:
260
+ H, W = longest_edge, shortest_edge
261
+ else:
262
+ H, W = shortest_edge, longest_edge
263
+ image = image.resize((W, H))
264
+ if return_pil:
265
+ return image
266
+ else:
267
+ buffered = BytesIO()
268
+ image.save(buffered, format=image_format)
269
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
270
+ return img_b64_str
271
+
272
+ def get_images(self, return_pil=False):
273
+ images = []
274
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
275
+ if i % 2 == 0:
276
+ if type(msg) is tuple:
277
+ msg, image, image_process_mode = msg
278
+ image = self.process_image(
279
+ image, image_process_mode, return_pil=return_pil
280
+ )
281
+ images.append(image)
282
+ return images
283
+
284
+ def to_gradio_chatbot(self):
285
+ ret = []
286
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
287
+ if i % 2 == 0:
288
+ if type(msg) is tuple:
289
+ msg, image, image_process_mode = msg
290
+ img_b64_str = self.process_image(
291
+ image, "Default", return_pil=False, image_format="JPEG"
292
+ )
293
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
294
+ msg = img_str + msg.replace("<image>", "").strip()
295
+ ret.append([msg, None])
296
+ else:
297
+ ret.append([msg, None])
298
+ else:
299
+ ret[-1][-1] = msg
300
+ return ret
301
+
302
+ def copy(self):
303
+ return Conversation(
304
+ system=self.system,
305
+ roles=self.roles,
306
+ messages=[[x, y] for x, y in self.messages],
307
+ offset=self.offset,
308
+ sep_style=self.sep_style,
309
+ sep=self.sep,
310
+ sep2=self.sep2,
311
+ version=self.version,
312
+ )
313
+
314
+ def dict(self):
315
+ if len(self.get_images()) > 0:
316
+ return {
317
+ "system": self.system,
318
+ "roles": self.roles,
319
+ "messages": [
320
+ [x, y[0] if type(y) is tuple else y] for x, y in self.messages
321
+ ],
322
+ "offset": self.offset,
323
+ "sep": self.sep,
324
+ "sep2": self.sep2,
325
+ }
326
+ return {
327
+ "system": self.system,
328
+ "roles": self.roles,
329
+ "messages": self.messages,
330
+ "offset": self.offset,
331
+ "sep": self.sep,
332
+ "sep2": self.sep2,
333
+ }
334
+
335
+
336
+ conv_vicuna_v0 = Conversation(
337
+ system="A chat between a curious human and an artificial intelligence assistant. "
338
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
339
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
340
+ roles=("Human", "Assistant"),
341
+ # pyre-fixme[6]: For 3rd argument expected `List[List[str]]` but got
342
+ # `Tuple[Tuple[str, str], Tuple[str, str]]`.
343
+ messages=(
344
+ (
345
+ "Human",
346
+ "What are the key differences between renewable and non-renewable energy sources?",
347
+ ),
348
+ (
349
+ "Assistant",
350
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
351
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
352
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
353
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
354
+ "renewable and non-renewable energy sources:\n"
355
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
356
+ "energy sources are finite and will eventually run out.\n"
357
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
358
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
359
+ "and other negative effects.\n"
360
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
361
+ "have lower operational costs than non-renewable sources.\n"
362
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
363
+ "locations than non-renewable sources.\n"
364
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
365
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
366
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
367
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
368
+ ),
369
+ ),
370
+ offset=2,
371
+ sep_style=SeparatorStyle.SINGLE,
372
+ sep="###",
373
+ )
374
+
375
+ conv_vicuna_v1 = Conversation(
376
+ system="A chat between a curious user and an artificial intelligence assistant. "
377
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
378
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
379
+ roles=("USER", "ASSISTANT"),
380
+ version="v1",
381
+ # pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
382
+ messages=(),
383
+ offset=0,
384
+ sep_style=SeparatorStyle.TWO,
385
+ sep=" ",
386
+ sep2="</s>",
387
+ )
388
+
389
+ conv_llama_2 = Conversation(
390
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
391
+
392
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
393
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
394
+ roles=("USER", "ASSISTANT"),
395
+ version="llama_v2",
396
+ # pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
397
+ messages=(),
398
+ offset=0,
399
+ sep_style=SeparatorStyle.LLAMA_2,
400
+ sep="<s>",
401
+ sep2="</s>",
402
+ )
403
+
404
+ conv_llava_llama_2 = Conversation(
405
+ system="You are a helpful language and vision assistant. "
406
+ "You are able to understand the visual content that the user provides, "
407
+ "and assist the user with a variety of tasks using natural language.",
408
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
409
+ roles=("USER", "ASSISTANT"),
410
+ version="llama_v2",
411
+ # pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
412
+ messages=(),
413
+ offset=0,
414
+ sep_style=SeparatorStyle.LLAMA_2,
415
+ sep="<s>",
416
+ sep2="</s>",
417
+ )
418
+
419
+ conv_mpt = Conversation(
420
+ system="""<|im_start|>system
421
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
422
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
423
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
424
+ version="mpt",
425
+ # pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
426
+ messages=(),
427
+ offset=0,
428
+ sep_style=SeparatorStyle.MPT,
429
+ sep="<|im_end|>",
430
+ )
431
+
432
+ conv_llava_plain = Conversation(
433
+ system="",
434
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
435
+ roles=("", ""),
436
+ # pyre-fixme[6]: For 3rd argument expected `List[List[str]]` but got `Tuple[]`.
437
+ messages=(),
438
+ offset=0,
439
+ sep_style=SeparatorStyle.PLAIN,
440
+ sep="\n",
441
+ version="plain",
442
+ )
443
+
444
+ conv_llava_v0 = Conversation(
445
+ system="A chat between a curious human and an artificial intelligence assistant. "
446
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
447
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
448
+ roles=("Human", "Assistant"),
449
+ # pyre-fixme[6]: For 3rd argument expected `List[List[str]]` but got `Tuple[]`.
450
+ messages=(),
451
+ offset=0,
452
+ sep_style=SeparatorStyle.SINGLE,
453
+ sep="###",
454
+ )
455
+
456
+ conv_llava_v0_mmtag = Conversation(
457
+ system="A chat between a curious user and an artificial intelligence assistant. "
458
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
459
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
460
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
461
+ roles=("Human", "Assistant"),
462
+ # pyre-fixme[6]: For 3rd argument expected `List[List[str]]` but got `Tuple[]`.
463
+ messages=(),
464
+ offset=0,
465
+ sep_style=SeparatorStyle.SINGLE,
466
+ sep="###",
467
+ version="v0_mmtag",
468
+ )
469
+
470
+ conv_llava_v1 = Conversation(
471
+ system="A chat between a curious human and an artificial intelligence assistant. "
472
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
473
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
474
+ roles=("USER", "ASSISTANT"),
475
+ version="v1",
476
+ # pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
477
+ messages=(),
478
+ offset=0,
479
+ sep_style=SeparatorStyle.TWO,
480
+ sep=" ",
481
+ sep2="</s>",
482
+ )
483
+
484
+ conv_llava_v1_mmtag = Conversation(
485
+ system="A chat between a curious user and an artificial intelligence assistant. "
486
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
487
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
488
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
489
+ roles=("USER", "ASSISTANT"),
490
+ # pyre-fixme[6]: For 3rd argument expected `List[List[str]]` but got `Tuple[]`.
491
+ messages=(),
492
+ offset=0,
493
+ sep_style=SeparatorStyle.TWO,
494
+ sep=" ",
495
+ sep2="</s>",
496
+ version="v1_mmtag",
497
+ )
498
+
499
+ conv_mistral_instruct = Conversation(
500
+ system="",
501
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
502
+ roles=("USER", "ASSISTANT"),
503
+ version="llama_v2",
504
+ # pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
505
+ messages=(),
506
+ offset=0,
507
+ sep_style=SeparatorStyle.LLAMA_2,
508
+ sep="",
509
+ sep2="</s>",
510
+ )
511
+
512
+ conv_chatml_direct = Conversation(
513
+ system="""<|im_start|>system
514
+ Answer the questions.""",
515
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
516
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
517
+ version="mpt",
518
+ # pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
519
+ messages=(),
520
+ offset=0,
521
+ sep_style=SeparatorStyle.MPT,
522
+ sep="<|im_end|>",
523
+ )
524
+
525
+ # llama3_tokenizer = AutoTokenizer.from_pretrained(
526
+ # PathManager.get_local_path(
527
+ # "./checkpoint/"
528
+ # )
529
+ # )
530
+
531
+ conv_llama3 = Conversation(
532
+ system="""As a multimodal AI, you have the ability to process and analyze images. Whenever an image is present in the conversation, very carefully examine it and consider its content when formulating your response. You should give concise responses to very simple questions, but provide thorough responses to more complex and open-ended questions.""",
533
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
534
+ roles=("user", "assistant"),
535
+ version="llama3",
536
+ # pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
537
+ messages=(),
538
+ offset=0,
539
+ sep_style=SeparatorStyle.LLAMA_3,
540
+ # tokenizer=llama3_tokenizer,
541
+ sep="<|eot_id|>",
542
+ )
543
+
544
+ conv_llama3_2 = Conversation(
545
+ system="""You are a helpful assistant.""",
546
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
547
+ roles=("user", "assistant"),
548
+ version="llama3_2",
549
+ # pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
550
+ messages=(),
551
+ offset=0,
552
+ sep_style=SeparatorStyle.LLAMA_3_2,
553
+ sep="<|eot_id|>",
554
+ )
555
+
556
+ conv_phi3_instruct = Conversation(
557
+ system="""<|system|>\nYou are a helpful AI assistant.""",
558
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
559
+ roles=("\n<|user|>\n", "\n<|assistant|>\n"),
560
+ version="phi3",
561
+ # pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
562
+ messages=(),
563
+ offset=0,
564
+ sep_style=SeparatorStyle.MPT,
565
+ sep="<|end|>",
566
+ )
567
+
568
+ conv_qwen = Conversation(
569
+ system="""<|im_start|>system
570
+ You are a helpful assistant.""",
571
+ # pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
572
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
573
+ version="qwen",
574
+ messages=[],
575
+ offset=0,
576
+ sep_style=SeparatorStyle.CHATML,
577
+ sep="<|im_end|>",
578
+ )
579
+
580
+ default_conversation = conv_vicuna_v1
581
+ conv_templates = {
582
+ "default": conv_vicuna_v0,
583
+ "v0": conv_vicuna_v0,
584
+ "v1": conv_vicuna_v1,
585
+ "vicuna_v1": conv_vicuna_v1,
586
+ "llama_2": conv_llama_2,
587
+ "mistral_instruct": conv_mistral_instruct,
588
+ "chatml_direct": conv_chatml_direct,
589
+ "mistral_direct": conv_chatml_direct,
590
+ "plain": conv_llava_plain,
591
+ "v0_plain": conv_llava_plain,
592
+ "llava_v0": conv_llava_v0,
593
+ "v0_mmtag": conv_llava_v0_mmtag,
594
+ "llava_v1": conv_llava_v1,
595
+ "v1_mmtag": conv_llava_v1_mmtag,
596
+ "llava_llama_2": conv_llava_llama_2,
597
+ "mpt": conv_mpt,
598
+ "llama3": conv_llama3,
599
+ "llama3_2": conv_llama3_2,
600
+ "phi3": conv_phi3_instruct,
601
+ "qwen": conv_qwen,
602
+ }
603
+
604
+
605
+ if __name__ == "__main__":
606
+ print(default_conversation.get_prompt())
longvu/file_io.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and its affiliates.
2
+
3
+ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
4
+
5
+ from iopath.common.file_io import HTTPURLHandler, PathManager as PathManagerBase
6
+
7
+ __all__ = ["PathManager"]
8
+
9
+
10
+ PathManager = PathManagerBase()
11
+ PathManager.register_handler(HTTPURLHandler())
longvu/language_model/__pycache__/cambrian_llama.cpython-310.pyc ADDED
Binary file (8.51 kB). View file
 
longvu/language_model/__pycache__/cambrian_qwen.cpython-310.pyc ADDED
Binary file (7.98 kB). View file
 
longvu/language_model/cambrian_llama.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn import CrossEntropyLoss
22
+
23
+ from transformers import (
24
+ AutoConfig,
25
+ AutoModelForCausalLM,
26
+ LlamaConfig,
27
+ LlamaForCausalLM,
28
+ LlamaModel,
29
+ )
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.generation.utils import GenerateOutput
32
+
33
+ from transformers.modeling_attn_mask_utils import (
34
+ _prepare_4d_causal_attention_mask,
35
+ _prepare_4d_causal_attention_mask_for_sdpa,
36
+ )
37
+
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPast,
40
+ CausalLMOutputWithPast,
41
+ )
42
+ from transformers.utils import logging
43
+
44
+ from ..cambrian_arch import CambrianMetaForCausalLM, CambrianMetaModel
45
+
46
+ IS_XLA_AVAILABLE = False
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class CambrianConfig(LlamaConfig):
52
+ model_type = "cambrian_llama"
53
+
54
+ debug = "debug"
55
+
56
+
57
+ class CambrianLlamaModel(CambrianMetaModel, LlamaModel):
58
+ config_class = CambrianConfig
59
+
60
+ def __init__(self, config: LlamaConfig):
61
+ super(CambrianLlamaModel, self).__init__(config)
62
+
63
+ def forward(
64
+ self,
65
+ # pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
66
+ input_ids: torch.LongTensor = None,
67
+ attention_mask: Optional[torch.Tensor] = None,
68
+ position_ids: Optional[torch.LongTensor] = None,
69
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
70
+ inputs_embeds: Optional[torch.FloatTensor] = None,
71
+ use_cache: Optional[bool] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ return_dict: Optional[bool] = None,
75
+ vision_tower_aux_feature_list: Optional[List[torch.FloatTensor]] = None,
76
+ vision_tower_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
77
+ final_vision_feature_size: Optional[List[tuple]] = None,
78
+ global_context_feature: Optional[torch.Tensor] = None,
79
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
80
+
81
+ output_attentions = (
82
+ output_attentions
83
+ if output_attentions is not None
84
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `config`.
85
+ else self.config.output_attentions
86
+ )
87
+
88
+ output_hidden_states = (
89
+ output_hidden_states
90
+ if output_hidden_states is not None
91
+ else self.config.output_hidden_states
92
+ )
93
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
94
+
95
+ return_dict = (
96
+ return_dict if return_dict is not None else self.config.use_return_dict
97
+ )
98
+
99
+ # retrieve input_ids and inputs_embeds
100
+ if input_ids is not None and inputs_embeds is not None:
101
+ raise ValueError(
102
+ "You cannot specify both input_ids and inputs_embeds at the same time"
103
+ )
104
+ elif input_ids is not None:
105
+ batch_size, seq_length = input_ids.shape[:2]
106
+ elif inputs_embeds is not None:
107
+ batch_size, seq_length = inputs_embeds.shape[:2]
108
+ else:
109
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
110
+
111
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute
112
+ # `gradient_checkpointing`.
113
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `training`.
114
+ if self.gradient_checkpointing and self.training:
115
+ if use_cache:
116
+ logger.warning_once(
117
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
118
+ )
119
+ use_cache = False
120
+
121
+ past_key_values_length = 0
122
+ if use_cache:
123
+ use_legacy_cache = not isinstance(past_key_values, Cache)
124
+ if use_legacy_cache:
125
+ # pyre-fixme[9]: past_key_values has type
126
+ # `Optional[List[FloatTensor]]`; used as `DynamicCache`.
127
+ # pyre-fixme[6]: For 1st argument expected
128
+ # `Optional[Tuple[Tuple[FloatTensor]]]` but got
129
+ # `Optional[List[FloatTensor]]`.
130
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
131
+ # pyre-fixme[16]: `Optional` has no attribute `get_usable_length`.
132
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
133
+
134
+ if position_ids is None:
135
+ # pyre-fixme[16]: `Optional` has no attribute `device`.
136
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
137
+ position_ids = torch.arange(
138
+ past_key_values_length,
139
+ seq_length + past_key_values_length,
140
+ dtype=torch.long,
141
+ device=device,
142
+ )
143
+ position_ids = position_ids.unsqueeze(0)
144
+
145
+ if inputs_embeds is None:
146
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `embed_tokens`.
147
+ inputs_embeds = self.embed_tokens(input_ids)
148
+
149
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute
150
+ # `_use_flash_attention_2`.
151
+ self._use_flash_attention_2 = getattr(self, "_use_flash_attention_2", False)
152
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `_use_sdpa`.
153
+ self._use_sdpa = getattr(self, "_use_sdpa", True)
154
+ if self._use_flash_attention_2:
155
+ # 2d mask is passed through the layers
156
+ attention_mask = (
157
+ attention_mask
158
+ if (attention_mask is not None and 0 in attention_mask)
159
+ else None
160
+ )
161
+ elif self._use_sdpa and not output_attentions:
162
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
163
+ # the manual implementation that requires a 4D causal mask in all cases.
164
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
165
+ attention_mask,
166
+ (batch_size, seq_length),
167
+ inputs_embeds,
168
+ past_key_values_length,
169
+ )
170
+ else:
171
+ # 4d mask is passed through the layers
172
+ attention_mask = _prepare_4d_causal_attention_mask(
173
+ attention_mask,
174
+ (batch_size, seq_length),
175
+ inputs_embeds,
176
+ past_key_values_length,
177
+ )
178
+
179
+ # embed positions
180
+ hidden_states = inputs_embeds
181
+ # decoder layers
182
+ all_hidden_states = () if output_hidden_states else None
183
+ all_self_attns = () if output_attentions else None
184
+ next_decoder_cache = None
185
+
186
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `layers`.
187
+ for i, decoder_layer in enumerate(self.layers):
188
+ if output_hidden_states:
189
+ all_hidden_states += (hidden_states,)
190
+
191
+ if self.gradient_checkpointing and self.training:
192
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute
193
+ # `_gradient_checkpointing_func`.
194
+ layer_outputs = self._gradient_checkpointing_func(
195
+ decoder_layer.__call__,
196
+ hidden_states,
197
+ attention_mask,
198
+ position_ids,
199
+ past_key_values,
200
+ output_attentions,
201
+ use_cache,
202
+ )
203
+ else:
204
+ layer_outputs = decoder_layer(
205
+ hidden_states,
206
+ attention_mask=attention_mask,
207
+ position_ids=position_ids,
208
+ past_key_value=past_key_values,
209
+ output_attentions=output_attentions,
210
+ use_cache=use_cache,
211
+ )
212
+
213
+ hidden_states = layer_outputs[0]
214
+
215
+ if use_cache:
216
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
217
+
218
+ if output_attentions:
219
+ all_self_attns += (layer_outputs[1],)
220
+
221
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `norm`.
222
+ hidden_states = self.norm(hidden_states)
223
+
224
+ # add hidden states from the last decoder layer
225
+ if output_hidden_states:
226
+ all_hidden_states += (hidden_states,)
227
+
228
+ next_cache = None
229
+ if use_cache:
230
+ next_cache = (
231
+ next_decoder_cache.to_legacy_cache()
232
+ # pyre-fixme[61]: `use_legacy_cache` is undefined, or not always
233
+ # defined.
234
+ if use_legacy_cache
235
+ else next_decoder_cache
236
+ )
237
+ if not return_dict:
238
+ return tuple(
239
+ v
240
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
241
+ if v is not None
242
+ )
243
+ return BaseModelOutputWithPast(
244
+ last_hidden_state=hidden_states,
245
+ past_key_values=next_cache,
246
+ hidden_states=all_hidden_states,
247
+ attentions=all_self_attns,
248
+ )
249
+
250
+
251
+ class CambrianLlamaForCausalLM(LlamaForCausalLM, CambrianMetaForCausalLM):
252
+ config_class = CambrianConfig
253
+
254
+ def __init__(self, config):
255
+ super(LlamaForCausalLM, self).__init__(config)
256
+
257
+ self.model = CambrianLlamaModel(config)
258
+ self.pretraining_tp = config.pretraining_tp
259
+ self.vocab_size = config.vocab_size
260
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
261
+
262
+ # Initialize weights and apply final processing
263
+ self.post_init()
264
+
265
+ def get_model(self):
266
+ return self.model
267
+
268
+ def forward(
269
+ self,
270
+ # pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
271
+ input_ids: torch.LongTensor = None,
272
+ attention_mask: Optional[torch.Tensor] = None,
273
+ position_ids: Optional[torch.LongTensor] = None,
274
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
275
+ inputs_embeds: Optional[torch.FloatTensor] = None,
276
+ labels: Optional[torch.LongTensor] = None,
277
+ use_cache: Optional[bool] = None,
278
+ output_attentions: Optional[bool] = None,
279
+ output_hidden_states: Optional[bool] = None,
280
+ images: Optional[torch.FloatTensor] = None,
281
+ image_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
282
+ image_sizes: Optional[List[List[int]]] = None,
283
+ return_dict: Optional[bool] = None,
284
+ cache_position=None,
285
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
286
+
287
+ final_vision_feature_size = None
288
+
289
+ if inputs_embeds is None:
290
+ (
291
+ input_ids,
292
+ position_ids,
293
+ attention_mask,
294
+ past_key_values,
295
+ inputs_embeds,
296
+ labels,
297
+ vision_tower_aux_feature_list,
298
+ vision_tower_aux_attention_masks_list,
299
+ final_vision_feature_size,
300
+ global_context_feature,
301
+ ) = self.prepare_inputs_labels_for_multimodal(
302
+ input_ids,
303
+ position_ids,
304
+ attention_mask,
305
+ past_key_values,
306
+ labels,
307
+ images,
308
+ image_aux_attention_masks_list,
309
+ image_sizes,
310
+ )
311
+ if IS_XLA_AVAILABLE:
312
+ # Very Important for TorchXLA
313
+ # self.model.gradient_checkpointing = False
314
+
315
+ # pyre-fixme[21]: Could not find module `torch_xla.utils.checkpoint`.
316
+ from torch_xla.utils.checkpoint import checkpoint
317
+
318
+ # self.model.gradient_checkpointing = True
319
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute
320
+ # `_gradient_checkpointing_func`.
321
+ self.model._gradient_checkpointing_func = checkpoint
322
+
323
+ output_attentions = (
324
+ output_attentions
325
+ if output_attentions is not None
326
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute `config`.
327
+ else self.config.output_attentions
328
+ )
329
+ output_hidden_states = (
330
+ output_hidden_states
331
+ if output_hidden_states is not None
332
+ else self.config.output_hidden_states
333
+ )
334
+ return_dict = (
335
+ return_dict if return_dict is not None else self.config.use_return_dict
336
+ )
337
+
338
+ # training
339
+ if IS_XLA_AVAILABLE:
340
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
341
+ # pyre-fixme[29]: `CambrianLlamaModel` is not a function.
342
+ outputs = self.model(
343
+ input_ids=input_ids,
344
+ attention_mask=attention_mask,
345
+ position_ids=position_ids,
346
+ past_key_values=past_key_values,
347
+ inputs_embeds=inputs_embeds,
348
+ use_cache=use_cache,
349
+ output_attentions=output_attentions,
350
+ output_hidden_states=output_hidden_states,
351
+ return_dict=return_dict,
352
+ # pyre-fixme[61]: `vision_tower_aux_feature_list` is undefined, or
353
+ # not always defined.
354
+ vision_tower_aux_feature_list=vision_tower_aux_feature_list,
355
+ # pyre-fixme[61]: `vision_tower_aux_attention_masks_list` is
356
+ # undefined, or not always defined.
357
+ vision_tower_aux_attention_masks_list=vision_tower_aux_attention_masks_list,
358
+ final_vision_feature_size=final_vision_feature_size,
359
+ # pyre-fixme[61]: `global_context_feature` is undefined, or not
360
+ # always defined.
361
+ global_context_feature=global_context_feature,
362
+ )
363
+
364
+ # inference
365
+ else:
366
+ if hasattr(self, "vision_tower_aux_feature_list"):
367
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
368
+ # pyre-fixme[29]: `CambrianLlamaModel` is not a function.
369
+ outputs = self.model(
370
+ input_ids=input_ids,
371
+ attention_mask=attention_mask,
372
+ position_ids=position_ids,
373
+ past_key_values=past_key_values,
374
+ inputs_embeds=inputs_embeds,
375
+ use_cache=use_cache,
376
+ output_attentions=output_attentions,
377
+ output_hidden_states=output_hidden_states,
378
+ return_dict=return_dict,
379
+ vision_tower_aux_feature_list=(
380
+ # pyre-fixme[61]: `vision_tower_aux_feature_list` is
381
+ # undefined, or not always defined.
382
+ vision_tower_aux_feature_list
383
+ if inputs_embeds is None
384
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
385
+ # attribute `vision_tower_aux_feature_list`.
386
+ else self.vision_tower_aux_feature_list
387
+ ),
388
+ vision_tower_aux_attention_masks_list=(
389
+ # pyre-fixme[61]: `vision_tower_aux_attention_masks_list` is
390
+ # undefined, or not always defined.
391
+ vision_tower_aux_attention_masks_list
392
+ if inputs_embeds is None
393
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
394
+ # attribute `vision_tower_aux_attention_masks_list`.
395
+ else self.vision_tower_aux_attention_masks_list
396
+ ),
397
+ final_vision_feature_size=(
398
+ final_vision_feature_size
399
+ if inputs_embeds is None
400
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
401
+ # attribute `final_vision_feature_size`.
402
+ else self.final_vision_feature_size
403
+ ),
404
+ global_context_feature=(
405
+ # pyre-fixme[61]: `global_context_feature` is undefined, or
406
+ # not always defined.
407
+ global_context_feature
408
+ if inputs_embeds is None
409
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
410
+ # attribute `global_context_feature`.
411
+ else self.global_context_feature
412
+ ),
413
+ )
414
+ else:
415
+ # pyre-fixme[29]: `CambrianLlamaModel` is not a function.
416
+ outputs = self.model(
417
+ input_ids=input_ids,
418
+ attention_mask=attention_mask,
419
+ position_ids=position_ids,
420
+ past_key_values=past_key_values,
421
+ inputs_embeds=inputs_embeds,
422
+ use_cache=use_cache,
423
+ output_attentions=output_attentions,
424
+ output_hidden_states=output_hidden_states,
425
+ return_dict=return_dict,
426
+ # final_vision_feature_size=final_vision_feature_size,
427
+ )
428
+
429
+ hidden_states = outputs[0]
430
+ if self.config.pretraining_tp > 1:
431
+ lm_head_slices = self.lm_head.weight.split(
432
+ self.vocab_size // self.config.pretraining_tp, dim=0
433
+ )
434
+ logits = [
435
+ F.linear(hidden_states, lm_head_slices[i])
436
+ for i in range(self.config.pretraining_tp)
437
+ ]
438
+ logits = torch.cat(logits, dim=-1)
439
+ else:
440
+ logits = self.lm_head(hidden_states)
441
+ logits = logits.float()
442
+
443
+ loss = None
444
+ if labels is not None:
445
+ # Shift so that tokens < n predict n
446
+ shift_logits = logits[..., :-1, :].contiguous()
447
+ shift_labels = labels[..., 1:].contiguous()
448
+ # Flatten the tokens
449
+ loss_fct = CrossEntropyLoss()
450
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
451
+ shift_labels = shift_labels.view(-1)
452
+ # Enable model parallelism
453
+ shift_labels = shift_labels.to(shift_logits.device)
454
+ loss = loss_fct(shift_logits, shift_labels)
455
+
456
+ if not return_dict:
457
+ output = (logits,) + outputs[1:]
458
+ return (loss,) + output if loss is not None else output
459
+
460
+ return CausalLMOutputWithPast(
461
+ loss=loss,
462
+ logits=logits,
463
+ past_key_values=outputs.past_key_values,
464
+ hidden_states=outputs.hidden_states,
465
+ attentions=outputs.attentions,
466
+ )
467
+
468
+ @torch.no_grad()
469
+ def generate(
470
+ self,
471
+ inputs: Optional[torch.Tensor] = None,
472
+ images: Optional[torch.Tensor] = None,
473
+ image_sizes: Optional[torch.Tensor] = None,
474
+ **kwargs,
475
+ ) -> Union[GenerateOutput, torch.LongTensor]:
476
+ position_ids = kwargs.pop("position_ids", None)
477
+ attention_mask = kwargs.pop("attention_mask", None)
478
+ if "inputs_embeds" in kwargs:
479
+ raise NotImplementedError("`inputs_embeds` is not supported")
480
+
481
+ if images is not None:
482
+ (
483
+ inputs,
484
+ position_ids,
485
+ attention_mask,
486
+ _,
487
+ inputs_embeds,
488
+ _,
489
+ vision_tower_aux_feature_list,
490
+ vision_tower_aux_attention_masks_list,
491
+ final_vision_feature_size,
492
+ global_context_feature,
493
+ ) = self.prepare_inputs_labels_for_multimodal(
494
+ inputs,
495
+ position_ids,
496
+ attention_mask,
497
+ None,
498
+ None,
499
+ images,
500
+ image_sizes=image_sizes,
501
+ )
502
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
503
+ # `vision_tower_aux_feature_list`.
504
+ self.vision_tower_aux_feature_list = vision_tower_aux_feature_list
505
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
506
+ # `vision_tower_aux_attention_masks_list`.
507
+ self.vision_tower_aux_attention_masks_list = (
508
+ vision_tower_aux_attention_masks_list
509
+ )
510
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
511
+ # `final_vision_feature_size`.
512
+ self.final_vision_feature_size = final_vision_feature_size
513
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
514
+ # `global_context_feature`.
515
+ self.global_context_feature = global_context_feature
516
+ else:
517
+ inputs_embeds = self.get_model().embed_tokens(inputs)
518
+
519
+ # pyre-fixme[16]: `LlamaForCausalLM` has no attribute `generate`.
520
+ return super().generate(
521
+ position_ids=position_ids,
522
+ attention_mask=attention_mask,
523
+ inputs_embeds=inputs_embeds,
524
+ **kwargs,
525
+ )
526
+
527
+ def prepare_inputs_for_generation(
528
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
529
+ ):
530
+ images = kwargs.pop("images", None)
531
+ image_sizes = kwargs.pop("image_sizes", None)
532
+ inputs = super().prepare_inputs_for_generation(
533
+ input_ids,
534
+ past_key_values=past_key_values,
535
+ inputs_embeds=inputs_embeds,
536
+ **kwargs,
537
+ )
538
+ if images is not None:
539
+ inputs["images"] = images
540
+ if image_sizes is not None:
541
+ inputs["image_sizes"] = image_sizes
542
+ return inputs
543
+
544
+
545
+ AutoConfig.register("cambrian_llama", CambrianConfig)
546
+ AutoModelForCausalLM.register(CambrianConfig, CambrianLlamaForCausalLM)
longvu/language_model/cambrian_qwen.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn import CrossEntropyLoss
22
+
23
+ from transformers import AutoConfig, AutoModelForCausalLM
24
+ from transformers.cache_utils import Cache, DynamicCache
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPast,
29
+ CausalLMOutputWithPast,
30
+ )
31
+ from transformers.utils import logging
32
+
33
+ from ..cambrian_arch import CambrianMetaForCausalLM, CambrianMetaModel
34
+
35
+ IS_XLA_AVAILABLE = False
36
+
37
+ from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2Model
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class CambrianConfig(Qwen2Config):
43
+ model_type = "cambrian_qwen"
44
+
45
+ debug = "debug"
46
+
47
+
48
+ class CambrianQwenModel(CambrianMetaModel, Qwen2Model):
49
+ config_class = CambrianConfig
50
+
51
+ def __init__(self, config: Qwen2Config):
52
+ super(CambrianQwenModel, self).__init__(config)
53
+
54
+ def forward(
55
+ self,
56
+ # pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
57
+ input_ids: torch.LongTensor = None,
58
+ attention_mask: Optional[torch.Tensor] = None,
59
+ position_ids: Optional[torch.LongTensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ use_cache: Optional[bool] = None,
63
+ output_attentions: Optional[bool] = None,
64
+ output_hidden_states: Optional[bool] = None,
65
+ return_dict: Optional[bool] = None,
66
+ cache_position: Optional[torch.LongTensor] = None,
67
+ vision_tower_aux_feature_list: Optional[List[torch.FloatTensor]] = None,
68
+ vision_tower_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
69
+ final_vision_feature_size: Optional[List[tuple]] = None,
70
+ global_context_feature: Optional[torch.Tensor] = None,
71
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
72
+ output_attentions = (
73
+ output_attentions
74
+ if output_attentions is not None
75
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `config`.
76
+ else self.config.output_attentions
77
+ )
78
+ output_hidden_states = (
79
+ output_hidden_states
80
+ if output_hidden_states is not None
81
+ else self.config.output_hidden_states
82
+ )
83
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
84
+
85
+ return_dict = (
86
+ return_dict if return_dict is not None else self.config.use_return_dict
87
+ )
88
+
89
+ if (input_ids is None) ^ (inputs_embeds is not None):
90
+ raise ValueError(
91
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
92
+ )
93
+
94
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `gradient_checkpointing`.
95
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `training`.
96
+ if self.gradient_checkpointing and self.training:
97
+ if use_cache:
98
+ logger.warning_once(
99
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
100
+ )
101
+ use_cache = False
102
+
103
+ use_legacy_cache = False
104
+ if use_cache and not isinstance(past_key_values, Cache):
105
+ use_legacy_cache = True
106
+ # pyre-fixme[6]: For 1st argument expected
107
+ # `Optional[Tuple[Tuple[FloatTensor]]]` but got
108
+ # `Optional[List[FloatTensor]]`.
109
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
110
+ logger.warning_once(
111
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
112
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
113
+ )
114
+
115
+ if inputs_embeds is None:
116
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `embed_tokens`.
117
+ inputs_embeds = self.embed_tokens(input_ids)
118
+
119
+ if cache_position is None:
120
+ past_seen_tokens = (
121
+ # pyre-fixme[16]: Item `List` of `Union[List[torch._C.FloatTensor],
122
+ # DynamicCache]` has no attribute `get_seq_length`.
123
+ past_key_values.get_seq_length() if past_key_values is not None else 0
124
+ )
125
+ cache_position = torch.arange(
126
+ past_seen_tokens,
127
+ past_seen_tokens + inputs_embeds.shape[1],
128
+ device=inputs_embeds.device,
129
+ )
130
+ if position_ids is None:
131
+ position_ids = cache_position.unsqueeze(0)
132
+
133
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `_update_causal_mask`.
134
+ causal_mask = self._update_causal_mask(
135
+ attention_mask,
136
+ inputs_embeds,
137
+ cache_position,
138
+ past_key_values,
139
+ output_attentions,
140
+ )
141
+
142
+ hidden_states = inputs_embeds
143
+
144
+ # decoder layers
145
+ all_hidden_states = () if output_hidden_states else None
146
+ all_self_attns = () if output_attentions else None
147
+ next_decoder_cache = None
148
+
149
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `layers`.
150
+ for i, decoder_layer in enumerate(self.layers):
151
+ if output_hidden_states:
152
+ all_hidden_states += (hidden_states,)
153
+
154
+ if self.gradient_checkpointing and self.training:
155
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute
156
+ # `_gradient_checkpointing_func`.
157
+ layer_outputs = self._gradient_checkpointing_func(
158
+ decoder_layer.__call__,
159
+ hidden_states,
160
+ causal_mask,
161
+ position_ids,
162
+ past_key_values,
163
+ output_attentions,
164
+ use_cache,
165
+ cache_position,
166
+ )
167
+ else:
168
+ layer_outputs = decoder_layer(
169
+ hidden_states,
170
+ attention_mask=causal_mask,
171
+ position_ids=position_ids,
172
+ past_key_value=past_key_values,
173
+ output_attentions=output_attentions,
174
+ use_cache=use_cache,
175
+ cache_position=cache_position,
176
+ )
177
+
178
+ hidden_states = layer_outputs[0]
179
+
180
+ if use_cache:
181
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
182
+
183
+ if output_attentions:
184
+ all_self_attns += (layer_outputs[1],)
185
+
186
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `norm`.
187
+ hidden_states = self.norm(hidden_states)
188
+
189
+ # add hidden states from the last decoder layer
190
+ if output_hidden_states:
191
+ all_hidden_states += (hidden_states,)
192
+
193
+ next_cache = None
194
+ if use_cache:
195
+ next_cache = (
196
+ next_decoder_cache.to_legacy_cache()
197
+ if use_legacy_cache
198
+ else next_decoder_cache
199
+ )
200
+
201
+ if not return_dict:
202
+ return tuple(
203
+ v
204
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
205
+ if v is not None
206
+ )
207
+ return BaseModelOutputWithPast(
208
+ last_hidden_state=hidden_states,
209
+ past_key_values=next_cache,
210
+ hidden_states=all_hidden_states,
211
+ attentions=all_self_attns,
212
+ )
213
+
214
+
215
+ class CambrianQwenForCausalLM(Qwen2ForCausalLM, CambrianMetaForCausalLM):
216
+ config_class = CambrianConfig
217
+
218
+ def __init__(self, config):
219
+ # super(Qwen2ForCausalLM, self).__init__(config)
220
+ Qwen2ForCausalLM.__init__(self, config)
221
+ config.model_type = "cambrian_qwen"
222
+ config.rope_scaling = None
223
+
224
+ self.model = CambrianQwenModel(config)
225
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
226
+ # Initialize weights and apply final processing
227
+ self.post_init()
228
+
229
+ def get_model(self):
230
+ return self.model
231
+
232
+ def forward(
233
+ self,
234
+ # pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
235
+ input_ids: torch.LongTensor = None,
236
+ attention_mask: Optional[torch.Tensor] = None,
237
+ position_ids: Optional[torch.LongTensor] = None,
238
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
239
+ inputs_embeds: Optional[torch.FloatTensor] = None,
240
+ labels: Optional[torch.LongTensor] = None,
241
+ use_cache: Optional[bool] = None,
242
+ output_attentions: Optional[bool] = None,
243
+ output_hidden_states: Optional[bool] = None,
244
+ images: Optional[torch.FloatTensor] = None,
245
+ image_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
246
+ image_sizes: Optional[List[List[int]]] = None,
247
+ return_dict: Optional[bool] = None,
248
+ modalities: Optional[List[str]] = ["image"],
249
+ dpo_forward: Optional[bool] = False,
250
+ cache_position=None,
251
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
252
+
253
+ input_image_features = None
254
+ highres_image_features = None
255
+ frame_split_sizes = None
256
+
257
+ if inputs_embeds is None:
258
+ (
259
+ input_ids,
260
+ position_ids,
261
+ attention_mask,
262
+ past_key_values,
263
+ inputs_embeds,
264
+ labels,
265
+ vision_tower_aux_feature_list,
266
+ vision_tower_aux_attention_masks_list,
267
+ final_vision_feature_size,
268
+ global_context_feature,
269
+ ) = self.prepare_inputs_labels_for_multimodal(
270
+ input_ids,
271
+ position_ids,
272
+ attention_mask,
273
+ past_key_values,
274
+ labels,
275
+ images,
276
+ image_aux_attention_masks_list,
277
+ image_sizes,
278
+ )
279
+
280
+ if dpo_forward:
281
+ # pyre-fixme[29]: `CambrianQwenModel` is not a function.
282
+ outputs = self.model(
283
+ input_ids=input_ids,
284
+ attention_mask=attention_mask,
285
+ position_ids=position_ids,
286
+ past_key_values=past_key_values,
287
+ inputs_embeds=inputs_embeds,
288
+ use_cache=use_cache,
289
+ output_attentions=output_attentions,
290
+ output_hidden_states=output_hidden_states,
291
+ return_dict=return_dict,
292
+ )
293
+
294
+ hidden_states = outputs[0]
295
+ logits = self.lm_head(hidden_states)
296
+ return logits, labels
297
+
298
+ else:
299
+ if hasattr(self, "vision_tower_aux_feature_list"):
300
+ # pyre-fixme[29]: `CambrianQwenModel` is not a function.
301
+ outputs = self.model(
302
+ input_ids=input_ids,
303
+ attention_mask=attention_mask,
304
+ position_ids=position_ids,
305
+ past_key_values=past_key_values,
306
+ inputs_embeds=inputs_embeds,
307
+ use_cache=use_cache,
308
+ output_attentions=output_attentions,
309
+ output_hidden_states=output_hidden_states,
310
+ return_dict=return_dict,
311
+ vision_tower_aux_feature_list=(
312
+ # pyre-fixme[61]: `vision_tower_aux_feature_list` is
313
+ # undefined, or not always defined.
314
+ vision_tower_aux_feature_list
315
+ if inputs_embeds is None
316
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
317
+ # `vision_tower_aux_feature_list`.
318
+ else self.vision_tower_aux_feature_list
319
+ ),
320
+ vision_tower_aux_attention_masks_list=(
321
+ # pyre-fixme[61]: `vision_tower_aux_attention_masks_list` is
322
+ # undefined, or not always defined.
323
+ vision_tower_aux_attention_masks_list
324
+ if inputs_embeds is None
325
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
326
+ # `vision_tower_aux_attention_masks_list`.
327
+ else self.vision_tower_aux_attention_masks_list
328
+ ),
329
+ final_vision_feature_size=(
330
+ # pyre-fixme[61]: `final_vision_feature_size` is undefined,
331
+ # or not always defined.
332
+ final_vision_feature_size
333
+ if inputs_embeds is None
334
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
335
+ # `final_vision_feature_size`.
336
+ else self.final_vision_feature_size
337
+ ),
338
+ global_context_feature=(
339
+ # pyre-fixme[61]: `global_context_feature` is undefined, or
340
+ # not always defined.
341
+ global_context_feature
342
+ if inputs_embeds is None
343
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
344
+ # `global_context_feature`.
345
+ else self.global_context_feature
346
+ ),
347
+ )
348
+ else:
349
+ # pyre-fixme[29]: `CambrianQwenModel` is not a function.
350
+ outputs = self.model(
351
+ input_ids=input_ids,
352
+ attention_mask=attention_mask,
353
+ position_ids=position_ids,
354
+ past_key_values=past_key_values,
355
+ inputs_embeds=inputs_embeds,
356
+ use_cache=use_cache,
357
+ output_attentions=output_attentions,
358
+ output_hidden_states=output_hidden_states,
359
+ return_dict=return_dict,
360
+ # final_vision_feature_size=final_vision_feature_size,
361
+ )
362
+
363
+ hidden_states = outputs[0]
364
+ logits = self.lm_head(hidden_states)
365
+ logits = logits.float()
366
+
367
+ loss = None
368
+ if labels is not None:
369
+ # Shift so that tokens < n predict n
370
+ shift_logits = logits[..., :-1, :].contiguous()
371
+ shift_labels = labels[..., 1:].contiguous()
372
+ # Flatten the tokens
373
+ loss_fct = CrossEntropyLoss()
374
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute `config`.
375
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
376
+ shift_labels = shift_labels.view(-1)
377
+ # Enable model parallelism
378
+ shift_labels = shift_labels.to(shift_logits.device)
379
+ loss = loss_fct(shift_logits, shift_labels)
380
+
381
+ if not return_dict:
382
+ output = (logits,) + outputs[1:]
383
+ return (loss,) + output if loss is not None else output
384
+
385
+ return CausalLMOutputWithPast(
386
+ loss=loss,
387
+ logits=logits,
388
+ past_key_values=outputs.past_key_values,
389
+ hidden_states=outputs.hidden_states,
390
+ attentions=outputs.attentions,
391
+ )
392
+
393
+ @torch.no_grad()
394
+ def generate(
395
+ self,
396
+ inputs: Optional[torch.Tensor] = None,
397
+ images: Optional[torch.Tensor] = None,
398
+ image_sizes: Optional[torch.Tensor] = None,
399
+ **kwargs,
400
+ ) -> Union[GenerateOutput, torch.LongTensor]:
401
+ position_ids = kwargs.pop("position_ids", None)
402
+ attention_mask = kwargs.pop("attention_mask", None)
403
+ if "inputs_embeds" in kwargs:
404
+ raise NotImplementedError("`inputs_embeds` is not supported")
405
+
406
+ if images is not None:
407
+ (
408
+ inputs,
409
+ position_ids,
410
+ attention_mask,
411
+ _,
412
+ inputs_embeds,
413
+ _,
414
+ vision_tower_aux_feature_list,
415
+ vision_tower_aux_attention_masks_list,
416
+ final_vision_feature_size,
417
+ global_context_feature,
418
+ ) = self.prepare_inputs_labels_for_multimodal(
419
+ inputs,
420
+ position_ids,
421
+ attention_mask,
422
+ None,
423
+ None,
424
+ images,
425
+ image_sizes=image_sizes,
426
+ )
427
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
428
+ # `vision_tower_aux_feature_list`.
429
+ self.vision_tower_aux_feature_list = vision_tower_aux_feature_list
430
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
431
+ # `vision_tower_aux_attention_masks_list`.
432
+ self.vision_tower_aux_attention_masks_list = (
433
+ vision_tower_aux_attention_masks_list
434
+ )
435
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
436
+ # `final_vision_feature_size`.
437
+ self.final_vision_feature_size = final_vision_feature_size
438
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
439
+ # `global_context_feature`.
440
+ self.global_context_feature = global_context_feature
441
+ else:
442
+ inputs_embeds = self.get_model().embed_tokens(inputs)
443
+
444
+ # pyre-fixme[16]: `Qwen2ForCausalLM` has no attribute `generate`.
445
+ return super().generate(
446
+ position_ids=position_ids,
447
+ attention_mask=attention_mask,
448
+ inputs_embeds=inputs_embeds,
449
+ **kwargs,
450
+ )
451
+
452
+ def prepare_inputs_for_generation(
453
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
454
+ ):
455
+ images = kwargs.pop("images", None)
456
+ image_sizes = kwargs.pop("image_sizes", None)
457
+ inputs = super().prepare_inputs_for_generation(
458
+ input_ids,
459
+ past_key_values=past_key_values,
460
+ inputs_embeds=inputs_embeds,
461
+ **kwargs,
462
+ )
463
+ if images is not None:
464
+ inputs["images"] = images
465
+ if image_sizes is not None:
466
+ inputs["image_sizes"] = image_sizes
467
+ return inputs
468
+
469
+
470
+ AutoConfig.register("cambrian_qwen", CambrianConfig)
471
+ AutoModelForCausalLM.register(CambrianConfig, CambrianQwenForCausalLM)
longvu/make_delta.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ """
3
+ Usage:
4
+ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
5
+ """
6
+
7
+ import argparse
8
+
9
+ import torch
10
+ from tqdm import tqdm
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+
13
+ from .utils import auto_upgrade
14
+
15
+
16
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
17
+ print("Loading base model")
18
+ base = AutoModelForCausalLM.from_pretrained(
19
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
20
+ )
21
+
22
+ print("Loading target model")
23
+ auto_upgrade(target_model_path)
24
+ target = AutoModelForCausalLM.from_pretrained(
25
+ target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
26
+ )
27
+
28
+ print("Calculating delta")
29
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
30
+ if name not in base.state_dict():
31
+ assert name in [
32
+ "model.mm_projector.weight",
33
+ "model.mm_projector.bias",
34
+ ], f"{name} not in base model"
35
+ continue
36
+ if param.data.shape == base.state_dict()[name].shape:
37
+ param.data -= base.state_dict()[name]
38
+ else:
39
+ assert name in [
40
+ "model.embed_tokens.weight",
41
+ "lm_head.weight",
42
+ ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
43
+ bparam = base.state_dict()[name]
44
+ param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
45
+
46
+ print("Saving delta")
47
+ if hub_repo_id:
48
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
49
+ else:
50
+ kwargs = {}
51
+ target.save_pretrained(delta_path, **kwargs)
52
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
53
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument("--base-model-path", type=str, required=True)
59
+ parser.add_argument("--target-model-path", type=str, required=True)
60
+ parser.add_argument("--delta-path", type=str, required=True)
61
+ parser.add_argument("--hub-repo-id", type=str, default=None)
62
+ args = parser.parse_args()
63
+
64
+ make_delta(
65
+ args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id
66
+ )
longvu/mm_datautils.py ADDED
@@ -0,0 +1,1688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-strict
2
+ import copy
3
+ import json
4
+ import os
5
+ import random
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List, Sequence
8
+
9
+ import numpy as np
10
+ import tokenizers
11
+
12
+ import torch
13
+
14
+ import transformers
15
+
16
+ from longvu import conversation as conversation_lib
17
+
18
+ from longvu.constants import (
19
+ DEFAULT_IM_END_TOKEN,
20
+ DEFAULT_IM_START_TOKEN,
21
+ DEFAULT_IMAGE_TOKEN,
22
+ IGNORE_INDEX,
23
+ IMAGE_TOKEN_INDEX,
24
+ )
25
+
26
+ # pyre-fixme[21]: Could not find module `decord`.
27
+ from decord import cpu, VideoReader # @manual=fbsource//third-party/pypi/decord:decord
28
+
29
+ from packaging import version
30
+ from PIL import Image
31
+ from torch import distributed as dist
32
+ from torch.distributed.fsdp import (
33
+ FullStateDictConfig,
34
+ FullyShardedDataParallel as FSDP,
35
+ StateDictType,
36
+ )
37
+ from torch.utils.data import Dataset
38
+
39
+ # pyre-fixme
40
+ IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse(
41
+ "0.14"
42
+ )
43
+ from transformers import StoppingCriteria
44
+
45
+ from longvu.mm_utils import KeywordsStoppingCriteria
46
+
47
+
48
+ # pyre-fixme[3]: Return type must be annotated.
49
+ # pyre-fixme[2]: Parameter must be annotated.
50
+ def maybe_zero_3(param, ignore_status: bool = False, name=None):
51
+ # NO deepspeed
52
+
53
+ # from deepspeed import zero
54
+ # from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
55
+ # if hasattr(param, "ds_id"):
56
+ # if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
57
+ # if not ignore_status:
58
+ # print(name, 'no ignore status')
59
+ # with zero.GatheredParameters([param]):
60
+ # param = param.data.detach().cpu().clone()
61
+ # else:
62
+ # param = param.detach().cpu().clone()
63
+ return param.detach().cpu().clone()
64
+
65
+
66
+ # pyre-fixme[3]: Return type must be annotated.
67
+ # pyre-fixme[2]: Parameter must be annotated.
68
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
69
+ to_return = {
70
+ k: t
71
+ for k, t in named_params
72
+ if any(key_match in k for key_match in keys_to_match)
73
+ }
74
+ to_return = {
75
+ k: maybe_zero_3(v, ignore_status=True, name=k).cpu()
76
+ for k, v in to_return.items()
77
+ }
78
+ return to_return
79
+
80
+
81
+ # pyre-fixme[3]: Return type must be annotated.
82
+ # pyre-fixme[2]: Parameter must be annotated.
83
+ def find_all_linear_names(model):
84
+ cls = torch.nn.Linear
85
+ lora_module_names = set()
86
+ multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
87
+ for name, module in model.named_modules():
88
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
89
+ continue
90
+ if isinstance(module, cls):
91
+ names = name.split(".")
92
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
93
+
94
+ if "lm_head" in lora_module_names: # needed for 16-bit
95
+ lora_module_names.remove("lm_head")
96
+ return list(lora_module_names)
97
+
98
+
99
+ def safe_save_model_for_hf_trainer(
100
+ trainer: transformers.Trainer, output_dir: str
101
+ ) -> None:
102
+ """Collects the state dict and dump to disk."""
103
+ global_rank = dist.get_rank()
104
+ save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
105
+ # pyre-fixme[16]: `Trainer` has no attribute `args`.
106
+ if len(trainer.args.fsdp) == 0:
107
+ # pyre-fixme[16]: `Trainer` has no attribute `model`.
108
+ cpu_state_dict = trainer.model.state_dict()
109
+ else:
110
+ with FSDP.state_dict_type(
111
+ trainer.model, StateDictType.FULL_STATE_DICT, save_policy
112
+ ):
113
+ cpu_state_dict = trainer.model.state_dict()
114
+
115
+ for key in cpu_state_dict.keys():
116
+ cpu_state_dict[key] = cpu_state_dict[key].to(torch.bfloat16)
117
+
118
+ if global_rank == 0:
119
+ trainer.model.config.save_pretrained(output_dir)
120
+ current_folder = output_dir.split("/")[-1]
121
+ parent_folder = os.path.dirname(output_dir)
122
+ save_path = os.path.join(output_dir, "pytorch_model.bin")
123
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False) and not getattr(
124
+ trainer.args, "tune_text_decoder", False
125
+ ):
126
+ # Only save Adapter
127
+ keys_to_match = ["mm_projector"]
128
+ if getattr(trainer.args, "use_im_start_end", False):
129
+ keys_to_match.extend(["embed_tokens", "embed_in"])
130
+
131
+ freeze_layer_remove = []
132
+ for key in cpu_state_dict.keys():
133
+ remove = True
134
+ for key_match in keys_to_match:
135
+ if key_match in key:
136
+ remove = False
137
+ break
138
+ if remove:
139
+ freeze_layer_remove.append(key)
140
+ for key in freeze_layer_remove:
141
+ del cpu_state_dict[key]
142
+
143
+ if current_folder.startswith("checkpoint-"):
144
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
145
+ os.makedirs(mm_projector_folder, exist_ok=True)
146
+ save_path = os.path.join(mm_projector_folder, f"{current_folder}.bin")
147
+ else:
148
+ save_path = os.path.join(output_dir, f"mm_projector.bin")
149
+ torch.save(cpu_state_dict, save_path)
150
+
151
+
152
+ def smart_tokenizer_and_embedding_resize(
153
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
154
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
155
+ special_tokens_dict: Dict,
156
+ tokenizer: transformers.PreTrainedTokenizer,
157
+ model: transformers.PreTrainedModel,
158
+ ) -> None:
159
+ """Resize tokenizer and embedding.
160
+
161
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
162
+ """
163
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
164
+ # pyre-fixme[16]: `PreTrainedModel` has no attribute `resize_token_embeddings`.
165
+ model.resize_token_embeddings(len(tokenizer))
166
+
167
+ if num_new_tokens > 0:
168
+ # pyre-fixme[16]: `PreTrainedModel` has no attribute `get_input_embeddings`.
169
+ input_embeddings = model.get_input_embeddings().weight.data
170
+ # pyre-fixme[16]: `PreTrainedModel` has no attribute `get_output_embeddings`.
171
+ output_embeddings = model.get_output_embeddings().weight.data
172
+
173
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
174
+ dim=0, keepdim=True
175
+ )
176
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
177
+ dim=0, keepdim=True
178
+ )
179
+
180
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
181
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
182
+
183
+
184
+ def _tokenize_fn(
185
+ strings: Sequence[str],
186
+ tokenizer: transformers.PreTrainedTokenizer,
187
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
188
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
189
+ ) -> Dict:
190
+ """Tokenize a list of strings."""
191
+ tokenized_list = [
192
+ tokenizer(
193
+ text,
194
+ return_tensors="pt",
195
+ padding="longest",
196
+ max_length=tokenizer.model_max_length,
197
+ truncation=True,
198
+ )
199
+ for text in strings
200
+ ]
201
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
202
+ input_ids_lens = labels_lens = [
203
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
204
+ for tokenized in tokenized_list
205
+ ]
206
+ return dict(
207
+ input_ids=input_ids,
208
+ labels=labels,
209
+ input_ids_lens=input_ids_lens,
210
+ labels_lens=labels_lens,
211
+ )
212
+
213
+
214
+ # pyre-fixme[2]: Parameter must be annotated.
215
+ def _mask_targets(target, tokenized_lens, speakers) -> None:
216
+ # cur_idx = 0
217
+ cur_idx = tokenized_lens[0]
218
+ tokenized_lens = tokenized_lens[1:]
219
+ target[:cur_idx] = IGNORE_INDEX
220
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
221
+ if speaker == "human":
222
+ target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
223
+ cur_idx += tokenized_len
224
+
225
+
226
+ # pyre-fixme[3]: Return type must be annotated.
227
+ # pyre-fixme[2]: Parameter must be annotated.
228
+ def _add_speaker_and_signal(header, source, get_conversation: bool = True):
229
+ """Add speaker and start/end signal on each round."""
230
+ BEGIN_SIGNAL = "### "
231
+ END_SIGNAL = "\n"
232
+ conversation = header
233
+ for sentence in source:
234
+ from_str = sentence["from"]
235
+ if from_str.lower() == "human":
236
+ from_str = conversation_lib.default_conversation.roles[0]
237
+ elif from_str.lower() == "gpt":
238
+ from_str = conversation_lib.default_conversation.roles[1]
239
+ else:
240
+ from_str = "unknown"
241
+ sentence["value"] = (
242
+ BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
243
+ )
244
+ if get_conversation:
245
+ conversation += sentence["value"]
246
+ conversation += BEGIN_SIGNAL
247
+ return conversation
248
+
249
+
250
+ # pyre-fixme[3]: Return type must be annotated.
251
+ # pyre-fixme[2]: Parameter must be annotated.
252
+ def expand2square(pil_img, background_color):
253
+ width, height = pil_img.size
254
+ if width == height:
255
+ return pil_img
256
+ elif width > height:
257
+ result = Image.new(pil_img.mode, (width, width), background_color)
258
+ result.paste(pil_img, (0, (width - height) // 2))
259
+ return result
260
+ else:
261
+ result = Image.new(pil_img.mode, (height, height), background_color)
262
+ result.paste(pil_img, ((height - width) // 2, 0))
263
+ return result
264
+
265
+
266
+ # pyre-fixme[3]: Return type must be annotated.
267
+ # pyre-fixme[2]: Parameter must be annotated.
268
+ def process_images(images, image_processor, model_cfg):
269
+ if isinstance(image_processor, list):
270
+ processor_aux_list = image_processor
271
+ new_images_aux_list = []
272
+ for image in images:
273
+ if isinstance(image, np.ndarray):
274
+ image = Image.fromarray(image)
275
+ image_aux_list = []
276
+ for processor_aux in processor_aux_list:
277
+ image_aux = image
278
+ if hasattr(processor_aux, "image_mean"):
279
+ try:
280
+ target_resolution = processor_aux.crop_size["height"]
281
+ except:
282
+ target_resolution = processor_aux.size["height"]
283
+ image_aux = expand2square(
284
+ image_aux, tuple(int(x * 255) for x in processor_aux.image_mean)
285
+ ).resize((target_resolution, target_resolution))
286
+ image_aux = processor_aux.preprocess(image_aux, return_tensors="pt")[
287
+ "pixel_values"
288
+ ][0]
289
+ image_aux_list.append(image_aux)
290
+ new_images_aux_list.append(image_aux_list)
291
+ new_images_aux_list = [
292
+ list(batch_image_aux) for batch_image_aux in zip(*new_images_aux_list)
293
+ ]
294
+ new_images_aux_list = [
295
+ torch.stack(image_aux).half().cuda() for image_aux in new_images_aux_list
296
+ ]
297
+ return new_images_aux_list
298
+ else:
299
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
300
+ new_images = []
301
+ if image_aspect_ratio == "pad":
302
+ for image in images:
303
+ image = expand2square(
304
+ image, tuple(int(x * 255) for x in image_processor.image_mean)
305
+ )
306
+ image = image_processor.preprocess(image, return_tensors="pt")[
307
+ "pixel_values"
308
+ ][0]
309
+ new_images.append(image)
310
+ else:
311
+ return image_processor(images, return_tensors="pt")["pixel_values"]
312
+ if all(x.shape == new_images[0].shape for x in new_images):
313
+ new_images = torch.stack(new_images, dim=0)
314
+ return new_images
315
+
316
+
317
+ # pyre-fixme[2]: Parameter must be annotated.
318
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
319
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
320
+ def preprocess_multimodal(sources: Sequence[str], data_args) -> Dict:
321
+ is_multimodal = data_args.is_multimodal
322
+ if not is_multimodal:
323
+ # pyre-fixme[7]: Expected `Dict[typing.Any, typing.Any]` but got
324
+ # `Sequence[str]`.
325
+ return sources
326
+
327
+ for source in sources:
328
+ for sentence in source:
329
+ if (
330
+ # pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
331
+ # but got `str`.
332
+ DEFAULT_IMAGE_TOKEN in sentence["value"]
333
+ # pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
334
+ # but got `str`.
335
+ or "<video>" in sentence["value"]
336
+ ):
337
+ # pyre-fixme[16]: `str` has no attribute `__setitem__`.
338
+ sentence["value"] = (
339
+ # pyre-fixme[6]: For 1st argument expected `Union[slice,
340
+ # SupportsIndex]` but got `str`.
341
+ sentence["value"]
342
+ .replace(DEFAULT_IMAGE_TOKEN, "")
343
+ .replace("<video>", "")
344
+ .strip()
345
+ )
346
+ # pyre-fixme[6]: For 1st argument expected `Union[slice,
347
+ # SupportsIndex]` but got `str`.
348
+ sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
349
+ # pyre-fixme[6]: For 1st argument expected `Union[slice,
350
+ # SupportsIndex]` but got `str`.
351
+ sentence["value"] = sentence["value"].strip()
352
+ if "mmtag" in conversation_lib.default_conversation.version:
353
+ # pyre-fixme[6]: For 1st argument expected `Union[slice,
354
+ # SupportsIndex]` but got `str`.
355
+ sentence["value"] = sentence["value"].replace(
356
+ DEFAULT_IMAGE_TOKEN,
357
+ "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>",
358
+ )
359
+ replace_token = DEFAULT_IMAGE_TOKEN
360
+ if data_args.mm_use_im_start_end:
361
+ replace_token = (
362
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
363
+ )
364
+ # pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
365
+ # but got `str`.
366
+ sentence["value"] = sentence["value"].replace(
367
+ DEFAULT_IMAGE_TOKEN, replace_token
368
+ )
369
+
370
+ # pyre-fixme[7]: Expected `Dict[typing.Any, typing.Any]` but got `Sequence[str]`.
371
+ return sources
372
+
373
+
374
+ def preprocess_llama_2(
375
+ # pyre-fixme[2]: Parameter must be annotated.
376
+ sources,
377
+ tokenizer: transformers.PreTrainedTokenizer,
378
+ has_image: bool = False,
379
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
380
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
381
+ ) -> Dict:
382
+ conv = conversation_lib.default_conversation.copy()
383
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
384
+
385
+ # Apply prompt templates
386
+ conversations = []
387
+ for i, source in enumerate(sources):
388
+ if roles[source[0]["from"]] != conv.roles[0]:
389
+ # Skip the first one if it is not from human
390
+ source = source[1:]
391
+
392
+ conv.messages = []
393
+ for j, sentence in enumerate(source):
394
+ role = roles[sentence["from"]]
395
+ assert role == conv.roles[j % 2], f"{i}"
396
+ conv.append_message(role, sentence["value"])
397
+ conversations.append(conv.get_prompt())
398
+
399
+ # Tokenize conversations
400
+
401
+ if has_image:
402
+ input_ids = torch.stack(
403
+ [
404
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
405
+ for prompt in conversations
406
+ ],
407
+ dim=0,
408
+ )
409
+ else:
410
+ input_ids = tokenizer(
411
+ conversations,
412
+ return_tensors="pt",
413
+ padding="longest",
414
+ max_length=tokenizer.model_max_length,
415
+ truncation=True,
416
+ ).input_ids
417
+
418
+ targets = input_ids.clone()
419
+
420
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
421
+
422
+ # Mask targets
423
+ sep = "[/INST] "
424
+ for conversation, target in zip(conversations, targets):
425
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
426
+
427
+ rounds = conversation.split(conv.sep2)
428
+ cur_len = 1
429
+ target[:cur_len] = IGNORE_INDEX
430
+ for i, rou in enumerate(rounds):
431
+ if rou == "":
432
+ break
433
+
434
+ parts = rou.split(sep)
435
+ if len(parts) != 2:
436
+ break
437
+ parts[0] += sep
438
+
439
+ if has_image:
440
+ round_len = len(tokenizer_image_token(rou, tokenizer))
441
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
442
+ else:
443
+ round_len = len(tokenizer(rou).input_ids)
444
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
445
+
446
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
447
+
448
+ cur_len += round_len
449
+ target[cur_len:] = IGNORE_INDEX
450
+
451
+ if cur_len < tokenizer.model_max_length:
452
+ if cur_len != total_len:
453
+ target[:] = IGNORE_INDEX
454
+ print(
455
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
456
+ f" (ignored)"
457
+ )
458
+
459
+ return dict(
460
+ input_ids=input_ids,
461
+ labels=targets,
462
+ )
463
+
464
+
465
+ def preprocess_v1(
466
+ # pyre-fixme[2]: Parameter must be annotated.
467
+ sources,
468
+ tokenizer: transformers.PreTrainedTokenizer,
469
+ has_image: bool = False,
470
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
471
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
472
+ ) -> Dict:
473
+ conv = conversation_lib.default_conversation.copy()
474
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
475
+
476
+ # Apply prompt templates
477
+ conversations = []
478
+ for i, source in enumerate(sources):
479
+ if roles[source[0]["from"]] != conv.roles[0]:
480
+ # Skip the first one if it is not from human
481
+ source = source[1:]
482
+
483
+ conv.messages = []
484
+ for j, sentence in enumerate(source):
485
+ role = roles[sentence["from"]]
486
+ assert role == conv.roles[j % 2], f"{i}"
487
+ conv.append_message(role, sentence["value"])
488
+ conversations.append(conv.get_prompt())
489
+
490
+ # Tokenize conversations
491
+
492
+ if has_image:
493
+ input_ids = torch.stack(
494
+ [
495
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
496
+ for prompt in conversations
497
+ ],
498
+ dim=0,
499
+ )
500
+ else:
501
+ input_ids = tokenizer(
502
+ conversations,
503
+ return_tensors="pt",
504
+ padding="longest",
505
+ max_length=tokenizer.model_max_length,
506
+ truncation=True,
507
+ ).input_ids
508
+
509
+ targets = input_ids.clone()
510
+
511
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
512
+
513
+ # Mask targets
514
+ sep = conv.sep + conv.roles[1] + ": "
515
+ for conversation, target in zip(conversations, targets):
516
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
517
+
518
+ rounds = conversation.split(conv.sep2)
519
+ cur_len = 1
520
+ target[:cur_len] = IGNORE_INDEX
521
+ for i, rou in enumerate(rounds):
522
+ if rou == "":
523
+ break
524
+
525
+ parts = rou.split(sep)
526
+ if len(parts) != 2:
527
+ break
528
+ parts[0] += sep
529
+
530
+ if has_image:
531
+ round_len = len(tokenizer_image_token(rou, tokenizer))
532
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
533
+ else:
534
+ round_len = len(tokenizer(rou).input_ids)
535
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
536
+ # pyre-fixme
537
+ if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
538
+ round_len -= 1
539
+ instruction_len -= 1
540
+
541
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
542
+
543
+ cur_len += round_len
544
+ target[cur_len:] = IGNORE_INDEX
545
+
546
+ if cur_len < tokenizer.model_max_length:
547
+ if cur_len != total_len:
548
+ target[:] = IGNORE_INDEX
549
+ print(
550
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
551
+ f" (ignored)"
552
+ )
553
+
554
+ return dict(
555
+ input_ids=input_ids,
556
+ labels=targets,
557
+ )
558
+
559
+
560
+ # pyre-fixme[3]: Return type must be annotated.
561
+ def tokenizer_image_token(
562
+ # pyre-fixme[2]: Parameter must be annotated.
563
+ prompt,
564
+ # pyre-fixme[2]: Parameter must be annotated.
565
+ tokenizer,
566
+ # pyre-fixme[2]: Parameter must be annotated.
567
+ image_token_index=IMAGE_TOKEN_INDEX,
568
+ # pyre-fixme[2]: Parameter must be annotated.
569
+ return_tensors=None,
570
+ ):
571
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
572
+
573
+ # pyre-fixme[3]: Return type must be annotated.
574
+ # pyre-fixme[2]: Parameter must be annotated.
575
+ def insert_separator(X, sep):
576
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
577
+
578
+ input_ids = []
579
+ offset = 0
580
+ if (
581
+ len(prompt_chunks) > 0
582
+ and len(prompt_chunks[0]) > 0
583
+ and prompt_chunks[0][0] == tokenizer.bos_token_id
584
+ ):
585
+ offset = 1
586
+ input_ids.append(prompt_chunks[0][0])
587
+
588
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
589
+ input_ids.extend(x[offset:])
590
+
591
+ if return_tensors is not None:
592
+ if return_tensors == "pt":
593
+ return torch.tensor(input_ids, dtype=torch.long)
594
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
595
+ return input_ids
596
+
597
+
598
+ # pyre-fixme[3]: Return type must be annotated.
599
+ def tokenizer_image_token_llama3(
600
+ # pyre-fixme[2]: Parameter must be annotated.
601
+ prompt,
602
+ # pyre-fixme[2]: Parameter must be annotated.
603
+ tokenizer,
604
+ # pyre-fixme[2]: Parameter must be annotated.
605
+ image_token_index=IMAGE_TOKEN_INDEX,
606
+ # pyre-fixme[2]: Parameter must be annotated.
607
+ return_tensors=None,
608
+ ):
609
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
610
+
611
+ # pyre-fixme[3]: Return type must be annotated.
612
+ # pyre-fixme[2]: Parameter must be annotated.
613
+ def insert_separator(X, sep):
614
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
615
+
616
+ input_ids = []
617
+ for x in insert_separator(prompt_chunks, [image_token_index]):
618
+ input_ids.extend(x)
619
+
620
+ if return_tensors is not None:
621
+ if return_tensors == "pt":
622
+ return torch.tensor(input_ids, dtype=torch.long)
623
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
624
+ return input_ids
625
+
626
+
627
+ def preprocess_qwen(
628
+ # pyre-fixme[2]: Parameter must be annotated.
629
+ sources,
630
+ tokenizer: transformers.PreTrainedTokenizer,
631
+ has_image: bool = False,
632
+ system_message: str = "You are a helpful assistant.",
633
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
634
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
635
+ ) -> Dict:
636
+ # roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
637
+ roles = {"human": "user", "gpt": "assistant"}
638
+
639
+ # Add image tokens to tokenizer as a special tokens
640
+ # Use a deepcopy of tokenizer so that we don't modify on the tokenizer
641
+ tokenizer = copy.deepcopy(tokenizer)
642
+ # When there is actually an image, we add the image tokens as a special token
643
+ if has_image:
644
+ tokenizer.add_tokens(["<image>"], special_tokens=True)
645
+
646
+ image_token_index = tokenizer.convert_tokens_to_ids("<image>")
647
+ im_start, im_end = tokenizer.additional_special_tokens_ids
648
+ # unmask_tokens = ["<|im_start|>", "<|im_start|>", "\n"]
649
+ unmask_tokens_idx = [198, im_start, im_end]
650
+ nl_tokens = tokenizer("\n").input_ids
651
+
652
+ # Reset Qwen chat templates so that it won't include system message every time we apply
653
+ chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
654
+ tokenizer.chat_template = chat_template
655
+
656
+ # _system = tokenizer("system").input_ids + nl_tokens
657
+ # _user = tokenizer("user").input_ids + nl_tokens
658
+ # _assistant = tokenizer("assistant").input_ids + nl_tokens
659
+
660
+ # Apply prompt templates
661
+ input_ids, targets = [], []
662
+ for i, source in enumerate(sources):
663
+ if roles[source[0]["from"]] != roles["human"]:
664
+ source = source[1:]
665
+
666
+ input_id, target = [], []
667
+
668
+ # New version, use apply chat template
669
+ # Build system message for each sentence
670
+ input_id += tokenizer.apply_chat_template(
671
+ [{"role": "system", "content": system_message}]
672
+ )
673
+ target += [IGNORE_INDEX] * len(input_id)
674
+
675
+ for conv in source:
676
+ # Make sure llava data can load
677
+ try:
678
+ role = conv["role"]
679
+ content = conv["content"]
680
+ except:
681
+ role = conv["from"]
682
+ content = conv["value"]
683
+
684
+ role = roles.get(role, role)
685
+
686
+ conv = [{"role": role, "content": content}]
687
+ encode_id = tokenizer.apply_chat_template(conv)
688
+ input_id += encode_id
689
+ if role in ["user", "system"]:
690
+ target += [IGNORE_INDEX] * len(encode_id)
691
+ else:
692
+ target += encode_id
693
+
694
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
695
+ for idx, encode_id in enumerate(input_id):
696
+ if encode_id in unmask_tokens_idx:
697
+ target[idx] = encode_id
698
+ if encode_id == image_token_index:
699
+ input_id[idx] = IMAGE_TOKEN_INDEX
700
+ input_ids.append(input_id)
701
+ targets.append(target)
702
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
703
+ targets = torch.tensor(targets, dtype=torch.long)
704
+
705
+ return dict(
706
+ input_ids=input_ids, # tensor(bs x seq_len)
707
+ labels=targets, # tensor(bs x seq_len)
708
+ )
709
+
710
+
711
+ def preprocess_llama3(
712
+ # pyre-fixme[2]: Parameter must be annotated.
713
+ sources,
714
+ tokenizer: transformers.PreTrainedTokenizer,
715
+ has_image: bool = False,
716
+ system_message: str = "You are a helpful assistant.",
717
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
718
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
719
+ ) -> Dict:
720
+ # roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
721
+ roles = {"human": "user", "gpt": "assistant"}
722
+
723
+ # Add image tokens to tokenizer as a special tokens
724
+ # Use a deepcopy of tokenizer so that we don't modify on the tokenizer
725
+ tokenizer = copy.deepcopy(tokenizer)
726
+ # When there is actually an image, we add the image tokens as a special token
727
+ if has_image:
728
+ tokenizer.add_tokens(["<image>"], special_tokens=True)
729
+ image_token_index = tokenizer.convert_tokens_to_ids("<image>")
730
+ bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
731
+ start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
732
+ end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
733
+ eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
734
+
735
+ unmask_tokens = [
736
+ "<|begin_of_text|>",
737
+ "<|start_header_id|>",
738
+ "<|end_header_id|>",
739
+ "<|eot_id|>",
740
+ "\n\n",
741
+ ]
742
+ unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens]
743
+
744
+ # After update, calling tokenizer of llama3 will
745
+ # auto add bos id for the tokens. ヽ(`⌒´)ノ
746
+ # pyre-fixme[53]: Captured variable `bos_token_id` is not annotated.
747
+ # pyre-fixme[3]: Return type must be annotated.
748
+ # pyre-fixme[2]: Parameter must be annotated.
749
+ def safe_tokenizer_llama3(text):
750
+ input_ids = tokenizer(text).input_ids
751
+ if input_ids[0] == bos_token_id:
752
+ input_ids = input_ids[1:]
753
+ return input_ids
754
+
755
+ nl_tokens = tokenizer.convert_tokens_to_ids("\n\n")
756
+
757
+ # chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{%- if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{%- endif %}"
758
+ chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
759
+ tokenizer.chat_template = chat_template
760
+
761
+ # Apply prompt templates
762
+ input_ids, targets = [], []
763
+ for i, source in enumerate(sources):
764
+ if roles[source[0]["from"]] != roles["human"]:
765
+ source = source[1:]
766
+
767
+ input_id, target = [], []
768
+
769
+ # New version, use apply chat template
770
+ # Build system message for each sentence
771
+ input_id += tokenizer.apply_chat_template(
772
+ [{"role": "system", "content": system_message}]
773
+ # pyre-fixme[6]: For 1st argument expected `Union[int, str]` but got `slice`.
774
+ )[:-4]
775
+
776
+ target += [IGNORE_INDEX] * len(input_id)
777
+
778
+ for conv in source:
779
+ # Make sure llava data can load
780
+ try:
781
+ role = conv["role"]
782
+ content = conv["content"]
783
+ except:
784
+ role = conv["from"]
785
+ content = conv["value"]
786
+
787
+ role = roles.get(role, role)
788
+
789
+ conv = [{"role": role, "content": content}]
790
+ # First is bos token we don't need here
791
+ # pyre-fixme[6]: For 1st argument expected `Union[int, str]` but got
792
+ # `slice`.
793
+ encode_id = tokenizer.apply_chat_template(conv)[1:-4]
794
+ input_id += encode_id
795
+ if role in ["user", "system"]:
796
+ target += [IGNORE_INDEX] * len(encode_id)
797
+ else:
798
+ target += encode_id
799
+
800
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
801
+ for idx, encode_id in enumerate(input_id):
802
+ if encode_id in unmask_tokens_idx:
803
+ target[idx] = encode_id
804
+ if encode_id == image_token_index:
805
+ input_id[idx] = IMAGE_TOKEN_INDEX
806
+ input_ids.append(input_id)
807
+ targets.append(target)
808
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
809
+ targets = torch.tensor(targets, dtype=torch.long)
810
+
811
+ print("input_ids", input_ids, flush=True)
812
+ print("targets", targets, flush=True)
813
+
814
+ return dict(
815
+ input_ids=input_ids, # tensor(bs x seq_len)
816
+ labels=targets, # tensor(bs x seq_len)
817
+ )
818
+
819
+
820
+ def preprocess_llama_3_1(
821
+ # pyre-fixme[2]: Parameter must be annotated.
822
+ sources,
823
+ tokenizer: transformers.PreTrainedTokenizer,
824
+ has_image: bool = False,
825
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
826
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
827
+ ) -> Dict:
828
+ conv = conversation_lib.default_conversation.copy()
829
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
830
+
831
+ # Apply prompt templates
832
+ conversations = []
833
+ for i, source in enumerate(sources):
834
+ if roles[source[0]["from"]] != conv.roles[0]:
835
+ # Skip the first one if it is not from human
836
+ source = source[1:]
837
+
838
+ conv.messages = []
839
+ for j, sentence in enumerate(source):
840
+ if sentence["from"] == "Answer":
841
+ sentence["from"] = "gpt" # data bug
842
+ role = roles[sentence["from"]]
843
+ # assert role == conv.roles[j % 2], f"{i}"
844
+ conv.append_message(role, sentence["value"])
845
+ conversations.append(conv.get_prompt())
846
+
847
+ # Tokenize conversations
848
+
849
+ if has_image:
850
+ input_ids = torch.stack(
851
+ [
852
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
853
+ for prompt in conversations
854
+ ],
855
+ dim=0,
856
+ )
857
+ else:
858
+ input_ids = tokenizer(
859
+ conversations,
860
+ return_tensors="pt",
861
+ padding="longest",
862
+ max_length=tokenizer.model_max_length,
863
+ truncation=True,
864
+ ).input_ids
865
+
866
+ # remove the first bos token
867
+ if input_ids[0][0] == input_ids[0][1] == tokenizer.bos_token_id:
868
+ input_ids = input_ids[:, 1:]
869
+ targets = input_ids.clone()
870
+
871
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3_1
872
+
873
+ # Mask targets
874
+ sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>" + "\n\n"
875
+ # sep = conv.sep + conv.roles[1] + ": "
876
+ for conversation, target in zip(conversations, targets):
877
+ total_len = int(target.shape[0])
878
+
879
+ rounds = conversation.split(conv.tokenizer.eos_token)
880
+ rounds = [rounds[0]] + [
881
+ rounds[idx] + rounds[idx + 1] for idx in range(1, len(rounds) - 1, 2)
882
+ ]
883
+
884
+ cur_len = 1
885
+ target[:cur_len] = IGNORE_INDEX
886
+ for i, rou in enumerate(rounds):
887
+ if rou == "":
888
+ break
889
+
890
+ parts = rou.split(sep)
891
+ if len(parts) != 2 and i != 0:
892
+ break
893
+
894
+ if i == 0:
895
+ round_len = len(tokenizer(rou, add_special_tokens=False).input_ids)
896
+ instruction_len = len(
897
+ tokenizer(rou, add_special_tokens=False).input_ids
898
+ )
899
+
900
+ else:
901
+ parts[0] += sep
902
+ if has_image:
903
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + 1
904
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
905
+ else:
906
+ round_len = len(tokenizer(rou).input_ids) + 1
907
+ instruction_len = len(tokenizer(parts[0]).input_ids)
908
+
909
+ # if i > 0: round_len += 1
910
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
911
+ cur_len += round_len
912
+
913
+ target[cur_len:] = IGNORE_INDEX
914
+ cur_len = cur_len + len(tokenizer(sep, add_special_tokens=False).input_ids)
915
+
916
+ # if cur_len > tokenizer.model_max_length: print(f"WARNING: max length context")
917
+ if cur_len < tokenizer.model_max_length:
918
+ if cur_len != total_len:
919
+ target[:] = IGNORE_INDEX
920
+ print(
921
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
922
+ f" (ignored)"
923
+ )
924
+
925
+ return dict(
926
+ input_ids=input_ids,
927
+ labels=targets,
928
+ )
929
+
930
+
931
+ def preprocess_llama_3_2(
932
+ # pyre-fixme[2]: Parameter must be annotated.
933
+ sources,
934
+ tokenizer: transformers.PreTrainedTokenizer,
935
+ has_image: bool = False,
936
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
937
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
938
+ ) -> Dict:
939
+ conv = conversation_lib.default_conversation.copy()
940
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
941
+
942
+ # Apply prompt templates
943
+ conversations = []
944
+ for i, source in enumerate(sources):
945
+ if roles[source[0]["from"]] != conv.roles[0]:
946
+ # Skip the first one if it is not from human
947
+ source = source[1:]
948
+
949
+ conv.messages = []
950
+ for j, sentence in enumerate(source):
951
+ role = roles[sentence["from"]]
952
+ assert role == conv.roles[j % 2], f"{i}"
953
+ conv.append_message(role, sentence["value"])
954
+ conversations.append(conv.get_prompt())
955
+
956
+ # Tokenize conversations
957
+
958
+ if has_image:
959
+ input_ids = torch.stack(
960
+ [
961
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
962
+ for prompt in conversations
963
+ ],
964
+ dim=0,
965
+ )
966
+ else:
967
+ input_ids = tokenizer(
968
+ conversations,
969
+ return_tensors="pt",
970
+ padding="longest",
971
+ max_length=tokenizer.model_max_length,
972
+ truncation=True,
973
+ ).input_ids
974
+
975
+ # remove the first bos token
976
+ if input_ids[0][0] == input_ids[0][1] == tokenizer.bos_token_id:
977
+ input_ids = input_ids[:, 1:]
978
+ targets = input_ids.clone()
979
+
980
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3_2
981
+
982
+ # Mask targets
983
+ sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>" + "\n\n"
984
+ # sep = conv.sep + conv.roles[1] + ": "
985
+ for conversation, target in zip(conversations, targets):
986
+ total_len = int(target.shape[0])
987
+
988
+ rounds = conversation.split(conv.tokenizer.eos_token)
989
+ rounds = [rounds[0]] + [
990
+ rounds[idx] + rounds[idx + 1] for idx in range(1, len(rounds) - 1, 2)
991
+ ]
992
+
993
+ cur_len = 1
994
+ target[:cur_len] = IGNORE_INDEX
995
+ for i, rou in enumerate(rounds):
996
+ if rou == "":
997
+ break
998
+
999
+ parts = rou.split(sep)
1000
+ if len(parts) != 2 and i != 0:
1001
+ break
1002
+
1003
+ if i == 0:
1004
+ round_len = len(tokenizer(rou, add_special_tokens=False).input_ids)
1005
+ instruction_len = len(
1006
+ tokenizer(rou, add_special_tokens=False).input_ids
1007
+ )
1008
+
1009
+ else:
1010
+ parts[0] += sep
1011
+ if has_image:
1012
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + 1
1013
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
1014
+ else:
1015
+ round_len = len(tokenizer(rou).input_ids) + 1
1016
+ instruction_len = len(tokenizer(parts[0]).input_ids)
1017
+
1018
+ # if i > 0: round_len += 1
1019
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
1020
+ cur_len += round_len
1021
+
1022
+ target[cur_len:] = IGNORE_INDEX
1023
+ cur_len = cur_len + len(tokenizer(sep, add_special_tokens=False).input_ids)
1024
+
1025
+ # if cur_len > tokenizer.model_max_length: print(f"WARNING: max length context")
1026
+ if cur_len < tokenizer.model_max_length:
1027
+ if cur_len != total_len:
1028
+ target[:] = IGNORE_INDEX
1029
+ print(
1030
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
1031
+ f" (ignored)"
1032
+ )
1033
+
1034
+ return dict(
1035
+ input_ids=input_ids,
1036
+ labels=targets,
1037
+ )
1038
+
1039
+
1040
+ def preprocess_phi3(
1041
+ # pyre-fixme[2]: Parameter must be annotated.
1042
+ sources,
1043
+ tokenizer: transformers.PreTrainedTokenizer,
1044
+ has_image: bool = False,
1045
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
1046
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
1047
+ ) -> Dict:
1048
+ conv = conversation_lib.conv_templates["phi3"].copy()
1049
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
1050
+
1051
+ # Apply prompt templates
1052
+ conversations = []
1053
+ for i, source in enumerate(sources):
1054
+ if roles[source[0]["from"]] != conv.roles[0]:
1055
+ # Skip the first one if it is not from human
1056
+ source = source[1:]
1057
+
1058
+ conv.messages = []
1059
+ for j, sentence in enumerate(source):
1060
+ role = roles[sentence["from"]]
1061
+ assert role == conv.roles[j % 2], f"{i}"
1062
+ conv.append_message(role, sentence["value"])
1063
+ conversations.append(conv.get_prompt())
1064
+
1065
+ # Tokenize conversations
1066
+ if has_image:
1067
+ input_ids = torch.stack(
1068
+ [
1069
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
1070
+ for prompt in conversations
1071
+ ],
1072
+ dim=0,
1073
+ )
1074
+ else:
1075
+ input_ids = tokenizer(
1076
+ conversations,
1077
+ return_tensors="pt",
1078
+ padding="longest",
1079
+ max_length=tokenizer.model_max_length,
1080
+ truncation=True,
1081
+ ).input_ids
1082
+
1083
+ targets = input_ids.clone()
1084
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
1085
+
1086
+ # Mask targets
1087
+ sep = conv.sep + conv.roles[1]
1088
+ for conversation, target in zip(conversations, targets):
1089
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
1090
+
1091
+ rounds = conversation.split(conv.sep)
1092
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
1093
+ for conv_idx in range(3, len(rounds), 2):
1094
+ re_rounds.append(
1095
+ conv.sep.join(rounds[conv_idx : conv_idx + 2])
1096
+ ) # user + gpt
1097
+ cur_len = 0
1098
+ target[:cur_len] = IGNORE_INDEX
1099
+ for i, rou in enumerate(re_rounds):
1100
+ if rou == "":
1101
+ break
1102
+
1103
+ parts = rou.split(sep)
1104
+ if len(parts) != 2:
1105
+ break
1106
+ parts[0] += sep
1107
+
1108
+ if has_image:
1109
+ round_len = len(tokenizer_image_token(rou, tokenizer))
1110
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
1111
+ else:
1112
+ round_len = len(tokenizer(rou).input_ids)
1113
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
1114
+
1115
+ if i == 0:
1116
+ round_len += 1
1117
+ instruction_len += 1
1118
+ else:
1119
+ round_len -= 2
1120
+ instruction_len -= 2
1121
+
1122
+ if (
1123
+ i != 0
1124
+ and getattr(tokenizer, "legacy", False)
1125
+ and IS_TOKENIZER_GREATER_THAN_0_14
1126
+ ):
1127
+ round_len += 1
1128
+ instruction_len += 1
1129
+
1130
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
1131
+
1132
+ cur_len += round_len
1133
+ target[cur_len:] = IGNORE_INDEX
1134
+
1135
+ if cur_len < tokenizer.model_max_length:
1136
+ if cur_len != total_len:
1137
+ target[:] = IGNORE_INDEX
1138
+ print(
1139
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
1140
+ f" (ignored)"
1141
+ )
1142
+
1143
+ return dict(
1144
+ input_ids=input_ids,
1145
+ labels=targets,
1146
+ )
1147
+
1148
+
1149
+ def preprocess_mpt(
1150
+ # pyre-fixme[2]: Parameter must be annotated.
1151
+ sources,
1152
+ tokenizer: transformers.PreTrainedTokenizer,
1153
+ has_image: bool = False,
1154
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
1155
+ ) -> Dict:
1156
+ conv = conversation_lib.default_conversation.copy()
1157
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
1158
+
1159
+ # Apply prompt templates
1160
+ conversations = []
1161
+ for i, source in enumerate(sources):
1162
+ if roles[source[0]["from"]] != conv.roles[0]:
1163
+ # Skip the first one if it is not from human
1164
+ source = source[1:]
1165
+
1166
+ conv.messages = []
1167
+ for j, sentence in enumerate(source):
1168
+ role = roles[sentence["from"]]
1169
+ assert role == conv.roles[j % 2], f"{i}"
1170
+ conv.append_message(role, sentence["value"])
1171
+ conversations.append(conv.get_prompt())
1172
+
1173
+ # Tokenize conversations
1174
+ if has_image:
1175
+ input_ids = torch.stack(
1176
+ [
1177
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
1178
+ for prompt in conversations
1179
+ ],
1180
+ dim=0,
1181
+ )
1182
+ else:
1183
+ input_ids = tokenizer(
1184
+ conversations,
1185
+ return_tensors="pt",
1186
+ padding="longest",
1187
+ max_length=tokenizer.model_max_length,
1188
+ truncation=True,
1189
+ ).input_ids
1190
+
1191
+ targets = input_ids.clone()
1192
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
1193
+
1194
+ # Mask targets
1195
+ sep = conv.sep + conv.roles[1]
1196
+ for conversation, target in zip(conversations, targets):
1197
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
1198
+
1199
+ rounds = conversation.split(conv.sep)
1200
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
1201
+ for conv_idx in range(3, len(rounds), 2):
1202
+ re_rounds.append(
1203
+ conv.sep.join(rounds[conv_idx : conv_idx + 2])
1204
+ ) # user + gpt
1205
+ cur_len = 0
1206
+ target[:cur_len] = IGNORE_INDEX
1207
+ for i, rou in enumerate(re_rounds):
1208
+ if rou == "":
1209
+ break
1210
+
1211
+ parts = rou.split(sep)
1212
+ if len(parts) != 2:
1213
+ break
1214
+ parts[0] += sep
1215
+ if has_image:
1216
+ round_len = len(tokenizer_image_token(rou, tokenizer))
1217
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
1218
+ else:
1219
+ round_len = len(tokenizer(rou).input_ids)
1220
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
1221
+
1222
+ if (
1223
+ i != 0
1224
+ and getattr(tokenizer, "legacy", False)
1225
+ and IS_TOKENIZER_GREATER_THAN_0_14
1226
+ ):
1227
+ round_len += 1
1228
+ instruction_len += 1
1229
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
1230
+
1231
+ cur_len += round_len
1232
+ target[cur_len:] = IGNORE_INDEX
1233
+
1234
+ if cur_len < tokenizer.model_max_length:
1235
+ if cur_len != total_len:
1236
+ target[:] = IGNORE_INDEX
1237
+ print(
1238
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
1239
+ f" (ignored)"
1240
+ )
1241
+
1242
+ return dict(
1243
+ input_ids=input_ids,
1244
+ labels=targets,
1245
+ )
1246
+
1247
+
1248
+ def preprocess_plain(
1249
+ sources: Sequence[str],
1250
+ tokenizer: transformers.PreTrainedTokenizer,
1251
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
1252
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
1253
+ ) -> Dict:
1254
+ # add end signal and concatenate together
1255
+ conversations = []
1256
+ for source in sources:
1257
+ assert len(source) == 2
1258
+ # pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]` but
1259
+ # got `str`.
1260
+ assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
1261
+ # pyre-fixme[16]: `str` has no attribute `__setitem__`.
1262
+ source[0]["value"] = DEFAULT_IMAGE_TOKEN
1263
+ conversation = (
1264
+ # pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
1265
+ # but got `str`.
1266
+ source[0]["value"]
1267
+ # pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
1268
+ # but got `str`.
1269
+ + source[1]["value"]
1270
+ + conversation_lib.default_conversation.sep
1271
+ )
1272
+ conversations.append(conversation)
1273
+ # tokenize conversations
1274
+ input_ids = [
1275
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
1276
+ for prompt in conversations
1277
+ ]
1278
+ targets = copy.deepcopy(input_ids)
1279
+ for target, source in zip(targets, sources):
1280
+ # pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]` but
1281
+ # got `str`.
1282
+ tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
1283
+ target[:tokenized_len] = IGNORE_INDEX
1284
+
1285
+ return dict(input_ids=input_ids, labels=targets)
1286
+
1287
+
1288
+ def preprocess(
1289
+ sources: Sequence[str],
1290
+ tokenizer: transformers.PreTrainedTokenizer,
1291
+ has_image: bool = False,
1292
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
1293
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
1294
+ ) -> Dict:
1295
+ """
1296
+ Given a list of sources, each is a conversation list. This transform:
1297
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
1298
+ 2. Concatenate conversations together;
1299
+ 3. Tokenize the concatenated conversation;
1300
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
1301
+ """
1302
+ if (
1303
+ conversation_lib.default_conversation.sep_style
1304
+ == conversation_lib.SeparatorStyle.PLAIN
1305
+ ):
1306
+ return preprocess_plain(sources, tokenizer)
1307
+ if (
1308
+ conversation_lib.default_conversation.sep_style
1309
+ == conversation_lib.SeparatorStyle.LLAMA_2
1310
+ ):
1311
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
1312
+ if conversation_lib.default_conversation.version.startswith("v1"):
1313
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
1314
+ if conversation_lib.default_conversation.version == "mpt":
1315
+ return preprocess_mpt(sources, tokenizer, has_image=has_image)
1316
+ if conversation_lib.default_conversation.version == "llama3":
1317
+ return preprocess_llama3(sources, tokenizer, has_image=has_image)
1318
+ if conversation_lib.default_conversation.version == "llama3_1":
1319
+ return preprocess_llama_3_1(sources, tokenizer, has_image=has_image)
1320
+ if conversation_lib.default_conversation.version == "llama3_2":
1321
+ return preprocess_llama_3_2(sources, tokenizer, has_image=has_image)
1322
+ if conversation_lib.default_conversation.version == "phi3":
1323
+ return preprocess_phi3(sources, tokenizer, has_image=has_image)
1324
+ if conversation_lib.default_conversation.version == "qwen":
1325
+ return preprocess_qwen(sources, tokenizer, has_image=has_image)
1326
+ # add end signal and concatenate together
1327
+ conversations = []
1328
+ for source in sources:
1329
+ header = f"{conversation_lib.default_conversation.system}\n\n"
1330
+ conversation = _add_speaker_and_signal(header, source)
1331
+ conversations.append(conversation)
1332
+
1333
+ # tokenize conversations
1334
+ # pyre-fixme[3]: Return type must be annotated.
1335
+ # pyre-fixme[2]: Parameter must be annotated.
1336
+ def get_tokenize_len(prompts):
1337
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
1338
+
1339
+ if has_image:
1340
+ input_ids = [
1341
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
1342
+ for prompt in conversations
1343
+ ]
1344
+ else:
1345
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
1346
+ input_ids = conversations_tokenized["input_ids"]
1347
+
1348
+ targets = copy.deepcopy(input_ids)
1349
+ for target, source in zip(targets, sources):
1350
+ if has_image:
1351
+ # pyre-fixme[61]: `header` is undefined, or not always defined.
1352
+ # pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
1353
+ # but got `str`.
1354
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
1355
+ else:
1356
+ tokenized_lens = _tokenize_fn(
1357
+ # pyre-fixme[61]: `header` is undefined, or not always defined.
1358
+ # pyre-fixme[6]: For 1st argument expected `Union[slice,
1359
+ # SupportsIndex]` but got `str`.
1360
+ [header] + [s["value"] for s in source],
1361
+ tokenizer,
1362
+ )["input_ids_lens"]
1363
+ # pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]` but
1364
+ # got `str`.
1365
+ speakers = [sentence["from"] for sentence in source]
1366
+ _mask_targets(target, tokenized_lens, speakers)
1367
+
1368
+ return dict(input_ids=input_ids, labels=targets)
1369
+
1370
+
1371
+ class LazySupervisedDataset(Dataset):
1372
+ """Dataset for supervised fine-tuning."""
1373
+
1374
+ def __init__(
1375
+ self,
1376
+ data_path: str,
1377
+ tokenizer: transformers.PreTrainedTokenizer,
1378
+ # pyre-fixme[2]: Parameter must be annotated.
1379
+ data_args,
1380
+ ) -> None:
1381
+ super(LazySupervisedDataset, self).__init__()
1382
+ list_data_dict = json.load(open(data_path, "r"))
1383
+
1384
+ self.tokenizer = tokenizer
1385
+ # pyre-fixme[4]: Attribute must be annotated.
1386
+ self.list_data_dict = list_data_dict
1387
+ # pyre-fixme[4]: Attribute must be annotated.
1388
+ self.data_args = data_args
1389
+
1390
+ @property
1391
+ # pyre-fixme[3]: Return type must be annotated.
1392
+ def lengths(self):
1393
+ length_list = []
1394
+ for sample in self.list_data_dict:
1395
+ img_tokens = 128 if "image" in sample else 0
1396
+ length_list.append(
1397
+ sum(len(conv["value"].split()) for conv in sample["conversations"])
1398
+ + img_tokens
1399
+ )
1400
+ return length_list
1401
+
1402
+ @property
1403
+ def modality_lengths(self) -> List[int]:
1404
+ length_list = []
1405
+ for sample in self.list_data_dict:
1406
+ cur_len = sum(
1407
+ len(conv["value"].split()) for conv in sample["conversations"]
1408
+ )
1409
+ cur_len = (
1410
+ cur_len if ("image" in sample) or ("video" in sample) else -cur_len
1411
+ )
1412
+ length_list.append(cur_len)
1413
+ return length_list
1414
+
1415
+ def __len__(self) -> int:
1416
+ return len(self.list_data_dict)
1417
+
1418
+ def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
1419
+ sources = self.list_data_dict[i]
1420
+ if isinstance(i, int):
1421
+ sources = [sources]
1422
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
1423
+ has_image = True
1424
+ if "image" in sources[0]:
1425
+ image_file = self.list_data_dict[i]["image"]
1426
+ image_folder = self.data_args.image_folder
1427
+ processor = self.data_args.image_processor
1428
+ full_path = os.path.join(image_folder, image_file)
1429
+ if not os.path.exists(full_path):
1430
+ print(full_path)
1431
+ has_image = False
1432
+ sources = copy.deepcopy([e["conversations"] for e in sources])
1433
+ else:
1434
+ image = Image.open(full_path).convert("RGB")
1435
+ if self.data_args.image_aspect_ratio == "sam":
1436
+ image = np.array(image)[:, :, ::-1]
1437
+ if self.data_args.image_aspect_ratio == "pad":
1438
+ # pyre-fixme[3]: Return type must be annotated.
1439
+ # pyre-fixme[2]: Parameter must be annotated.
1440
+ def expand2square(pil_img, background_color):
1441
+ width, height = pil_img.size
1442
+ if width == height:
1443
+ return pil_img
1444
+ elif width > height:
1445
+ result = Image.new(
1446
+ pil_img.mode, (width, width), background_color
1447
+ )
1448
+ result.paste(pil_img, (0, (width - height) // 2))
1449
+ return result
1450
+ else:
1451
+ result = Image.new(
1452
+ pil_img.mode, (height, height), background_color
1453
+ )
1454
+ result.paste(pil_img, ((height - width) // 2, 0))
1455
+ return result
1456
+
1457
+ image = expand2square(
1458
+ image, tuple(int(x * 255) for x in processor.image_mean)
1459
+ )
1460
+ image = processor.preprocess(image, return_tensors="pt")[
1461
+ "pixel_values"
1462
+ ][0]
1463
+ else:
1464
+ if self.data_args.image_aspect_ratio != "sam":
1465
+ image = processor.preprocess(image, return_tensors="pt")[
1466
+ "pixel_values"
1467
+ ][0]
1468
+ sources = preprocess_multimodal(
1469
+ copy.deepcopy([e["conversations"] for e in sources]), self.data_args
1470
+ )
1471
+ elif "video" in sources[0]:
1472
+ video_file = self.list_data_dict[i]["video"]
1473
+ video_folder = self.data_args.image_folder
1474
+ if "webvid" in video_folder:
1475
+ video_file = os.path.join(video_folder, "videos", video_file)
1476
+ elif "ActivityNet" in video_folder:
1477
+ video_file = os.path.join(video_folder, "train_val", video_file)
1478
+ else:
1479
+ video_file = os.path.join(video_folder, video_file)
1480
+ if not os.path.exists(video_file):
1481
+ print("nonexist: {}".format(video_file), flush=True)
1482
+ for sub_folder in os.listdir(video_folder):
1483
+ if os.path.isdir(os.path.join(video_folder, sub_folder)):
1484
+ for sub_sub_folder in os.listdir(
1485
+ os.path.join(video_folder, sub_folder)
1486
+ ):
1487
+ print("folder", sub_folder, sub_sub_folder)
1488
+ has_image = False
1489
+ sources = copy.deepcopy([e["conversations"] for e in sources])
1490
+ else:
1491
+ if video_file.endswith(".webm"):
1492
+ has_image = False
1493
+ sources = copy.deepcopy([e["conversations"] for e in sources])
1494
+ else:
1495
+ try:
1496
+ # if video_file.endswith(".webm"):
1497
+ # video_webm = VideoFileClip(video_file)
1498
+ # video_frames = np.array(list(video_webm.iter_frames()))
1499
+ # sample_fps = round(video_webm.fps / self.data_args.video_fps)
1500
+ # frame_idx = [i for i in range(0, len(video_frames), sample_fps)]
1501
+ # video = video_frames[frame_idx]
1502
+ # else:
1503
+ vr = VideoReader(video_file, ctx=cpu(0), num_threads=1)
1504
+ sample_fps = round(vr.get_avg_fps() / self.data_args.video_fps)
1505
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
1506
+ video = vr.get_batch(frame_idx).asnumpy()
1507
+ if self.data_args.image_aspect_ratio == "sam":
1508
+ image = video[:, :, :, ::-1][:100]
1509
+ else:
1510
+ processor = self.data_args.image_processor
1511
+ image = processor.preprocess(video, return_tensors="pt")[
1512
+ "pixel_values"
1513
+ ]
1514
+ sources = preprocess_multimodal(
1515
+ copy.deepcopy([e["conversations"] for e in sources]),
1516
+ self.data_args,
1517
+ )
1518
+ except:
1519
+ has_image = False
1520
+ sources = copy.deepcopy([e["conversations"] for e in sources])
1521
+ else:
1522
+ has_image = False
1523
+ sources = copy.deepcopy([e["conversations"] for e in sources])
1524
+ data_dict = preprocess(
1525
+ # pyre-fixme[6]: For 1st argument expected `Sequence[str]` but got
1526
+ # `Union[Dict[typing.Any, typing.Any], List[typing.Any]]`.
1527
+ sources,
1528
+ self.tokenizer,
1529
+ has_image=has_image,
1530
+ )
1531
+ if isinstance(i, int):
1532
+ data_dict = dict(
1533
+ input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]
1534
+ )
1535
+
1536
+ # image exist in the data
1537
+ if has_image:
1538
+ if "image" in self.list_data_dict[i]:
1539
+ # pyre-fixme[61]: Local variable `image` is undefined, or not always defined.
1540
+ data_dict["image"] = image
1541
+ elif "video" in self.list_data_dict[i]:
1542
+ # pyre-fixme[61]: Local variable `image` is undefined, or not always defined.
1543
+ data_dict["image"] = image
1544
+ elif self.data_args.is_multimodal:
1545
+ # image does not exist in the data, but the model is multimodal
1546
+ # crop_size = self.data_args.image_processor.crop_size
1547
+ # data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
1548
+ if self.data_args.image_aspect_ratio == "sam":
1549
+ if "video" in self.list_data_dict[i]:
1550
+ data_dict["image"] = np.zeros((1, 1024, 1024, 3)).astype(np.uint8)
1551
+ else:
1552
+ data_dict["image"] = np.zeros((1024, 1024, 3)).astype(np.uint8)
1553
+ else:
1554
+ crop_size = self.data_args.image_processor.crop_size
1555
+ if "video" in self.list_data_dict[i]:
1556
+ data_dict["image"] = torch.zeros(
1557
+ 1, 3, crop_size["height"], crop_size["width"]
1558
+ )
1559
+ else:
1560
+ data_dict["image"] = torch.zeros(
1561
+ 3, crop_size["height"], crop_size["width"]
1562
+ )
1563
+
1564
+ if has_image:
1565
+ if self.data_args.num_points > 0:
1566
+ if "box" in self.list_data_dict[i]:
1567
+ x1, y1, x2, y2 = self.list_data_dict[i]["box"]
1568
+ points = []
1569
+ x = random.uniform(x1, x2)
1570
+ y = random.uniform(y1, y2)
1571
+ points.append(torch.tensor([x, y, 1]))
1572
+ for _ in range(1, self.data_args.num_points):
1573
+ points.append(torch.tensor([0, 0, 0]))
1574
+ points = torch.stack(points, dim=0)
1575
+ data_dict["point"] = points
1576
+ else:
1577
+ if "point" in self.list_data_dict[i]:
1578
+ points = torch.tensor(self.list_data_dict[i]["point"])
1579
+ data_dict["point"] = points
1580
+ else:
1581
+ points = []
1582
+ grid = int(np.sqrt(self.data_args.num_points))
1583
+ height, width = image.shape[0], image.shape[1]
1584
+ for i in range(grid):
1585
+ for j in range(grid):
1586
+ points.append(
1587
+ torch.tensor(
1588
+ [
1589
+ width / grid / 2.0 + i / grid * width,
1590
+ height / grid / 2.0 + j / grid * height,
1591
+ 1,
1592
+ ]
1593
+ )
1594
+ )
1595
+ points = torch.stack(points, dim=0)
1596
+ data_dict["point"] = points
1597
+ elif self.data_args.is_multimodal:
1598
+ if self.data_args.num_points > 0:
1599
+ points = []
1600
+ grid = int(np.sqrt(self.data_args.num_points))
1601
+ height, width = data_dict["image"].shape[0], data_dict["image"].shape[1]
1602
+ for i in range(grid):
1603
+ for j in range(grid):
1604
+ points.append(
1605
+ torch.tensor(
1606
+ [
1607
+ width / grid / 2.0 + i / grid * width,
1608
+ height / grid / 2.0 + j / grid * height,
1609
+ 1,
1610
+ ]
1611
+ )
1612
+ )
1613
+ points = torch.stack(points, dim=0)
1614
+ data_dict["point"] = points
1615
+
1616
+ return data_dict
1617
+
1618
+
1619
+ @dataclass
1620
+ class DataCollatorForSupervisedDataset(object):
1621
+ """Collate examples for supervised fine-tuning."""
1622
+
1623
+ tokenizer: transformers.PreTrainedTokenizer
1624
+
1625
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
1626
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
1627
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
1628
+ input_ids, labels = tuple(
1629
+ [instance[key] for instance in instances] for key in ("input_ids", "labels")
1630
+ )
1631
+ input_ids = torch.nn.utils.rnn.pad_sequence(
1632
+ input_ids,
1633
+ batch_first=True,
1634
+ # pyre-fixme[6]: For 3rd argument expected `float` but got `Optional[int]`.
1635
+ padding_value=self.tokenizer.pad_token_id,
1636
+ )
1637
+ labels = torch.nn.utils.rnn.pad_sequence(
1638
+ labels, batch_first=True, padding_value=IGNORE_INDEX
1639
+ )
1640
+ input_ids = input_ids[:, : self.tokenizer.model_max_length]
1641
+ labels = labels[:, : self.tokenizer.model_max_length]
1642
+ batch = dict(
1643
+ input_ids=input_ids,
1644
+ labels=labels,
1645
+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[int]`.
1646
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
1647
+ )
1648
+
1649
+ # if "image" in instances[0]:
1650
+ # images = [instance["image"] for instance in instances]
1651
+ # if all(x is not None and x.shape == images[0].shape for x in images):
1652
+ # if type(images[0]) is torch.Tensor:
1653
+ # batch["images"] = torch.stack(images)
1654
+ # else:
1655
+ #
1656
+ # batch["images"] = np.stack(images)
1657
+ # else:
1658
+ #
1659
+ # # `List[typing.Any]`.
1660
+ # batch["images"] = images
1661
+
1662
+ if "image" in instances[0]:
1663
+ images = [instance["image"] for instance in instances]
1664
+ # pyre-fixme[6]: For 2nd argument expected `Tensor` but got `List[typing.Any]`.
1665
+ batch["images"] = images
1666
+
1667
+ if "point" in instances[0]:
1668
+ points = [instance["point"] for instance in instances]
1669
+ batch["points"] = torch.stack(points)
1670
+
1671
+ return batch
1672
+
1673
+
1674
+ def make_supervised_data_module(
1675
+ tokenizer: transformers.PreTrainedTokenizer,
1676
+ # pyre-fixme[2]: Parameter must be annotated.
1677
+ data_args,
1678
+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
1679
+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
1680
+ ) -> Dict:
1681
+ """Make dataset and collator for supervised fine-tuning."""
1682
+ train_dataset = LazySupervisedDataset(
1683
+ tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args
1684
+ )
1685
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
1686
+ return dict(
1687
+ train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
1688
+ )
longvu/mm_utils.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import base64
3
+ import math
4
+ from io import BytesIO
5
+
6
+ import torch
7
+ from longvu.constants import IMAGE_TOKEN_INDEX
8
+ from PIL import Image
9
+
10
+ from transformers import StoppingCriteria
11
+
12
+
13
+ def select_best_resolution(original_size, possible_resolutions):
14
+ """
15
+ Selects the best resolution from a list of possible resolutions based on the original size.
16
+
17
+ Args:
18
+ original_size (tuple): The original size of the image in the format (width, height).
19
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
20
+
21
+ Returns:
22
+ tuple: The best fit resolution in the format (width, height).
23
+ """
24
+ original_width, original_height = original_size
25
+ best_fit = None
26
+ max_effective_resolution = 0
27
+ min_wasted_resolution = float("inf")
28
+
29
+ for width, height in possible_resolutions:
30
+ scale = min(width / original_width, height / original_height)
31
+ downscaled_width, downscaled_height = int(original_width * scale), int(
32
+ original_height * scale
33
+ )
34
+ effective_resolution = min(
35
+ downscaled_width * downscaled_height, original_width * original_height
36
+ )
37
+ wasted_resolution = (width * height) - effective_resolution
38
+
39
+ if effective_resolution > max_effective_resolution or (
40
+ effective_resolution == max_effective_resolution
41
+ and wasted_resolution < min_wasted_resolution
42
+ ):
43
+ max_effective_resolution = effective_resolution
44
+ min_wasted_resolution = wasted_resolution
45
+ best_fit = (width, height)
46
+
47
+ return best_fit
48
+
49
+
50
+ def resize_and_pad_image(image, target_resolution):
51
+ """
52
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
53
+
54
+ Args:
55
+ image (PIL.Image.Image): The input image.
56
+ target_resolution (tuple): The target resolution (width, height) of the image.
57
+
58
+ Returns:
59
+ PIL.Image.Image: The resized and padded image.
60
+ """
61
+ original_width, original_height = image.size
62
+ target_width, target_height = target_resolution
63
+
64
+ scale_w = target_width / original_width
65
+ scale_h = target_height / original_height
66
+
67
+ if scale_w < scale_h:
68
+ new_width = target_width
69
+ new_height = min(math.ceil(original_height * scale_w), target_height)
70
+ else:
71
+ new_height = target_height
72
+ new_width = min(math.ceil(original_width * scale_h), target_width)
73
+
74
+ # Resize the image
75
+ resized_image = image.resize((new_width, new_height))
76
+
77
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
78
+ paste_x = (target_width - new_width) // 2
79
+ paste_y = (target_height - new_height) // 2
80
+ new_image.paste(resized_image, (paste_x, paste_y))
81
+
82
+ return new_image
83
+
84
+
85
+ def divide_to_patches(image, patch_size):
86
+ """
87
+ Divides an image into patches of a specified size.
88
+
89
+ Args:
90
+ image (PIL.Image.Image): The input image.
91
+ patch_size (int): The size of each patch.
92
+
93
+ Returns:
94
+ list: A list of PIL.Image.Image objects representing the patches.
95
+ """
96
+ patches = []
97
+ width, height = image.size
98
+ for i in range(0, height, patch_size):
99
+ for j in range(0, width, patch_size):
100
+ box = (j, i, j + patch_size, i + patch_size)
101
+ patch = image.crop(box)
102
+ patches.append(patch)
103
+
104
+ return patches
105
+
106
+
107
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
108
+ """
109
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
110
+
111
+ Args:
112
+ image_size (tuple): The size of the input image in the format (width, height).
113
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
114
+ patch_size (int): The size of each image patch.
115
+
116
+ Returns:
117
+ tuple: The shape of the image patch grid in the format (width, height).
118
+ """
119
+ if type(grid_pinpoints) is list:
120
+ possible_resolutions = grid_pinpoints
121
+ else:
122
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
123
+ width, height = select_best_resolution(image_size, possible_resolutions)
124
+ return width // patch_size, height // patch_size
125
+
126
+
127
+ def process_anyres_image(image, processor, grid_pinpoints):
128
+ """
129
+ Process an image with variable resolutions.
130
+
131
+ Args:
132
+ image (PIL.Image.Image): The input image to be processed.
133
+ processor: The image processor object.
134
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
135
+
136
+ Returns:
137
+ torch.Tensor: A tensor containing the processed image patches.
138
+ """
139
+ if type(grid_pinpoints) is list:
140
+ possible_resolutions = grid_pinpoints
141
+ else:
142
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
143
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
144
+ image_padded = resize_and_pad_image(image, best_resolution)
145
+
146
+ patches = divide_to_patches(image_padded, processor.crop_size["height"])
147
+
148
+ image_original_resize = image.resize(
149
+ (processor.size["shortest_edge"], processor.size["shortest_edge"])
150
+ )
151
+
152
+ image_patches = [image_original_resize] + patches
153
+ image_patches = [
154
+ processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
155
+ for image_patch in image_patches
156
+ ]
157
+ return torch.stack(image_patches, dim=0)
158
+
159
+
160
+ def load_image_from_base64(image):
161
+ return Image.open(BytesIO(base64.b64decode(image)))
162
+
163
+
164
+ def expand2square(pil_img, background_color):
165
+ width, height = pil_img.size
166
+ if width == height:
167
+ return pil_img
168
+ elif width > height:
169
+ result = Image.new(pil_img.mode, (width, width), background_color)
170
+ result.paste(pil_img, (0, (width - height) // 2))
171
+ return result
172
+ else:
173
+ result = Image.new(pil_img.mode, (height, height), background_color)
174
+ result.paste(pil_img, ((height - width) // 2, 0))
175
+ return result
176
+
177
+
178
+ # def process_images(images, image_processor, model_cfg):
179
+ # image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
180
+ # new_images = []
181
+ # if image_aspect_ratio == 'pad':
182
+ # for image in images:
183
+ # image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
184
+ # image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
185
+ # new_images.append(image)
186
+ # elif image_aspect_ratio == "anyres":
187
+ # for image in images:
188
+ # image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
189
+ # new_images.append(image)
190
+ # else:
191
+ # return image_processor(images, return_tensors='pt')['pixel_values']
192
+ # if all(x.shape == new_images[0].shape for x in new_images):
193
+ # new_images = torch.stack(new_images, dim=0)
194
+ # return new_images
195
+
196
+
197
+ # multiple vision towers
198
+ def process_images(images, image_processor, model_cfg):
199
+ processor_aux_list = image_processor
200
+ new_images_aux_list = []
201
+ for image in images:
202
+ image_aux_list = []
203
+ for processor_aux in processor_aux_list:
204
+ image_aux = image
205
+ if hasattr(processor_aux, "image_mean"):
206
+ try:
207
+ target_resolution = processor_aux.crop_size["height"]
208
+ except:
209
+ target_resolution = processor_aux.size["height"]
210
+ image_aux = expand2square(
211
+ image_aux, tuple(int(x * 255) for x in processor_aux.image_mean)
212
+ ).resize((target_resolution, target_resolution))
213
+ image_aux = processor_aux.preprocess(image_aux, return_tensors="pt")[
214
+ "pixel_values"
215
+ ][0]
216
+ image_aux_list.append(image_aux)
217
+ new_images_aux_list.append(image_aux_list)
218
+ new_images_aux_list = [
219
+ list(batch_image_aux) for batch_image_aux in zip(*new_images_aux_list)
220
+ ]
221
+ new_images_aux_list = [
222
+ torch.stack(image_aux).half().cuda() for image_aux in new_images_aux_list
223
+ ]
224
+ return new_images_aux_list
225
+
226
+
227
+ def tokenizer_image_token(
228
+ prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
229
+ ):
230
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
231
+
232
+ def insert_separator(X, sep):
233
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
234
+
235
+ input_ids = []
236
+ offset = 0
237
+ if (
238
+ len(prompt_chunks) > 0
239
+ and len(prompt_chunks[0]) > 0
240
+ and prompt_chunks[0][0] == tokenizer.bos_token_id
241
+ ):
242
+ offset = 1
243
+ input_ids.append(prompt_chunks[0][0])
244
+
245
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
246
+ input_ids.extend(x[offset:])
247
+
248
+ if return_tensors is not None:
249
+ if return_tensors == "pt":
250
+ return torch.tensor(input_ids, dtype=torch.long)
251
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
252
+ return input_ids
253
+
254
+
255
+ def tokenizer_image_token_llama3(
256
+ prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
257
+ ):
258
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
259
+
260
+ def insert_separator(X, sep):
261
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
262
+
263
+ input_ids = []
264
+ for x in insert_separator(prompt_chunks, [image_token_index]):
265
+ input_ids.extend(x)
266
+
267
+ if return_tensors is not None:
268
+ if return_tensors == "pt":
269
+ return torch.tensor(input_ids, dtype=torch.long)
270
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
271
+ return input_ids
272
+
273
+
274
+ def get_model_name_from_path(model_path):
275
+ model_path = model_path.strip("/")
276
+ model_paths = model_path.split("/")
277
+ if model_paths[-1].startswith("checkpoint-"):
278
+ return model_paths[-2] + "_" + model_paths[-1]
279
+ else:
280
+ return model_paths[-1]
281
+
282
+
283
+ class KeywordsStoppingCriteria(StoppingCriteria):
284
+ def __init__(self, keywords, tokenizer, input_ids):
285
+ self.keywords = keywords
286
+ self.keyword_ids = []
287
+ self.max_keyword_len = 0
288
+ for keyword in keywords:
289
+ cur_keyword_ids = tokenizer(keyword).input_ids
290
+ if (
291
+ len(cur_keyword_ids) > 1
292
+ and cur_keyword_ids[0] == tokenizer.bos_token_id
293
+ ):
294
+ cur_keyword_ids = cur_keyword_ids[1:]
295
+ if len(cur_keyword_ids) > self.max_keyword_len:
296
+ self.max_keyword_len = len(cur_keyword_ids)
297
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
298
+ self.tokenizer = tokenizer
299
+ self.start_len = input_ids.shape[1]
300
+
301
+ def call_for_batch(
302
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
303
+ ) -> bool:
304
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
305
+ self.keyword_ids = [
306
+ keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
307
+ ]
308
+ for keyword_id in self.keyword_ids:
309
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0] :]
310
+ if torch.equal(truncated_output_ids, keyword_id):
311
+ return True
312
+ outputs = self.tokenizer.batch_decode(
313
+ output_ids[:, -offset:], skip_special_tokens=True
314
+ )[0]
315
+ for keyword in self.keywords:
316
+ if keyword in outputs:
317
+ return True
318
+ return False
319
+
320
+ def __call__(
321
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
322
+ ) -> bool:
323
+ outputs = []
324
+ for i in range(output_ids.shape[0]):
325
+ # pyre-fixme[6]: For 1st argument expected `LongTensor` but got `Tensor`.
326
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
327
+ return all(outputs)
longvu/multimodal_encoder/__pycache__/base_encoder.cpython-310.pyc ADDED
Binary file (4.33 kB). View file
 
longvu/multimodal_encoder/__pycache__/builder.cpython-310.pyc ADDED
Binary file (1 kB). View file
 
longvu/multimodal_encoder/__pycache__/dino_encoder.cpython-310.pyc ADDED
Binary file (3.67 kB). View file
 
longvu/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc ADDED
Binary file (2.6 kB). View file
 
longvu/multimodal_encoder/base_encoder.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ProcessorWrapper:
8
+ def __init__(
9
+ self,
10
+ transform,
11
+ height=378,
12
+ width=378,
13
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
14
+ ):
15
+ self._crop_size = {
16
+ "height": height,
17
+ "width": width,
18
+ }
19
+ self._transforms = transform
20
+ # print(transform)
21
+ self.image_mean = image_mean
22
+
23
+ @property
24
+ def crop_size(self):
25
+ return self._crop_size
26
+
27
+ def preprocess(self, image, return_tensors="pt"):
28
+ # Ensure image is a PIL Image
29
+ output = {}
30
+ output["pixel_values"] = [self._transforms(image)]
31
+ return output
32
+
33
+
34
+ class BaseVisionTower(nn.Module):
35
+ def __init__(self, vision_tower_name, args, delay_load=False):
36
+ super().__init__()
37
+
38
+ self.is_loaded = False
39
+ self.args = args
40
+
41
+ self.vision_tower_name = vision_tower_name
42
+ self.select_layer = args.mm_vision_select_layer
43
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
44
+ self.unfreeze_mm_vision_tower = getattr(args, "unfreeze_mm_vision_tower", False)
45
+ self.delay_load = delay_load
46
+
47
+ @abstractmethod
48
+ def load_model(self, device_map=None):
49
+ raise NotImplementedError("Subclasses must implement load_model")
50
+
51
+ @abstractmethod
52
+ def _forward(self, images):
53
+ raise NotImplementedError("Subclasses must implement forward")
54
+
55
+ def forward(self, images):
56
+ if type(images) is list:
57
+ image_features = [self._forward(image.unsqueeze(0)) for image in images]
58
+ else:
59
+ image_features = self._forward(images)
60
+
61
+ return image_features
62
+
63
+ @property
64
+ def dummy_feature(self):
65
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
66
+
67
+ @property
68
+ def dtype(self):
69
+ # Dynamically infer the dtype from the first parameter, if not explicitly specified
70
+ if hasattr(self.vision_tower, "dtype"):
71
+ return self.vision_tower.dtype
72
+ else:
73
+ params = list(self.vision_tower.parameters())
74
+ return (
75
+ params[0].dtype if len(params) > 0 else torch.float32
76
+ ) # Default to torch.float32 if no parameters
77
+
78
+ @property
79
+ def device(self):
80
+ # Dynamically infer the device from the first parameter, if not explicitly specified
81
+ if hasattr(self.vision_tower, "device"):
82
+ return self.vision_tower.device
83
+ else:
84
+ params = list(self.vision_tower.parameters())
85
+ return (
86
+ params[0].device if len(params) > 0 else torch.device("cpu")
87
+ ) # Default to CPU if no parameters
88
+
89
+ @property
90
+ def config(self):
91
+ if self.is_loaded:
92
+ return self.vision_tower.config
93
+ else:
94
+ return self.cfg_only
95
+
96
+ @property
97
+ def hidden_size(self):
98
+ try:
99
+ return self.config.hidden_size
100
+ except:
101
+ return self._hidden_size
102
+
103
+ @property
104
+ def image_size(self): # resolution
105
+ # return self.config.image_size
106
+ try:
107
+ return self.config.image_size
108
+ except:
109
+ return self._image_size
110
+
111
+ @property
112
+ def patch_size(self):
113
+ # return self.config.patch_size
114
+ try:
115
+ return self.config.patch_size
116
+ except:
117
+ return self._patch_size
118
+
119
+ @property
120
+ def num_patches_per_side(self):
121
+ if self._interp_size is not None:
122
+ return int(self._interp_size**0.5)
123
+ try:
124
+ return self.image_size // self.patch_size
125
+ except:
126
+ return self._num_patches_per_side
127
+
128
+ @property
129
+ def num_patches(self):
130
+ if self._interp_size is not None:
131
+ return self._interp_size
132
+ try:
133
+ return self.num_patches_per_side**2
134
+ except:
135
+ return self._num_patches
longvu/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ import copy
3
+
4
+ from .dino_encoder import DinoVisionTower
5
+ from .siglip_encoder import SiglipVisionTower
6
+
7
+
8
+ def build_vision_tower_aux_list(vision_tower_cfg, **kwargs):
9
+ vision_tower_aux_name_list = getattr(
10
+ vision_tower_cfg,
11
+ "mm_vision_tower_aux_list",
12
+ getattr(vision_tower_cfg, "vision_tower_aux_list", None),
13
+ )
14
+ vision_tower_aux_token_len_list = getattr(
15
+ vision_tower_cfg,
16
+ "mm_vision_tower_aux_token_len_list",
17
+ getattr(vision_tower_cfg, "vision_tower_aux_token_len_list", None),
18
+ )
19
+ vision_tower_aux_list = []
20
+ for vision_tower_aux_name, vision_tower_aux_token_len in zip(
21
+ vision_tower_aux_name_list, vision_tower_aux_token_len_list
22
+ ):
23
+ config = copy.deepcopy(vision_tower_cfg)
24
+ vision_tower_aux_name += "-interp{}".format(vision_tower_aux_token_len)
25
+ if "siglip" in vision_tower_aux_name.lower():
26
+ vision_tower_aux_list.append(
27
+ SiglipVisionTower(vision_tower_aux_name, args=config, **kwargs)
28
+ )
29
+
30
+ # SSL-based Vision Towers
31
+ elif "dinov2" in vision_tower_aux_name.lower():
32
+ vision_tower_aux_list.append(
33
+ DinoVisionTower(vision_tower_aux_name, args=config, **kwargs)
34
+ )
35
+ else:
36
+ raise ValueError(f"Unknown vision tower: {vision_tower_aux_name}")
37
+ return vision_tower_aux_list
longvu/multimodal_encoder/dino_encoder.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from transformers import AutoImageProcessor, Dinov2Config, Dinov2Model
5
+
6
+ from .base_encoder import BaseVisionTower, ProcessorWrapper
7
+
8
+
9
+ class DinoVisionTower(BaseVisionTower):
10
+ def __init__(self, vision_tower, args, delay_load=False):
11
+ super(DinoVisionTower, self).__init__(vision_tower, args, delay_load)
12
+
13
+ model_path = "./checkpoints/dinov2-giant"
14
+ base_model_name, res, interp = model_path, 378, 576
15
+ self._vision_tower_name = vision_tower
16
+ self.vision_tower_name = base_model_name
17
+ self._image_size = res
18
+ self._interp_size = interp
19
+ self._patch_size = 14 # default patch size
20
+
21
+ if not self.delay_load:
22
+ self.load_model()
23
+ else:
24
+ self.cfg_only = Dinov2Config.from_pretrained(self.vision_tower_name)
25
+
26
+ def load_model(self, device_map=None):
27
+
28
+ self.vision_tower = Dinov2Model.from_pretrained(self.vision_tower_name)
29
+ """ValueError: Dinov2Model does not support `device_map='auto'`. To implement support, the model class needs to implement the `_no_split_modules` attribute."""
30
+ self.vision_tower._no_split_modules = ["Dinov2SwiGLUFFN"]
31
+
32
+ _image_size = self.vision_tower.config.image_size
33
+ if self._image_size is None:
34
+ self._image_size = _image_size
35
+
36
+ # increase shortest edge to prevent edge case crops
37
+ default_shortest_ratio = 8 / 7 # 224/256
38
+ # shortest_edge = int(default_shortest_ratio * self._image_size)
39
+ shortest_edge = self._image_size
40
+
41
+ processor = AutoImageProcessor.from_pretrained(
42
+ self.vision_tower_name,
43
+ crop_size=dict(height=self._image_size, width=self._image_size),
44
+ size=dict(shortest_edge=shortest_edge),
45
+ )
46
+ self.image_processor = processor
47
+
48
+ # Assign the output channels of the projection convolution as the hidden size
49
+ self._hidden_size = (
50
+ self.vision_tower.embeddings.patch_embeddings.projection.out_channels
51
+ )
52
+ # Assign the first value of the stride of the projection convolution as the patch size
53
+ self._patch_size = (
54
+ self.vision_tower.embeddings.patch_embeddings.projection.stride[0]
55
+ )
56
+
57
+ # print(self._hidden_size, self._patch_size)
58
+
59
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
60
+ self.is_loaded = True
61
+
62
+ @property
63
+ def image_size(self):
64
+ return self._image_size
65
+
66
+ def feature_select(self, outputs):
67
+ sequence_output = outputs[
68
+ "last_hidden_state"
69
+ ] # batch_size, sequence_length, hidden_size
70
+
71
+ if self.select_feature == "cls_patch":
72
+ image_features = sequence_output
73
+ elif self.select_feature == "patch":
74
+ image_features = sequence_output[:, 1:]
75
+ elif self.select_feature == "cls":
76
+ image_features = sequence_output[:, 0]
77
+ else:
78
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
79
+ return image_features
80
+
81
+ def interpolate(self, image_features):
82
+ if self._interp_size is None:
83
+ return image_features
84
+
85
+ b, num_tokens, dim = image_features.shape
86
+
87
+ if num_tokens != self.num_patches:
88
+ target_h = target_w = int(self._interp_size**0.5)
89
+ h = w = int(num_tokens**0.5)
90
+
91
+ image_features = image_features.view(b, h, w, dim)
92
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
93
+
94
+ image_features = F.interpolate(
95
+ image_features.to(torch.float32),
96
+ size=(target_h, target_w),
97
+ mode="bilinear",
98
+ align_corners=False,
99
+ ).to(image_features.dtype)
100
+
101
+ # Permute the dimensions back to (b, target_h, target_w, dim)
102
+ image_features = image_features.permute(0, 2, 3, 1).contiguous()
103
+
104
+ # Flatten the spatial dimensions (target_h, target_w) into a single dimension
105
+ image_features = image_features.flatten(1, 2)
106
+
107
+ return image_features
108
+
109
+ def _forward(self, images):
110
+ # logger.warning(f"images shape: {images.shape}")
111
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
112
+ image_forward_outs = self.vision_tower.forward(
113
+ images.to(device=self.device, dtype=self.dtype)
114
+ )
115
+ # logger.warning(f"image_forward_outs shape: {image_forward_outs['last_hidden_state'].shape}")
116
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
117
+ # logger.warning(f"image_features shape: {image_features.shape}")
118
+ interp_features = self.interpolate(image_features)
119
+ # logger.warning(f"interp_features shape: {interp_features.shape}")
120
+ return interp_features
121
+
122
+ @property
123
+ def num_patches_per_side(self):
124
+ return int(self.num_patches**0.5)
125
+
126
+ @property
127
+ def num_patches(self):
128
+ if self._interp_size is None:
129
+ return (self._image_size // self._patch_size) ** 2
130
+ else:
131
+ return self._interp_size
longvu/multimodal_encoder/drop.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+
17
+ # pyre-unsafe
18
+ """Drop regularization layers."""
19
+
20
+ from torch import nn
21
+
22
+
23
+ class DropPathV(nn.Module):
24
+ """Set examples to zero randomly."""
25
+
26
+ def __init__(self, p=0.1, inplace=False):
27
+ super(DropPathV, self).__init__()
28
+ self.p = p
29
+ self.inplace = inplace
30
+
31
+ def forward(self, input):
32
+ if not self.training or self.p <= 0:
33
+ return input
34
+ keep_p = 1 - self.p
35
+ shape = (input.shape[0],) + (1,) * (input.dim() - 1)
36
+ scale = input.new_empty(shape).bernoulli_(keep_p).div_(keep_p)
37
+ return input.mul_(scale) if self.inplace else input.mul(scale)
38
+
39
+ def extra_repr(self):
40
+ inplace_str = ", inplace" if self.inplace else ""
41
+ return "p={}{}".format(self.p, inplace_str)
longvu/multimodal_encoder/image.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+
17
+ # pyre-unsafe
18
+ """Image utilities."""
19
+
20
+ import numpy as np
21
+ import PIL.Image
22
+ import torch
23
+
24
+
25
+ def im_resize(img, size=None, scale=None, mode="linear"):
26
+ """Resize image by the scale or size."""
27
+ if size is None:
28
+ if not isinstance(scale, (tuple, list)):
29
+ scale = (scale, scale)
30
+ h, w = img.shape[:2]
31
+ size = int(h * scale[0] + 0.5), int(w * scale[1] + 0.5)
32
+ else:
33
+ if not isinstance(size, (tuple, list)):
34
+ size = (size, size)
35
+ resize_modes = {"linear": PIL.Image.BILINEAR}
36
+ from torchvision.transforms import ToPILImage
37
+
38
+ to_pil = ToPILImage()
39
+ img = to_pil(img.to(torch.float32).cpu())
40
+ # img = PIL.Image.fromarray(img)
41
+ return np.array(img.resize(size[::-1], resize_modes[mode]))
42
+
43
+
44
+ def im_rescale(img, scales, max_size=0):
45
+ """Rescale image to match the detecting scales."""
46
+ im_shape = img.shape
47
+ img_list, img_scales = [], []
48
+ size_min = np.min(im_shape[:2])
49
+ size_max = np.max(im_shape[:2])
50
+ for target_size in scales:
51
+ im_scale = float(target_size) / float(size_min)
52
+ target_size_max = max_size if max_size > 0 else target_size
53
+ if np.round(im_scale * size_max) > target_size_max:
54
+ im_scale = float(target_size_max) / float(size_max)
55
+ img_list.append(im_resize(img, scale=im_scale))
56
+ img_scales.append((im_scale, im_scale))
57
+ return img_list, img_scales
58
+
59
+
60
+ def im_vstack(arrays, fill_value=None, dtype=None, size=None, align=None):
61
+ """Stack image arrays in sequence vertically."""
62
+ if fill_value is None:
63
+ return np.vstack(arrays)
64
+ # Compute the max stack shape.
65
+ max_shape = np.max(np.stack([arr.shape for arr in arrays]), 0)
66
+ if size is not None and min(size) > 0:
67
+ max_shape[: len(size)] = size
68
+ if align is not None and min(align) > 0:
69
+ align_size = np.ceil(max_shape[: len(align)] / align)
70
+ max_shape[: len(align)] = align_size.astype("int64") * align
71
+ # Fill output with the given value.
72
+ output_dtype = dtype or arrays[0].dtype
73
+ output_shape = [len(arrays)] + list(max_shape)
74
+ output = np.empty(output_shape, output_dtype)
75
+ output[:] = fill_value
76
+ # Copy arrays.
77
+ for i, arr in enumerate(arrays):
78
+ copy_slices = (slice(0, d) for d in arr.shape)
79
+ output[(i,) + tuple(copy_slices)] = arr
80
+ return output
longvu/multimodal_encoder/logging.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+
17
+ # pyre-unsafe
18
+ """Logging utilities."""
19
+
20
+ import inspect
21
+ import logging as _logging
22
+ import os
23
+ import sys as _sys
24
+ import threading
25
+
26
+
27
+ _logger = None
28
+ _logger_lock = threading.Lock()
29
+
30
+
31
+ def get_logger():
32
+ global _logger
33
+ # Use double-checked locking to avoid taking lock unnecessarily.
34
+ if _logger:
35
+ return _logger
36
+ _logger_lock.acquire()
37
+ try:
38
+ if _logger:
39
+ return _logger
40
+ logger = _logging.getLogger("tokenize-anything")
41
+ logger.setLevel("INFO")
42
+ logger.propagate = False
43
+ logger._is_root = True
44
+ if True:
45
+ # Determine whether we are in an interactive environment.
46
+ _interactive = False
47
+ try:
48
+ # This is only defined in interactive shells.
49
+ if _sys.ps1:
50
+ _interactive = True
51
+ except AttributeError:
52
+ # Even now, we may be in an interactive shell with `python -i`.
53
+ _interactive = _sys.flags.interactive
54
+ # If we are in an interactive environment (like Jupyter), set loglevel
55
+ # to INFO and pipe the output to stdout.
56
+ if _interactive:
57
+ logger.setLevel("INFO")
58
+ _logging_target = _sys.stdout
59
+ else:
60
+ _logging_target = _sys.stderr
61
+ # Add the output handler.
62
+ _handler = _logging.StreamHandler(_logging_target)
63
+ _handler.setFormatter(_logging.Formatter("%(levelname)s %(message)s"))
64
+ logger.addHandler(_handler)
65
+ _logger = logger
66
+ return _logger
67
+ finally:
68
+ _logger_lock.release()
69
+
70
+
71
+ def _detailed_msg(msg):
72
+ file, lineno = inspect.stack()[:3][2][1:3]
73
+ return "{}:{}] {}".format(os.path.split(file)[-1], lineno, msg)
74
+
75
+
76
+ def log(level, msg, *args, **kwargs):
77
+ get_logger().log(level, _detailed_msg(msg), *args, **kwargs)
78
+
79
+
80
+ def debug(msg, *args, **kwargs):
81
+ if is_root():
82
+ get_logger().debug(_detailed_msg(msg), *args, **kwargs)
83
+
84
+
85
+ def error(msg, *args, **kwargs):
86
+ get_logger().error(_detailed_msg(msg), *args, **kwargs)
87
+ assert 0
88
+
89
+
90
+ def fatal(msg, *args, **kwargs):
91
+ get_logger().fatal(_detailed_msg(msg), *args, **kwargs)
92
+ assert 0
93
+
94
+
95
+ def info(msg, *args, **kwargs):
96
+ if is_root():
97
+ get_logger().info(_detailed_msg(msg), *args, **kwargs)
98
+
99
+
100
+ def warning(msg, *args, **kwargs):
101
+ if is_root():
102
+ get_logger().warning(_detailed_msg(msg), *args, **kwargs)
103
+
104
+
105
+ def get_verbosity():
106
+ """Return how much logging output will be produced."""
107
+ return get_logger().getEffectiveLevel()
108
+
109
+
110
+ def set_verbosity(v):
111
+ """Set the threshold for what messages will be logged."""
112
+ get_logger().setLevel(v)
113
+
114
+
115
+ def set_formatter(fmt=None, datefmt=None):
116
+ """Set the formatter."""
117
+ handler = _logging.StreamHandler(_sys.stderr)
118
+ handler.setFormatter(_logging.Formatter(fmt, datefmt))
119
+ logger = get_logger()
120
+ logger.removeHandler(logger.handlers[0])
121
+ logger.addHandler(handler)
122
+
123
+
124
+ def set_root(is_root=True):
125
+ """Set logger to the root."""
126
+ get_logger()._is_root = is_root
127
+
128
+
129
+ def is_root():
130
+ """Return logger is the root."""
131
+ return get_logger()._is_root
longvu/multimodal_encoder/loss.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+
17
+ # pyre-unsafe
18
+ """Loss layers."""
19
+
20
+ from torch import nn
21
+
22
+
23
+ def reduce_loss(loss, reduction="mean"):
24
+ """Reduce the loss."""
25
+ if reduction == "mean" or reduction == "sum":
26
+ return getattr(loss, reduction)()
27
+ if reduction == "batch_mean":
28
+ return loss.sum().mul_(1.0 / loss.size(0))
29
+ return loss
30
+
31
+
32
+ class BinaryFocalLoss(nn.Module):
33
+ """Binary focal loss."""
34
+
35
+ def __init__(self, alpha=0.25, reduction="none"):
36
+ super(BinaryFocalLoss, self).__init__()
37
+ self.alpha = alpha
38
+ self.reduction = reduction
39
+
40
+ def forward(self, input, target):
41
+ alpha, p = self.alpha, input.sigmoid()
42
+ neg_alpha, neg_target = 1.0 - alpha, 1.0 - target
43
+ alpha_weight = target.mul(alpha).add_(neg_target.mul(neg_alpha))
44
+ focal_weight = (1.0 - p).mul_(target).add_(p.mul(neg_target)).square()
45
+ loss = nn.functional.binary_cross_entropy_with_logits(
46
+ input, target, reduction="none"
47
+ )
48
+ return reduce_loss(loss * focal_weight.mul_(alpha_weight), self.reduction)
49
+
50
+
51
+ class BinaryDiceLoss(nn.Module):
52
+ """Binary dice loss."""
53
+
54
+ def __init__(self, eps=1.0, reduction="none"):
55
+ super(BinaryDiceLoss, self).__init__()
56
+ self.eps = eps
57
+ self.reduction = reduction
58
+
59
+ def forward(self, input, target):
60
+ input = input.sigmoid()
61
+ num = input.mul(target).sum(-1).mul_(2).add_(self.eps)
62
+ den = input.add(target).sum(-1).add_(self.eps)
63
+ return reduce_loss(1.0 - num / den, self.reduction)
64
+
65
+
66
+ class CrossEntropyLoss(nn.Module):
67
+ """Cross entropy loss with label smoothing."""
68
+
69
+ def __init__(self, epsilon=0, reduction="none"):
70
+ super(CrossEntropyLoss, self).__init__()
71
+ self.epsilon = epsilon
72
+ self.reduction = reduction
73
+
74
+ def forward_dense(self, input, target):
75
+ dim, target = input.shape[-1], target.squeeze_()
76
+ x = nn.functional.log_softmax(input, dim=-1)
77
+ y = nn.functional.one_hot(target, dim).float()
78
+ x = (
79
+ x.permute([0, x.dim() - 1] + list(range(x.dim()))[1:-1])
80
+ if x.dim() > 2
81
+ else x
82
+ )
83
+ y = (
84
+ y.permute([0, y.dim() - 1] + list(range(y.dim()))[1:-1])
85
+ if y.dim() > 2
86
+ else y
87
+ )
88
+ loss = nn.functional.cross_entropy(
89
+ x, y, reduction="none", label_smoothing=self.epsilon
90
+ )
91
+ return reduce_loss(loss, self.reduction)
92
+
93
+ def forward(self, input, target):
94
+ if self.epsilon > 0:
95
+ return self.forward_dense(input, target)
96
+ return nn.functional.cross_entropy(input, target, reduction=self.reduction)
longvu/multimodal_encoder/registry.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+
17
+ # pyre-unsafe
18
+ """Registry utilities."""
19
+
20
+ import collections
21
+ import functools
22
+
23
+
24
+ class Registry(object):
25
+ """Registry class."""
26
+
27
+ def __init__(self, name):
28
+ self.name = name
29
+ self.registry = collections.OrderedDict()
30
+
31
+ def has(self, key):
32
+ return key in self.registry
33
+
34
+ def register(self, name, func=None, **kwargs):
35
+ def decorated(inner_function):
36
+ for key in name if isinstance(name, (tuple, list)) else [name]:
37
+ self.registry[key] = functools.partial(inner_function, **kwargs)
38
+ return inner_function
39
+
40
+ if func is not None:
41
+ return decorated(func)
42
+ return decorated
43
+
44
+ def get(self, name, default=None):
45
+ if name is None:
46
+ return None
47
+ if not self.has(name):
48
+ if default is not None:
49
+ return default
50
+ raise KeyError("`%s` is not registered in <%s>." % (name, self.name))
51
+ return self.registry[name]
52
+
53
+ def try_get(self, name):
54
+ if self.has(name):
55
+ return self.get(name)
56
+ return None
longvu/multimodal_encoder/siglip_encoder.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
5
+
6
+ from .base_encoder import BaseVisionTower, ProcessorWrapper
7
+
8
+
9
+ class SiglipVisionTower(BaseVisionTower):
10
+ def __init__(self, vision_tower_name, args, delay_load=False):
11
+ super(SiglipVisionTower, self).__init__(vision_tower_name, args, delay_load)
12
+
13
+ model_path = "./checkpoints/siglip-so400m-patch14-384"
14
+ base_model_name, res, interp = model_path, 384, 576
15
+ self.vision_tower_name = base_model_name
16
+ self._image_size = res if res is not None else 512
17
+ self._interp_size = interp
18
+ if not self.delay_load:
19
+ self.load_model()
20
+ elif self.unfreeze_mm_vision_tower:
21
+ self.load_model()
22
+ else:
23
+ self._hidden_size = 1152
24
+
25
+ def load_model(self, device_map=None):
26
+ self.vision_model = "siglip"
27
+ # clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
28
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
29
+
30
+ # self.vision_tower = clip_model.visual.trunk
31
+ self.vision_tower.output_tokens = True
32
+
33
+ self._hidden_size = self.vision_tower.config.hidden_size
34
+ self._image_size = self.vision_tower.config.image_size
35
+ self._patch_size = self.vision_tower.config.patch_size
36
+ self.image_processor = SiglipImageProcessor.from_pretrained(
37
+ self.vision_tower_name
38
+ )
39
+
40
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
41
+ self.is_loaded = True
42
+
43
+ def interpolate(self, image_features):
44
+ if self._interp_size is None:
45
+ return image_features
46
+
47
+ b, num_tokens, dim = image_features.shape
48
+
49
+ if num_tokens != self.num_patches:
50
+ target_h = target_w = int(self._interp_size**0.5)
51
+ h = w = int(num_tokens**0.5)
52
+
53
+ image_features = image_features.view(b, h, w, dim)
54
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
55
+
56
+ image_features = F.interpolate(
57
+ image_features.to(torch.float32),
58
+ size=(target_h, target_w),
59
+ mode="bilinear",
60
+ align_corners=False,
61
+ ).to(image_features.dtype)
62
+
63
+ # Permute the dimensions back to (b, target_h, target_w, dim)
64
+ image_features = image_features.permute(0, 2, 3, 1).contiguous()
65
+
66
+ # Flatten the spatial dimensions (target_h, target_w) into a single dimension
67
+ image_features = image_features.flatten(1, 2)
68
+
69
+ return image_features
70
+
71
+ def _forward(self, images, interpolate_token=576):
72
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
73
+ image_features = self.vision_tower.forward(
74
+ images.to(device=self.device, dtype=self.dtype),
75
+ output_hidden_states=True,
76
+ ).hidden_states[-1]
77
+ interp_features = self.interpolate(image_features)
78
+ return interp_features
longvu/multimodal_encoder/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+
17
+ # pyre-unsafe
18
+ """Layer utilities."""
19
+
20
+ import cv2
21
+ import numpy as np
22
+ import torch
23
+
24
+
25
+ def init_cross_conv(blocks):
26
+ """Initialize convolutional cross attention."""
27
+ for m in blocks.modules():
28
+ if isinstance(m, torch.nn.Conv2d):
29
+ torch.nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
30
+ for blk in blocks:
31
+ torch.nn.init.constant_(blk.norm3.weight, 0)
32
+
33
+
34
+ def set_dropout(module, dropout):
35
+ """Initialize dropout."""
36
+ for m in [m for m in module.modules() if isinstance(m, torch.nn.Dropout)]:
37
+ m.p = dropout
38
+
39
+
40
+ def set_drop_path(blocks, drop_path):
41
+ """Initialize drop path."""
42
+ if not isinstance(blocks, torch.nn.ModuleList):
43
+ blocks = getattr(blocks, "blocks", getattr(blocks, "layers", None))
44
+ for i, blk in enumerate(blocks):
45
+ for m in [m for m in blk.modules() if type(m).__name__ == "DropPath"]:
46
+ m.p = i * drop_path / (len(blocks) - 1)
47
+
48
+
49
+ def set_sync_batch_norm(module, ddp_group):
50
+ """Set data parallelism group for sync batch norm."""
51
+ for m in module.modules():
52
+ if isinstance(m, torch.nn.SyncBatchNorm):
53
+ m.process_group = ddp_group
54
+
55
+
56
+ def resize_pos_embed(weight, out_len):
57
+ """Resize position embedding weights."""
58
+ out_h = out_w = int(out_len**0.5)
59
+ h = w = int(weight.shape[0] ** 0.5)
60
+ weight = weight.reshape((h, w, weight.shape[1]))
61
+ out_weight = [
62
+ cv2.resize(x, (out_w, out_h), interpolation=cv2.INTER_CUBIC)
63
+ for x in np.split(weight.astype("float32", copy=False), 4, axis=-1)
64
+ ]
65
+ out_weight = np.concatenate(out_weight, axis=-1)
66
+ return out_weight.reshape((-1, weight.shape[-1])).astype(weight.dtype, copy=False)
longvu/multimodal_projector/__pycache__/builder.cpython-310.pyc ADDED
Binary file (2.01 kB). View file
 
longvu/multimodal_projector/builder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ import re
3
+
4
+ import torch.nn as nn
5
+
6
+
7
+ class IdentityMap(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def forward(self, x, *args, **kwargs):
12
+ return x
13
+
14
+ @property
15
+ def config(self):
16
+ return {"mm_projector_type": "identity"}
17
+
18
+
19
+ class SimpleResBlock(nn.Module):
20
+ def __init__(self, channels):
21
+ super().__init__()
22
+ self.pre_norm = nn.LayerNorm(channels)
23
+
24
+ self.proj = nn.Sequential(
25
+ nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
26
+ )
27
+
28
+ def forward(self, x):
29
+ x = self.pre_norm(x)
30
+ return x + self.proj(x)
31
+
32
+
33
+ def build_vision_projector(config, delay_load=False, **kwargs):
34
+ projector_type = getattr(config, "mm_projector_type", "linear")
35
+ config.mm_hidden_size = 256
36
+
37
+ if projector_type == "linear":
38
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
39
+
40
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
41
+ if mlp_gelu_match:
42
+ mlp_depth = int(mlp_gelu_match.group(1))
43
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
44
+ for _ in range(1, mlp_depth):
45
+ modules.append(nn.GELU())
46
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
47
+ return nn.Sequential(*modules)
48
+
49
+ if projector_type == "identity":
50
+ return IdentityMap()
51
+
52
+ raise ValueError(f"Unknown projector type: {projector_type}")
longvu/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ from transformers import AutoConfig
3
+
4
+
5
+ def auto_upgrade(config):
6
+ cfg = AutoConfig.from_pretrained(config)
7
+ if "llava" in config and "llava" not in cfg.model_type:
8
+ assert cfg.model_type == "llama"
9
+ print(
10
+ "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
11
+ )
12
+ print(
13
+ "You must upgrade the checkpoint to the new code base (this can be done automatically)."
14
+ )
15
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
16
+ if confirm.lower() in ["y", "yes"]:
17
+ print("Upgrading checkpoint...")
18
+ assert len(cfg.architectures) == 1
19
+ setattr(cfg.__class__, "model_type", "llava")
20
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
21
+ cfg.save_pretrained(config)
22
+ print("Checkpoint upgraded.")
23
+ else:
24
+ print("Checkpoint upgrade aborted.")
25
+ exit(1)
longvu/vision_sampler.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+
8
+
9
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
10
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
11
+ """
12
+ grid_size: int of the grid height and width
13
+ return:
14
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
15
+ """
16
+ grid_h = np.arange(grid_size, dtype=np.float32)
17
+ grid_w = np.arange(grid_size, dtype=np.float32)
18
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
19
+ grid = np.stack(grid, axis=0)
20
+
21
+ grid = grid.reshape([2, 1, grid_size, grid_size])
22
+
23
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
24
+ if cls_token:
25
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
26
+ return pos_embed
27
+
28
+
29
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
30
+ assert embed_dim % 2 == 0
31
+
32
+ # use half of dimensions to encode grid_h
33
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
34
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
35
+
36
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
37
+ return emb
38
+
39
+
40
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
41
+ """
42
+ embed_dim: output dimension for each position
43
+ pos: a list of positions to be encoded: size (M,)
44
+ out: (M, D)
45
+ """
46
+ assert embed_dim % 2 == 0
47
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
48
+ omega /= embed_dim / 2.0
49
+ omega = 1.0 / 10000**omega # (D/2,)
50
+
51
+ pos = pos.reshape(-1) # (M,)
52
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
53
+
54
+ emb_sin = np.sin(out) # (M, D/2)
55
+ emb_cos = np.cos(out) # (M, D/2)
56
+
57
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
58
+ return emb
59
+
60
+
61
+ class CrossAttention(nn.Module):
62
+
63
+ def __init__(self, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False):
64
+ super().__init__()
65
+ self.hidden_dim = hidden_dim
66
+ self.num_heads = num_heads
67
+ self.head_dim = self.hidden_dim // self.num_heads
68
+
69
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
70
+ raise ValueError(
71
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
72
+ f" and `num_heads`: {self.num_heads})."
73
+ )
74
+
75
+ self.q_proj = nn.Sequential(
76
+ nn.LayerNorm(q_dim),
77
+ nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
78
+ )
79
+ self.k_proj = nn.Sequential(
80
+ nn.LayerNorm(kv_dim),
81
+ nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
82
+ )
83
+ self.v_proj = nn.Sequential(
84
+ nn.LayerNorm(kv_dim),
85
+ nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
86
+ )
87
+ self.o_proj = nn.Linear(
88
+ self.num_heads * self.head_dim, q_dim, bias=attention_bias
89
+ )
90
+
91
+ def forward(self, vision_latents, queries, attention_mask):
92
+
93
+ bsz, q_len, _ = queries.size()
94
+ bsz, v_len, _ = vision_latents.size()
95
+
96
+ query_states = self.q_proj(queries)
97
+ key_states = self.k_proj(vision_latents)
98
+ value_states = self.v_proj(vision_latents)
99
+
100
+ query_states = query_states.view(
101
+ bsz, q_len, self.num_heads, self.head_dim
102
+ ).transpose(1, 2)
103
+ key_states = key_states.view(
104
+ bsz, v_len, self.num_heads, self.head_dim
105
+ ).transpose(1, 2)
106
+ value_states = value_states.view(
107
+ bsz, v_len, self.num_heads, self.head_dim
108
+ ).transpose(1, 2)
109
+
110
+ if attention_mask is not None:
111
+ if attention_mask.size() != (bsz, 1, q_len, v_len):
112
+ raise ValueError(
113
+ f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
114
+ )
115
+
116
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
117
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
118
+ if query_states.device.type == "cuda" and attention_mask is not None:
119
+ query_states = query_states.contiguous()
120
+ key_states = key_states.contiguous()
121
+ value_states = value_states.contiguous()
122
+
123
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
124
+ query_states,
125
+ key_states,
126
+ value_states,
127
+ attn_mask=attention_mask,
128
+ )
129
+
130
+ attn_output = attn_output.transpose(1, 2).contiguous()
131
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
132
+
133
+ attn_output = self.o_proj(attn_output)
134
+
135
+ return attn_output
136
+
137
+
138
+ class AggregationBlock(nn.Module):
139
+ def __init__(
140
+ self, attention, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False
141
+ ):
142
+ super().__init__()
143
+ self.hidden_dim = hidden_dim
144
+ self.num_heads = num_heads
145
+ self.head_dim = self.hidden_dim // self.num_heads
146
+
147
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
148
+ raise ValueError(
149
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
150
+ f" and `num_heads`: {self.num_heads})."
151
+ )
152
+
153
+ self.attention = attention
154
+ if attention:
155
+ self.attention_layer = CrossAttention(
156
+ q_dim, kv_dim, hidden_dim, num_heads, attention_bias
157
+ )
158
+ else:
159
+ self.attention_layer = MLP(kv_dim, q_dim, q_dim)
160
+
161
+ def forward(self, vision_latents, queries, attention_mask):
162
+ if self.attention:
163
+ queries = self.attention_layer(vision_latents, queries, attention_mask)
164
+ else:
165
+ queries = self.attention_layer(vision_latents)
166
+
167
+ return queries
168
+
169
+
170
+ class MultiKVCrossAttention(nn.Module):
171
+
172
+ def __init__(self, q_dim, kv_dim_list, hidden_dim, num_heads, attention_bias=False):
173
+ super().__init__()
174
+
175
+ self.hidden_dim = hidden_dim
176
+ self.num_heads = num_heads
177
+ self.head_dim = self.hidden_dim // self.num_heads
178
+
179
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
180
+ raise ValueError(
181
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
182
+ f" and `num_heads`: {self.num_heads})."
183
+ )
184
+
185
+ self.q_proj = nn.Sequential(
186
+ nn.LayerNorm(q_dim),
187
+ nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
188
+ )
189
+ self.num_of_kvs = len(kv_dim_list)
190
+ for i, kv_dim in enumerate(kv_dim_list):
191
+ setattr(
192
+ self,
193
+ "k_proj_{}".format(i),
194
+ nn.Sequential(
195
+ nn.LayerNorm(kv_dim),
196
+ nn.Linear(
197
+ kv_dim, self.num_heads * self.head_dim, bias=attention_bias
198
+ ),
199
+ ),
200
+ )
201
+ setattr(
202
+ self,
203
+ "v_proj_{}".format(i),
204
+ nn.Sequential(
205
+ nn.LayerNorm(kv_dim),
206
+ nn.Linear(
207
+ kv_dim, self.num_heads * self.head_dim, bias=attention_bias
208
+ ),
209
+ ),
210
+ )
211
+ self.o_proj = nn.Linear(
212
+ self.num_heads * self.head_dim, q_dim, bias=attention_bias
213
+ )
214
+
215
+ def forward(
216
+ self,
217
+ queries,
218
+ *vision_latents_attention_mask_list,
219
+ ):
220
+
221
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
222
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
223
+
224
+ bsz, q_len, _ = queries.size()
225
+
226
+ query_states = self.q_proj(queries)
227
+ key_states = torch.cat(
228
+ [
229
+ getattr(self, "k_proj_{}".format(i))(vision_latents_list[i])
230
+ for i in range(self.num_of_kvs)
231
+ ],
232
+ dim=1,
233
+ )
234
+ value_states = torch.cat(
235
+ [
236
+ getattr(self, "v_proj_{}".format(i))(vision_latents_list[i])
237
+ for i in range(self.num_of_kvs)
238
+ ],
239
+ dim=1,
240
+ )
241
+
242
+ v_len = key_states.shape[1]
243
+
244
+ query_states = query_states.view(
245
+ bsz, q_len, self.num_heads, self.head_dim
246
+ ).transpose(1, 2)
247
+ key_states = key_states.view(
248
+ bsz, v_len, self.num_heads, self.head_dim
249
+ ).transpose(1, 2)
250
+ value_states = value_states.view(
251
+ bsz, v_len, self.num_heads, self.head_dim
252
+ ).transpose(1, 2)
253
+
254
+ # if kv_weight is not None:
255
+ # kv_weight = kv_weight.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
256
+
257
+ attention_mask = torch.cat(attention_mask_list, dim=-1)
258
+
259
+ if attention_mask is not None:
260
+ if attention_mask.size() != (bsz, 1, q_len, v_len):
261
+ raise ValueError(
262
+ f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
263
+ )
264
+
265
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
266
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
267
+ if query_states.device.type == "cuda" and attention_mask is not None:
268
+ query_states = query_states.contiguous()
269
+ key_states = key_states.contiguous()
270
+ value_states = value_states.contiguous()
271
+
272
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
273
+ query_states,
274
+ key_states,
275
+ value_states,
276
+ attn_mask=attention_mask,
277
+ )
278
+ # attn_output = spda(
279
+ # query_states,
280
+ # key_states,
281
+ # value_states,
282
+ # attn_mask=attention_mask,
283
+ # additional_score=kv_weight
284
+ # )
285
+
286
+ attn_output = attn_output.transpose(1, 2).contiguous()
287
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
288
+
289
+ attn_output = self.o_proj(attn_output)
290
+
291
+ return attn_output
292
+
293
+
294
+ class MLP(nn.Module):
295
+ def __init__(self, d_in, d_hidden, d_out):
296
+ super().__init__()
297
+ self.linear_1 = nn.Linear(d_in, d_hidden, bias=False)
298
+ self.act = nn.GELU()
299
+ self.linear_2 = nn.Linear(d_hidden, d_out, bias=False)
300
+
301
+ def forward(self, x):
302
+ return self.linear_2(self.act(self.linear_1(x)))
303
+
304
+
305
+ class VisionCrossAttentionLayer(nn.Module):
306
+ def __init__(
307
+ self,
308
+ q_dim,
309
+ context_dim,
310
+ kv_dim_list,
311
+ kv_size_list,
312
+ hidden_dim=1024,
313
+ layer_idx=0,
314
+ ):
315
+ super().__init__()
316
+ num_heads = 16
317
+ self.num_of_kvs = len(kv_dim_list)
318
+
319
+ self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
320
+ self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
321
+ # if self.num_of_kvs > 1:
322
+ # self.weight_mlp = MLP(q_dim+hidden_dim, hidden_dim, self.num_of_kvs)
323
+ # self.tower_weight = nn.Parameter(torch.zeros((self.num_of_kvs)))
324
+ self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
325
+
326
+ self.norm = nn.LayerNorm(hidden_dim)
327
+
328
+ self.cross_attn = MultiKVCrossAttention(
329
+ hidden_dim, kv_dim_list, hidden_dim, num_heads
330
+ )
331
+ self.kv_size_list = kv_size_list
332
+ for i, kv_size in enumerate(kv_size_list):
333
+ if kv_size > 1:
334
+ setattr(
335
+ self,
336
+ "pos_embed_{}".format(i),
337
+ nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
338
+ )
339
+ # self.register_buffer("pos_embed_{}".format(i), torch.from_numpy(get_2d_sincos_pos_embed(hidden_dim, kv_size)).float(), persistent=False)
340
+
341
+ def forward(
342
+ self,
343
+ queries,
344
+ context_feature,
345
+ *vision_latents_attention_mask_list,
346
+ ) -> torch.FloatTensor:
347
+
348
+ residual = queries
349
+ # queries = self.proj_in(queries)
350
+ context_feature = self.proj_context(context_feature)
351
+ # queries = queries + context_feature
352
+ queries = torch.cat([queries, context_feature], -1)
353
+
354
+ # if self.num_of_kvs > 1:
355
+ # kv_weight = self.weight_mlp(queries) # B * 1 * num_tower
356
+ # kv_weight = kv_weight + self.tower_weight.view(1, 1, -1)
357
+ # kv_weight = kv_weight.softmax(-1)
358
+ # kv_number_list = [size**2 for size in self.kv_size_list]
359
+ # kv_weight = torch.repeat_interleave(kv_weight, torch.tensor(kv_number_list).to(kv_weight.device), dim=-1)
360
+ # else:
361
+ # kv_weight = None
362
+
363
+ queries = self.proj_in(queries)
364
+
365
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
366
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
367
+
368
+ attention_mask_list_reshaped = []
369
+ if attention_mask_list is not None:
370
+ for attention_mask in attention_mask_list:
371
+ attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
372
+ attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
373
+ attention_mask_list_reshaped.append(attention_mask)
374
+
375
+ vision_latents_pos_list = []
376
+ for i, vision_latents in enumerate(vision_latents_list):
377
+ if vision_latents.shape[1] > 1:
378
+ vision_latents_pos_list.append(
379
+ vision_latents
380
+ + getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
381
+ vision_latents.dtype
382
+ )
383
+ )
384
+ else:
385
+ vision_latents_pos_list.append(vision_latents)
386
+
387
+ # Cross Attention
388
+ attention_output = self.cross_attn(
389
+ queries, *vision_latents_pos_list, *attention_mask_list_reshaped
390
+ )
391
+
392
+ # attention_output = (attention_output * combination_weight).sum(2)
393
+ queries = queries + attention_output
394
+
395
+ queries = self.norm(queries)
396
+
397
+ queries = self.proj_out(queries)
398
+
399
+ queries = queries + residual
400
+
401
+ return queries
402
+
403
+
404
+ class VisionAggregationLayer(nn.Module):
405
+ def __init__(
406
+ self,
407
+ q_dim,
408
+ context_dim,
409
+ kv_dim_list,
410
+ kv_size_list,
411
+ hidden_dim=1024,
412
+ layer_idx=0,
413
+ ):
414
+ super().__init__()
415
+ num_heads = 16
416
+ self.num_of_kvs = len(kv_dim_list)
417
+
418
+ self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
419
+ self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
420
+
421
+ self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
422
+
423
+ self.norm = nn.LayerNorm(hidden_dim)
424
+
425
+ if self.num_of_kvs > 1:
426
+ self.weight_mlp = MLP(q_dim + hidden_dim, hidden_dim, self.num_of_kvs)
427
+
428
+ for i, kv_size in enumerate(kv_size_list):
429
+ if kv_size > 1:
430
+ setattr(
431
+ self,
432
+ "pos_embed_{}".format(i),
433
+ nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
434
+ )
435
+ setattr(
436
+ self,
437
+ "aggregate_{}".format(i),
438
+ AggregationBlock(
439
+ True, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
440
+ ),
441
+ )
442
+ else:
443
+ setattr(
444
+ self,
445
+ "aggregate_{}".format(i),
446
+ AggregationBlock(
447
+ False, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
448
+ ),
449
+ )
450
+
451
+ def forward(
452
+ self,
453
+ queries,
454
+ context_feature,
455
+ *vision_latents_attention_mask_list,
456
+ ) -> torch.FloatTensor:
457
+
458
+ residual = queries
459
+ # queries = self.proj_in(queries)
460
+ context_feature = self.proj_context(context_feature)
461
+ # queries = queries + context_feature
462
+ queries = torch.cat([queries, context_feature], -1)
463
+
464
+ if self.num_of_kvs > 1:
465
+ combination_weight = self.weight_mlp(queries).softmax(
466
+ -1
467
+ ) # B * 1 * num_tower
468
+ combination_weight = combination_weight.unsqueeze(-1)
469
+ else:
470
+ combination_weight = 1
471
+
472
+ queries = self.proj_in(queries)
473
+
474
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
475
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
476
+
477
+ attention_mask_list_reshaped = []
478
+ if attention_mask_list is not None:
479
+ for attention_mask in attention_mask_list:
480
+ attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
481
+ attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
482
+ attention_mask_list_reshaped.append(attention_mask)
483
+
484
+ vision_latents_pos_list = []
485
+ for i, vision_latents in enumerate(vision_latents_list):
486
+ if vision_latents.shape[1] > 1:
487
+ vision_latents_pos_list.append(
488
+ vision_latents
489
+ + getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
490
+ vision_latents.dtype
491
+ )
492
+ )
493
+ else:
494
+ vision_latents_pos_list.append(vision_latents)
495
+
496
+ aggregated_vision_latents_list = []
497
+ for i, (vision_latents, attention_mask) in enumerate(
498
+ zip(vision_latents_pos_list, attention_mask_list_reshaped)
499
+ ):
500
+ aggregated_vision_latents_list.append(
501
+ getattr(self, "aggregate_{}".format(i))(
502
+ vision_latents, queries, attention_mask
503
+ )
504
+ )
505
+
506
+ aggregated_vision_latents = torch.stack(aggregated_vision_latents_list, 2)
507
+
508
+ queries = queries + (aggregated_vision_latents * combination_weight).sum(2)
509
+
510
+ queries = self.norm(queries)
511
+
512
+ queries = self.proj_out(queries)
513
+
514
+ queries = queries + residual
515
+
516
+ return queries
517
+
518
+
519
+ class VisionTokenSampler(nn.Module):
520
+ def __init__(
521
+ self,
522
+ q_dim,
523
+ context_dim,
524
+ kv_dim_list,
525
+ kv_size_list,
526
+ vision_hidden_size,
527
+ num_of_layers=1,
528
+ layer_type="joint",
529
+ ):
530
+ super().__init__()
531
+ assert layer_type in ["joint", "sep"]
532
+ if layer_type == "joint":
533
+ self.layers = nn.ModuleList(
534
+ [
535
+ VisionCrossAttentionLayer(
536
+ q_dim,
537
+ context_dim,
538
+ kv_dim_list,
539
+ kv_size_list,
540
+ vision_hidden_size,
541
+ idx,
542
+ )
543
+ for idx in range(num_of_layers)
544
+ ]
545
+ )
546
+ else:
547
+ self.layers = nn.ModuleList(
548
+ [
549
+ VisionAggregationLayer(
550
+ q_dim,
551
+ context_dim,
552
+ kv_dim_list,
553
+ kv_size_list,
554
+ vision_hidden_size,
555
+ idx,
556
+ )
557
+ for idx in range(num_of_layers)
558
+ ]
559
+ )
560
+
561
+ def forward(self, queries, context_feature, *vision_latents_attention_mask_list):
562
+ for layer in self.layers:
563
+ queries = layer(
564
+ queries, context_feature, *vision_latents_attention_mask_list
565
+ )
566
+ return queries
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub==0.22.2
2
+ torch==2.1.2
3
+ numpy==1.26.4
4
+ torchvision
5
+ transformers==4.42.4
6
+ tokenizers==0.15.2
7
+ sentencepiece==0.1.99
8
+ shortuuid
9
+ accelerate==0.34.2
10
+ peft==0.4.0
11
+ bitsandbytes==0.41.0
12
+ pydantic<2,>=1
13
+ markdown2
14
+ scikit-learn==1.2.2
15
+ gradio==3.35.2
16
+ gradio_client==0.2.9
17
+ requests
18
+ httpx==0.24.0
19
+ uvicorn
20
+ fastapi
21
+ einops==0.6.1
22
+ einops-exts==0.0.4
23
+ timm==0.9.16
24
+ decord
25
+ ninja
26
+ deepspeed==0.12.2
27
+ protobuf
28
+ iopath