Spaces:
Runtime error
Runtime error
Vision-CAIR
commited on
Commit
•
85efb5b
1
Parent(s):
679922c
Upload 39 files
Browse files- .gitattributes +2 -0
- app.py +317 -0
- examples/video1.mp4 +3 -0
- examples/video2.mp4 +3 -0
- inference.py +94 -0
- longvu/.DS_Store +0 -0
- longvu/__init__.py +3 -0
- longvu/apply_delta.py +59 -0
- longvu/builder.py +249 -0
- longvu/cambrian_arch.py +1705 -0
- longvu/consolidate.py +33 -0
- longvu/constants.py +13 -0
- longvu/conversation.py +606 -0
- longvu/file_io.py +11 -0
- longvu/language_model/__pycache__/cambrian_llama.cpython-310.pyc +0 -0
- longvu/language_model/__pycache__/cambrian_qwen.cpython-310.pyc +0 -0
- longvu/language_model/cambrian_llama.py +546 -0
- longvu/language_model/cambrian_qwen.py +471 -0
- longvu/make_delta.py +66 -0
- longvu/mm_datautils.py +1688 -0
- longvu/mm_utils.py +327 -0
- longvu/multimodal_encoder/__pycache__/base_encoder.cpython-310.pyc +0 -0
- longvu/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
- longvu/multimodal_encoder/__pycache__/dino_encoder.cpython-310.pyc +0 -0
- longvu/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc +0 -0
- longvu/multimodal_encoder/base_encoder.py +135 -0
- longvu/multimodal_encoder/builder.py +37 -0
- longvu/multimodal_encoder/dino_encoder.py +131 -0
- longvu/multimodal_encoder/drop.py +41 -0
- longvu/multimodal_encoder/image.py +80 -0
- longvu/multimodal_encoder/logging.py +131 -0
- longvu/multimodal_encoder/loss.py +96 -0
- longvu/multimodal_encoder/registry.py +56 -0
- longvu/multimodal_encoder/siglip_encoder.py +78 -0
- longvu/multimodal_encoder/utils.py +66 -0
- longvu/multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
- longvu/multimodal_projector/builder.py +52 -0
- longvu/utils.py +25 -0
- longvu/vision_sampler.py +566 -0
- requirements.txt +28 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/video1.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/video2.mp4 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import traceback
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
import sys
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from longvu.builder import load_pretrained_model
|
15 |
+
from longvu.constants import (
|
16 |
+
DEFAULT_IMAGE_TOKEN,
|
17 |
+
IMAGE_TOKEN_INDEX,
|
18 |
+
)
|
19 |
+
from longvu.conversation import conv_templates, SeparatorStyle
|
20 |
+
from longvu.mm_datautils import (
|
21 |
+
KeywordsStoppingCriteria,
|
22 |
+
process_images,
|
23 |
+
tokenizer_image_token,
|
24 |
+
)
|
25 |
+
from decord import cpu, VideoReader
|
26 |
+
|
27 |
+
|
28 |
+
title_markdown = ("""
|
29 |
+
LongVU
|
30 |
+
""")
|
31 |
+
|
32 |
+
block_css = """
|
33 |
+
#buttons button {
|
34 |
+
min-width: min(120px,100%);
|
35 |
+
color: #9C276A
|
36 |
+
}
|
37 |
+
"""
|
38 |
+
|
39 |
+
plum_color = gr.themes.colors.Color(
|
40 |
+
name='plum',
|
41 |
+
c50='#F8E4EF',
|
42 |
+
c100='#E9D0DE',
|
43 |
+
c200='#DABCCD',
|
44 |
+
c300='#CBA8BC',
|
45 |
+
c400='#BC94AB',
|
46 |
+
c500='#AD809A',
|
47 |
+
c600='#9E6C89',
|
48 |
+
c700='#8F5878',
|
49 |
+
c800='#804467',
|
50 |
+
c900='#713056',
|
51 |
+
c950='#662647',
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
class Chat:
|
56 |
+
|
57 |
+
def __init__(self):
|
58 |
+
self.version = "qwen"
|
59 |
+
model_name = "cambrian_qwen"
|
60 |
+
model_path = "./checkpoints/longvu_qwen"
|
61 |
+
device = "cuda:7"
|
62 |
+
|
63 |
+
self.tokenizer, self.model, self.processor, _ = load_pretrained_model(model_path, None, model_name, device=device)
|
64 |
+
self.model.eval()
|
65 |
+
|
66 |
+
def remove_after_last_dot(self, s):
|
67 |
+
last_dot_index = s.rfind('.')
|
68 |
+
if last_dot_index == -1:
|
69 |
+
return s
|
70 |
+
return s[:last_dot_index + 1]
|
71 |
+
|
72 |
+
@spaces.GPU(duration=120)
|
73 |
+
@torch.inference_mode()
|
74 |
+
def generate(self, data: list, message, temperature, top_p, max_output_tokens):
|
75 |
+
# TODO: support multiple turns of conversation.
|
76 |
+
assert len(data) == 1
|
77 |
+
|
78 |
+
tensor, image_sizes, modal = data[0]
|
79 |
+
|
80 |
+
conv = conv_templates[self.version].copy()
|
81 |
+
|
82 |
+
if isinstance(message, str):
|
83 |
+
conv.append_message("user", DEFAULT_IMAGE_TOKEN + '\n' + message)
|
84 |
+
elif isinstance(message, list):
|
85 |
+
if DEFAULT_IMAGE_TOKEN not in message[0]['content']:
|
86 |
+
message[0]['content'] = DEFAULT_IMAGE_TOKEN + '\n' + message[0]['content']
|
87 |
+
for mes in message:
|
88 |
+
conv.append_message(mes["role"], mes["content"])
|
89 |
+
|
90 |
+
conv.append_message("assistant", None)
|
91 |
+
|
92 |
+
prompt = conv.get_prompt()
|
93 |
+
|
94 |
+
input_ids = (
|
95 |
+
tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
96 |
+
.unsqueeze(0)
|
97 |
+
.to(self.model.device)
|
98 |
+
)
|
99 |
+
|
100 |
+
if "llama3" in self.version:
|
101 |
+
input_ids = input_ids[0][1:].unsqueeze(0) # remove bos
|
102 |
+
|
103 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
104 |
+
keywords = [stop_str]
|
105 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
|
106 |
+
with torch.inference_mode():
|
107 |
+
output_ids = self.model.generate(
|
108 |
+
input_ids,
|
109 |
+
images=tensor,
|
110 |
+
image_sizes=image_sizes,
|
111 |
+
do_sample=True,
|
112 |
+
temperature=temperature,
|
113 |
+
max_new_tokens=max_output_tokens,
|
114 |
+
use_cache=True,
|
115 |
+
top_p=top_p,
|
116 |
+
stopping_criteria=[stopping_criteria],
|
117 |
+
)
|
118 |
+
|
119 |
+
pred = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
120 |
+
|
121 |
+
return self.remove_after_last_dot(pred)
|
122 |
+
|
123 |
+
|
124 |
+
@spaces.GPU(duration=120)
|
125 |
+
def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
|
126 |
+
data = []
|
127 |
+
|
128 |
+
processor = handler.processor
|
129 |
+
try:
|
130 |
+
if image is not None:
|
131 |
+
data.append((processor['image'](image).to(handler.model.device, dtype=dtype), None, '<image>'))
|
132 |
+
elif video is not None:
|
133 |
+
vr = VideoReader(video, ctx=cpu(0), num_threads=1)
|
134 |
+
fps = float(vr.get_avg_fps())
|
135 |
+
frame_indices = np.array(
|
136 |
+
[
|
137 |
+
i
|
138 |
+
for i in range(
|
139 |
+
0,
|
140 |
+
len(vr),
|
141 |
+
round(fps),
|
142 |
+
)
|
143 |
+
]
|
144 |
+
)
|
145 |
+
video_tensor = []
|
146 |
+
for frame_index in frame_indices:
|
147 |
+
img = vr[frame_index].asnumpy()
|
148 |
+
video_tensor.append(img)
|
149 |
+
video_tensor = np.stack(video_tensor)
|
150 |
+
image_sizes = [video_tensor[0].shape[:2]]
|
151 |
+
video_tensor = process_images(video_tensor, processor, handler.model.config)
|
152 |
+
video_tensor = [item.unsqueeze(0).to(handler.model.device, dtype=dtype) for item in video_tensor]
|
153 |
+
data.append((video_tensor, image_sizes, '<video>'))
|
154 |
+
elif image is None and video is None:
|
155 |
+
data.append((None, None, '<text>'))
|
156 |
+
else:
|
157 |
+
raise NotImplementedError("Not support image and video at the same time")
|
158 |
+
except Exception as e:
|
159 |
+
traceback.print_exc()
|
160 |
+
return gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), message, chatbot
|
161 |
+
|
162 |
+
assert len(message) % 2 == 0, "The message should be a pair of user and system message."
|
163 |
+
|
164 |
+
show_images = ""
|
165 |
+
if image is not None:
|
166 |
+
show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">'
|
167 |
+
if video is not None:
|
168 |
+
show_images += f'<video controls playsinline width="300" style="display: inline-block;" src="./file={video}"></video>'
|
169 |
+
|
170 |
+
one_turn_chat = [textbox_in, None]
|
171 |
+
|
172 |
+
# 1. first run case
|
173 |
+
if len(chatbot) == 0:
|
174 |
+
one_turn_chat[0] += "\n" + show_images
|
175 |
+
# 2. not first run case
|
176 |
+
else:
|
177 |
+
# scanning the last image or video
|
178 |
+
length = len(chatbot)
|
179 |
+
for i in range(length - 1, -1, -1):
|
180 |
+
previous_image = re.findall(r'<img src="./file=(.+?)"', chatbot[i][0])
|
181 |
+
previous_video = re.findall(r'<video controls playsinline width="500" style="display: inline-block;" src="./file=(.+?)"', chatbot[i][0])
|
182 |
+
|
183 |
+
if len(previous_image) > 0:
|
184 |
+
previous_image = previous_image[-1]
|
185 |
+
# 2.1 new image append or pure text input will start a new conversation
|
186 |
+
if (video is not None) or (image is not None and os.path.basename(previous_image) != os.path.basename(image)):
|
187 |
+
message.clear()
|
188 |
+
one_turn_chat[0] += "\n" + show_images
|
189 |
+
break
|
190 |
+
elif len(previous_video) > 0:
|
191 |
+
previous_video = previous_video[-1]
|
192 |
+
# 2.2 new video append or pure text input will start a new conversation
|
193 |
+
if image is not None or (video is not None and os.path.basename(previous_video) != os.path.basename(video)):
|
194 |
+
message.clear()
|
195 |
+
one_turn_chat[0] += "\n" + show_images
|
196 |
+
break
|
197 |
+
|
198 |
+
message.append({'role': 'user', 'content': textbox_in})
|
199 |
+
text_en_out = handler.generate(data, message, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens)
|
200 |
+
message.append({'role': 'assistant', 'content': text_en_out})
|
201 |
+
|
202 |
+
one_turn_chat[1] = text_en_out
|
203 |
+
chatbot.append(one_turn_chat)
|
204 |
+
|
205 |
+
return gr.update(value=image, interactive=True), gr.update(value=video, interactive=True), message, chatbot
|
206 |
+
|
207 |
+
|
208 |
+
def regenerate(message, chatbot):
|
209 |
+
message.pop(-1), message.pop(-1)
|
210 |
+
chatbot.pop(-1)
|
211 |
+
return message, chatbot
|
212 |
+
|
213 |
+
|
214 |
+
def clear_history(message, chatbot):
|
215 |
+
message.clear(), chatbot.clear()
|
216 |
+
return (gr.update(value=None, interactive=True),
|
217 |
+
gr.update(value=None, interactive=True),
|
218 |
+
message, chatbot,
|
219 |
+
gr.update(value=None, interactive=True))
|
220 |
+
|
221 |
+
handler = Chat()
|
222 |
+
|
223 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
224 |
+
|
225 |
+
theme = gr.themes.Default(primary_hue=plum_color)
|
226 |
+
# theme.update_color("primary", plum_color.c500)
|
227 |
+
theme.set(slider_color="#9C276A")
|
228 |
+
theme.set(block_title_text_color="#9C276A")
|
229 |
+
theme.set(block_label_text_color="#9C276A")
|
230 |
+
theme.set(button_primary_text_color="#9C276A")
|
231 |
+
|
232 |
+
with gr.Blocks(title='LongVU', theme=theme, css=block_css) as demo:
|
233 |
+
gr.Markdown(title_markdown)
|
234 |
+
message = gr.State([])
|
235 |
+
|
236 |
+
with gr.Row():
|
237 |
+
with gr.Column(scale=3):
|
238 |
+
image = gr.State(None)
|
239 |
+
video = gr.Video(label="Input Video")
|
240 |
+
|
241 |
+
with gr.Accordion("Parameters", open=True) as parameter_row:
|
242 |
+
|
243 |
+
temperature = gr.Slider(
|
244 |
+
minimum=0.1,
|
245 |
+
maximum=1.0,
|
246 |
+
value=0.2,
|
247 |
+
step=0.1,
|
248 |
+
interactive=True,
|
249 |
+
label="Temperature",
|
250 |
+
)
|
251 |
+
|
252 |
+
top_p = gr.Slider(
|
253 |
+
minimum=0.0,
|
254 |
+
maximum=1.0,
|
255 |
+
value=0.7,
|
256 |
+
step=0.1,
|
257 |
+
interactive=True,
|
258 |
+
label="Top P",
|
259 |
+
)
|
260 |
+
|
261 |
+
max_output_tokens = gr.Slider(
|
262 |
+
minimum=64,
|
263 |
+
maximum=512,
|
264 |
+
value=128,
|
265 |
+
step=64,
|
266 |
+
interactive=True,
|
267 |
+
label="Max output tokens",
|
268 |
+
)
|
269 |
+
|
270 |
+
with gr.Column(scale=7):
|
271 |
+
chatbot = gr.Chatbot(label="LongVU", bubble_full_width=True, height=420)
|
272 |
+
with gr.Row():
|
273 |
+
with gr.Column(scale=8):
|
274 |
+
textbox.render()
|
275 |
+
with gr.Column(scale=1, min_width=50):
|
276 |
+
submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
|
277 |
+
with gr.Row(elem_id="buttons") as button_row:
|
278 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
|
279 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
|
280 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
|
281 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
|
282 |
+
|
283 |
+
with gr.Row():
|
284 |
+
with gr.Column():
|
285 |
+
gr.Examples(
|
286 |
+
examples=[
|
287 |
+
[
|
288 |
+
f"./examples/video1.mp4",
|
289 |
+
"Describe this video in detail.",
|
290 |
+
],
|
291 |
+
[
|
292 |
+
f"./examples/video2.mp4",
|
293 |
+
"Which country does the boy in the video probably come from?",
|
294 |
+
]
|
295 |
+
],
|
296 |
+
inputs=[video, textbox],
|
297 |
+
)
|
298 |
+
|
299 |
+
submit_btn.click(
|
300 |
+
generate,
|
301 |
+
[image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
|
302 |
+
[image, video, message, chatbot])
|
303 |
+
|
304 |
+
regenerate_btn.click(
|
305 |
+
regenerate,
|
306 |
+
[message, chatbot],
|
307 |
+
[message, chatbot]).then(
|
308 |
+
generate,
|
309 |
+
[image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
|
310 |
+
[image, video, message, chatbot])
|
311 |
+
|
312 |
+
clear_btn.click(
|
313 |
+
clear_history,
|
314 |
+
[message, chatbot],
|
315 |
+
[image, video, message, chatbot, textbox])
|
316 |
+
|
317 |
+
demo.launch()
|
examples/video1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0431514403beba3c269a318b4e5eb6c08cd6940edb10d5014b4745c5fee31ac0
|
3 |
+
size 1171735
|
examples/video2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3069129741c1ed79524a3eeb138ce4f99a9b7fedf36652afd8ad7d83b1d6008b
|
3 |
+
size 1606730
|
inference.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from longvu.builder import load_pretrained_model
|
6 |
+
from longvu.constants import (
|
7 |
+
DEFAULT_IMAGE_TOKEN,
|
8 |
+
IMAGE_TOKEN_INDEX,
|
9 |
+
)
|
10 |
+
from longvu.conversation import conv_templates, SeparatorStyle
|
11 |
+
from longvu.mm_datautils import (
|
12 |
+
KeywordsStoppingCriteria,
|
13 |
+
process_images,
|
14 |
+
tokenizer_image_token,
|
15 |
+
)
|
16 |
+
from decord import cpu, VideoReader
|
17 |
+
|
18 |
+
version = "qwen"
|
19 |
+
model_name = "cambrian_qwen"
|
20 |
+
input_model_local_path = "./checkpoints/longvu_qwen"
|
21 |
+
|
22 |
+
device = "cuda:7"
|
23 |
+
|
24 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
25 |
+
input_model_local_path, None, model_name, device=device
|
26 |
+
)
|
27 |
+
model.get_model().config.tokenizer_model_max_length = 8192
|
28 |
+
model.get_model().config.inference_max_length = 128
|
29 |
+
model.config.use_cache = True
|
30 |
+
print(model.device)
|
31 |
+
|
32 |
+
model.eval()
|
33 |
+
|
34 |
+
video_path = "./examples/video1.mp4"
|
35 |
+
qs = "Describe this video in detail"
|
36 |
+
|
37 |
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
38 |
+
fps = float(vr.get_avg_fps())
|
39 |
+
frame_indices = np.array(
|
40 |
+
[
|
41 |
+
i
|
42 |
+
for i in range(
|
43 |
+
0,
|
44 |
+
len(vr),
|
45 |
+
round(fps),
|
46 |
+
)
|
47 |
+
]
|
48 |
+
)
|
49 |
+
video = []
|
50 |
+
for frame_index in frame_indices:
|
51 |
+
img = vr[frame_index].asnumpy()
|
52 |
+
video.append(img)
|
53 |
+
video = np.stack(video)
|
54 |
+
image_sizes = [video[0].shape[:2]]
|
55 |
+
video = process_images(video, image_processor, model.config)
|
56 |
+
video = [item.unsqueeze(0) for item in video]
|
57 |
+
|
58 |
+
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
|
59 |
+
|
60 |
+
conv = conv_templates[version].copy()
|
61 |
+
conv.append_message(conv.roles[0], qs)
|
62 |
+
conv.append_message(conv.roles[1], None)
|
63 |
+
prompt = conv.get_prompt()
|
64 |
+
|
65 |
+
input_ids = (
|
66 |
+
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
67 |
+
.unsqueeze(0)
|
68 |
+
.to(model.device)
|
69 |
+
)
|
70 |
+
|
71 |
+
if "llama3" in version:
|
72 |
+
input_ids = input_ids[0][1:].unsqueeze(0) # remove bos
|
73 |
+
|
74 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
75 |
+
keywords = [stop_str]
|
76 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
77 |
+
with torch.inference_mode():
|
78 |
+
output_ids = model.generate(
|
79 |
+
input_ids,
|
80 |
+
images=video,
|
81 |
+
image_sizes=image_sizes,
|
82 |
+
do_sample=False,
|
83 |
+
temperature=0.2,
|
84 |
+
max_new_tokens=128,
|
85 |
+
use_cache=True,
|
86 |
+
stopping_criteria=[stopping_criteria],
|
87 |
+
)
|
88 |
+
|
89 |
+
if isinstance(output_ids, tuple):
|
90 |
+
output_ids = output_ids[0]
|
91 |
+
|
92 |
+
pred = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
93 |
+
|
94 |
+
print("pred: ", pred, flush=True)
|
longvu/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
longvu/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# pyre-unsafe
|
2 |
+
from .language_model.cambrian_qwen import CambrianQwenModel
|
3 |
+
from .language_model.cambrian_llama import CambrianLlamaModel
|
longvu/apply_delta.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pyre-unsafe
|
2 |
+
"""
|
3 |
+
Usage:
|
4 |
+
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
+
|
13 |
+
from . import LlavaLlamaForCausalLM
|
14 |
+
|
15 |
+
|
16 |
+
def apply_delta(base_model_path, target_model_path, delta_path):
|
17 |
+
print("Loading base model")
|
18 |
+
base = AutoModelForCausalLM.from_pretrained(
|
19 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
20 |
+
)
|
21 |
+
|
22 |
+
print("Loading delta")
|
23 |
+
delta = LlavaLlamaForCausalLM.from_pretrained(
|
24 |
+
delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
25 |
+
)
|
26 |
+
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
27 |
+
|
28 |
+
print("Applying delta")
|
29 |
+
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
30 |
+
if name not in base.state_dict():
|
31 |
+
assert name in [
|
32 |
+
"model.mm_projector.weight",
|
33 |
+
"model.mm_projector.bias",
|
34 |
+
], f"{name} not in base model"
|
35 |
+
continue
|
36 |
+
if param.data.shape == base.state_dict()[name].shape:
|
37 |
+
param.data += base.state_dict()[name]
|
38 |
+
else:
|
39 |
+
assert name in [
|
40 |
+
"model.embed_tokens.weight",
|
41 |
+
"lm_head.weight",
|
42 |
+
], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
|
43 |
+
bparam = base.state_dict()[name]
|
44 |
+
param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
|
45 |
+
|
46 |
+
print("Saving target model")
|
47 |
+
delta.save_pretrained(target_model_path)
|
48 |
+
delta_tokenizer.save_pretrained(target_model_path)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
parser = argparse.ArgumentParser()
|
53 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
54 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
55 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
56 |
+
|
57 |
+
args = parser.parse_args()
|
58 |
+
|
59 |
+
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
|
longvu/builder.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pyre-unsafe
|
16 |
+
|
17 |
+
|
18 |
+
import os
|
19 |
+
import shutil
|
20 |
+
import warnings
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from longvu.constants import (
|
24 |
+
DEFAULT_IM_END_TOKEN,
|
25 |
+
DEFAULT_IM_START_TOKEN,
|
26 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
27 |
+
)
|
28 |
+
|
29 |
+
from longvu.language_model.cambrian_llama import CambrianLlamaForCausalLM
|
30 |
+
from longvu.language_model.cambrian_qwen import CambrianQwenForCausalLM
|
31 |
+
|
32 |
+
from transformers import (
|
33 |
+
AutoConfig,
|
34 |
+
AutoModelForCausalLM,
|
35 |
+
AutoTokenizer,
|
36 |
+
BitsAndBytesConfig,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
def load_pretrained_model(
|
41 |
+
model_path,
|
42 |
+
model_base,
|
43 |
+
model_name,
|
44 |
+
load_8bit=False,
|
45 |
+
load_4bit=False,
|
46 |
+
device_map="auto",
|
47 |
+
device="cuda",
|
48 |
+
use_flash_attn=False,
|
49 |
+
model_args=None,
|
50 |
+
**kwargs,
|
51 |
+
):
|
52 |
+
kwargs = {"device_map": device_map, **kwargs}
|
53 |
+
|
54 |
+
if device != "cuda":
|
55 |
+
kwargs["device_map"] = {"": device}
|
56 |
+
|
57 |
+
if load_8bit:
|
58 |
+
kwargs["load_in_8bit"] = True
|
59 |
+
elif load_4bit:
|
60 |
+
kwargs["load_in_4bit"] = True
|
61 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
62 |
+
load_in_4bit=True,
|
63 |
+
bnb_4bit_compute_dtype=torch.float16,
|
64 |
+
bnb_4bit_use_double_quant=True,
|
65 |
+
bnb_4bit_quant_type="nf4",
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
kwargs["torch_dtype"] = torch.float16
|
69 |
+
|
70 |
+
if use_flash_attn:
|
71 |
+
kwargs["attn_implementation"] = "flash_attention_2"
|
72 |
+
|
73 |
+
if "cambrian" in model_name.lower():
|
74 |
+
# Load Cambrian model
|
75 |
+
if "lora" in model_name.lower() and model_base is None:
|
76 |
+
warnings.warn(
|
77 |
+
"There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
|
78 |
+
)
|
79 |
+
if "lora" in model_name.lower() and model_base is not None:
|
80 |
+
# pyre-fixme[21]: Could not find module
|
81 |
+
# `core_ai.llava.language_model.cambrian_llama`.
|
82 |
+
from core_ai.llava.language_model.cambrian_llama import CambrianConfig
|
83 |
+
|
84 |
+
lora_cfg_pretrained = CambrianConfig.from_pretrained(model_path)
|
85 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
86 |
+
print("Loading Cambrian from base model...")
|
87 |
+
model = CambrianLlamaForCausalLM.from_pretrained(
|
88 |
+
model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs
|
89 |
+
)
|
90 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
91 |
+
if model.lm_head.weight.shape[0] != token_num:
|
92 |
+
model.lm_head.weight = torch.nn.Parameter(
|
93 |
+
torch.empty(
|
94 |
+
token_num, tokem_dim, device=model.device, dtype=model.dtype
|
95 |
+
)
|
96 |
+
)
|
97 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(
|
98 |
+
torch.empty(
|
99 |
+
token_num, tokem_dim, device=model.device, dtype=model.dtype
|
100 |
+
)
|
101 |
+
)
|
102 |
+
|
103 |
+
print("Loading additional Cambrian weights...")
|
104 |
+
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
|
105 |
+
non_lora_trainables = torch.load(
|
106 |
+
os.path.join(model_path, "non_lora_trainables.bin"),
|
107 |
+
map_location="cpu",
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
# this is probably from HF Hub
|
111 |
+
from huggingface_hub import hf_hub_download
|
112 |
+
|
113 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
114 |
+
cache_file = hf_hub_download(
|
115 |
+
repo_id=repo_id, filename=filename, subfolder=subfolder
|
116 |
+
)
|
117 |
+
return torch.load(cache_file, map_location="cpu")
|
118 |
+
|
119 |
+
non_lora_trainables = load_from_hf(
|
120 |
+
model_path, "non_lora_trainables.bin"
|
121 |
+
)
|
122 |
+
non_lora_trainables = {
|
123 |
+
(k[11:] if k.startswith("base_model.") else k): v
|
124 |
+
for k, v in non_lora_trainables.items()
|
125 |
+
}
|
126 |
+
if any(k.startswith("model.model.") for k in non_lora_trainables):
|
127 |
+
non_lora_trainables = {
|
128 |
+
(k[6:] if k.startswith("model.") else k): v
|
129 |
+
for k, v in non_lora_trainables.items()
|
130 |
+
}
|
131 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
132 |
+
|
133 |
+
from peft import PeftModel
|
134 |
+
|
135 |
+
print("Loading LoRA weights...")
|
136 |
+
model = PeftModel.from_pretrained(model, model_path)
|
137 |
+
print("Merging LoRA weights...")
|
138 |
+
model = model.merge_and_unload()
|
139 |
+
print("Model is loaded...")
|
140 |
+
elif model_base is not None:
|
141 |
+
# this may be mm projector only
|
142 |
+
print(f"Loading Cambrian-1 from base model... {model_base}")
|
143 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
144 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
145 |
+
model = CambrianLlamaForCausalLM.from_pretrained(
|
146 |
+
model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
|
147 |
+
)
|
148 |
+
|
149 |
+
mm_projector_weights = torch.load(
|
150 |
+
os.path.join(model_path, "mm_projector.bin"), map_location="cpu"
|
151 |
+
)
|
152 |
+
mm_projector_weights = {
|
153 |
+
k: v.to(torch.float16) for k, v in mm_projector_weights.items()
|
154 |
+
}
|
155 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
156 |
+
else:
|
157 |
+
if "qwen" in model_name.lower():
|
158 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
159 |
+
model = CambrianQwenForCausalLM.from_pretrained(
|
160 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
print(f"Loading Cambrian from {model_path}")
|
164 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
165 |
+
model = CambrianLlamaForCausalLM.from_pretrained(
|
166 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
# Load language model
|
170 |
+
if model_base is not None:
|
171 |
+
# PEFT model
|
172 |
+
from peft import PeftModel
|
173 |
+
|
174 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
175 |
+
model = AutoModelForCausalLM.from_pretrained(
|
176 |
+
model_base, low_cpu_mem_usage=True, **kwargs
|
177 |
+
)
|
178 |
+
print(f"Loading LoRA weights from {model_path}")
|
179 |
+
model = PeftModel.from_pretrained(model, model_path)
|
180 |
+
print(f"Merging weights")
|
181 |
+
model = model.merge_and_unload()
|
182 |
+
print("Convert to FP16...")
|
183 |
+
model.to(torch.float16)
|
184 |
+
else:
|
185 |
+
use_fast = False
|
186 |
+
if "mpt" in model_name.lower():
|
187 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
188 |
+
model = AutoModelForCausalLM.from_pretrained(
|
189 |
+
model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs
|
190 |
+
)
|
191 |
+
else:
|
192 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
193 |
+
model = AutoModelForCausalLM.from_pretrained(
|
194 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
195 |
+
)
|
196 |
+
|
197 |
+
image_processor = None
|
198 |
+
|
199 |
+
if "llava" in model_name.lower():
|
200 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
201 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
202 |
+
if mm_use_im_patch_token:
|
203 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
204 |
+
if mm_use_im_start_end:
|
205 |
+
tokenizer.add_tokens(
|
206 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
207 |
+
)
|
208 |
+
model.resize_token_embeddings(len(tokenizer))
|
209 |
+
|
210 |
+
vision_tower = model.get_vision_tower()
|
211 |
+
if not vision_tower.is_loaded:
|
212 |
+
try:
|
213 |
+
vision_tower.load_model(device_map=device_map)
|
214 |
+
except ValueError:
|
215 |
+
# ClipVisionTower doesn't support loading with device_map 'auto'
|
216 |
+
vision_tower.load_model()
|
217 |
+
vision_tower.to(device="cuda", dtype=torch.float16)
|
218 |
+
if device_map != "auto":
|
219 |
+
vision_tower.to(device=device_map, dtype=torch.float16)
|
220 |
+
image_processor = vision_tower.image_processor
|
221 |
+
elif "cambrian" in model_name.lower():
|
222 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
223 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
224 |
+
if mm_use_im_patch_token:
|
225 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
226 |
+
if mm_use_im_start_end:
|
227 |
+
tokenizer.add_tokens(
|
228 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
229 |
+
)
|
230 |
+
model.resize_token_embeddings(len(tokenizer))
|
231 |
+
|
232 |
+
vision_tower_aux_list = model.get_vision_tower_aux_list()
|
233 |
+
|
234 |
+
for vision_tower_aux in vision_tower_aux_list:
|
235 |
+
if not vision_tower_aux.is_loaded:
|
236 |
+
vision_tower_aux.load_model(device_map=device_map)
|
237 |
+
vision_tower_aux.to(device=device, dtype=torch.float16)
|
238 |
+
|
239 |
+
image_processor = [
|
240 |
+
vision_tower_aux.image_processor
|
241 |
+
for vision_tower_aux in vision_tower_aux_list
|
242 |
+
]
|
243 |
+
|
244 |
+
if hasattr(model.config, "max_sequence_length"):
|
245 |
+
context_len = model.config.max_sequence_length
|
246 |
+
else:
|
247 |
+
context_len = 2048
|
248 |
+
|
249 |
+
return tokenizer, model, image_processor, context_len
|
longvu/cambrian_arch.py
ADDED
@@ -0,0 +1,1705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import math
|
17 |
+
import random
|
18 |
+
from abc import ABC, abstractmethod
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from longvu.constants import (
|
25 |
+
DEFAULT_IM_END_TOKEN,
|
26 |
+
DEFAULT_IM_START_TOKEN,
|
27 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
28 |
+
IGNORE_INDEX,
|
29 |
+
IMAGE_TOKEN_INDEX,
|
30 |
+
)
|
31 |
+
|
32 |
+
from .multimodal_encoder.builder import build_vision_tower_aux_list
|
33 |
+
from .multimodal_projector.builder import build_vision_projector
|
34 |
+
from .vision_sampler import VisionTokenSampler
|
35 |
+
|
36 |
+
IS_XLA_AVAILABLE = False
|
37 |
+
|
38 |
+
|
39 |
+
class CambrianMetaModel:
|
40 |
+
|
41 |
+
def __init__(self, config):
|
42 |
+
super(CambrianMetaModel, self).__init__(config)
|
43 |
+
|
44 |
+
if hasattr(config, "mm_vision_tower_aux_list"):
|
45 |
+
|
46 |
+
projector_type = getattr(config, "mm_projector_type", "linear")
|
47 |
+
if projector_type == "sva":
|
48 |
+
|
49 |
+
vision_hidden_size = config.vision_hidden_size
|
50 |
+
num_query_group = config.num_query_group
|
51 |
+
query_num_list = config.query_num_list
|
52 |
+
connector_only = config.connector_only
|
53 |
+
connector_depth = config.connector_depth
|
54 |
+
self.vision_tower_aux_list = build_vision_tower_aux_list(
|
55 |
+
config, delay_load=True
|
56 |
+
)
|
57 |
+
self.mm_projector = nn.Sequential(
|
58 |
+
nn.Linear(vision_hidden_size * num_query_group, config.hidden_size),
|
59 |
+
nn.GELU(),
|
60 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
61 |
+
)
|
62 |
+
|
63 |
+
image_token_len = config.image_token_len
|
64 |
+
vision_tower_aux_token_len_list = (
|
65 |
+
self.config.mm_vision_tower_aux_token_len_list
|
66 |
+
)
|
67 |
+
cross_att_token_len_list = [
|
68 |
+
int(vision_tower_aux_token_len**0.5) // int(image_token_len**0.5)
|
69 |
+
for vision_tower_aux_token_len in vision_tower_aux_token_len_list
|
70 |
+
]
|
71 |
+
|
72 |
+
for aux_i, vision_tower_aux in enumerate(self.vision_tower_aux_list):
|
73 |
+
setattr(
|
74 |
+
self,
|
75 |
+
"mm_projector_aux_{}".format(aux_i),
|
76 |
+
nn.Sequential(
|
77 |
+
nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size),
|
78 |
+
nn.GELU(),
|
79 |
+
nn.Linear(vision_hidden_size, vision_hidden_size),
|
80 |
+
nn.LayerNorm(vision_hidden_size),
|
81 |
+
),
|
82 |
+
)
|
83 |
+
|
84 |
+
for query_group_i in range(num_query_group):
|
85 |
+
cross_att_token_len_list = [
|
86 |
+
int(vision_tower_aux_token_len**0.5)
|
87 |
+
// int(query_num_list[query_group_i] ** 0.5)
|
88 |
+
for vision_tower_aux_token_len in vision_tower_aux_token_len_list
|
89 |
+
]
|
90 |
+
setattr(
|
91 |
+
self,
|
92 |
+
"vision_sampler_{}".format(query_group_i),
|
93 |
+
VisionTokenSampler(
|
94 |
+
vision_hidden_size,
|
95 |
+
vision_hidden_size,
|
96 |
+
[vision_hidden_size] * len(self.vision_tower_aux_list),
|
97 |
+
cross_att_token_len_list,
|
98 |
+
vision_hidden_size,
|
99 |
+
connector_depth,
|
100 |
+
),
|
101 |
+
)
|
102 |
+
|
103 |
+
if not connector_only:
|
104 |
+
num_of_vision_sampler_layers = (
|
105 |
+
config.num_of_vision_sampler_layers
|
106 |
+
) = config.num_of_vision_sampler_layers
|
107 |
+
config.start_of_vision_sampler_layers = (
|
108 |
+
config.start_of_vision_sampler_layers
|
109 |
+
)
|
110 |
+
config.stride_of_vision_sampler_layers = (
|
111 |
+
config.stride_of_vision_sampler_layers
|
112 |
+
)
|
113 |
+
cross_att_token_len_list = [
|
114 |
+
int(vision_tower_aux_token_len**0.5)
|
115 |
+
// int(image_token_len**0.5)
|
116 |
+
for vision_tower_aux_token_len in vision_tower_aux_token_len_list
|
117 |
+
]
|
118 |
+
self.vision_sampler_layers = nn.ModuleList(
|
119 |
+
[
|
120 |
+
VisionTokenSampler(
|
121 |
+
config.hidden_size,
|
122 |
+
vision_hidden_size,
|
123 |
+
[vision_hidden_size] * len(self.vision_tower_aux_list),
|
124 |
+
cross_att_token_len_list,
|
125 |
+
vision_hidden_size,
|
126 |
+
1,
|
127 |
+
)
|
128 |
+
for layer_idx in range(0, num_of_vision_sampler_layers)
|
129 |
+
]
|
130 |
+
)
|
131 |
+
|
132 |
+
self.vision_query = nn.Parameter(
|
133 |
+
torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype)
|
134 |
+
)
|
135 |
+
|
136 |
+
self.image_newline = nn.Parameter(
|
137 |
+
torch.empty(config.hidden_size, dtype=self.dtype)
|
138 |
+
)
|
139 |
+
|
140 |
+
self.frame_pos = torch.stack(
|
141 |
+
[
|
142 |
+
1
|
143 |
+
/ torch.pow(
|
144 |
+
torch.tensor(10000),
|
145 |
+
torch.tensor(2 * (hid_j // 2) / config.hidden_size),
|
146 |
+
)
|
147 |
+
for hid_j in range(config.hidden_size)
|
148 |
+
]
|
149 |
+
)
|
150 |
+
|
151 |
+
else:
|
152 |
+
self.vision_tower_aux_list = build_vision_tower_aux_list(
|
153 |
+
config, delay_load=True
|
154 |
+
)
|
155 |
+
config.mm_hidden_size = sum(
|
156 |
+
[
|
157 |
+
vision_tower_aux.hidden_size
|
158 |
+
for vision_tower_aux in self.vision_tower_aux_list
|
159 |
+
]
|
160 |
+
)
|
161 |
+
self.mm_projector = build_vision_projector(config)
|
162 |
+
self.image_newline = nn.Parameter(
|
163 |
+
torch.empty(config.hidden_size, dtype=self.dtype)
|
164 |
+
)
|
165 |
+
|
166 |
+
def get_frame_pos(self, time_range):
|
167 |
+
frame_pos = self.frame_pos.reshape(1, -1) * time_range.reshape(-1, 1).to(
|
168 |
+
self.frame_pos.device
|
169 |
+
)
|
170 |
+
frame_pos[:, 0::2] = torch.sin(frame_pos[:, 0::2])
|
171 |
+
frame_pos[:, 1::2] = torch.cos(frame_pos[:, 0::2])
|
172 |
+
frame_pos = frame_pos.unsqueeze(1)
|
173 |
+
return frame_pos
|
174 |
+
|
175 |
+
# def get_vision_tower(self):
|
176 |
+
# vision_tower = getattr(self, 'vision_tower', None)
|
177 |
+
# if type(vision_tower) is list:
|
178 |
+
# vision_tower = vision_tower[0]
|
179 |
+
# return vision_tower
|
180 |
+
|
181 |
+
def get_vision_tower_aux_list(self):
|
182 |
+
vision_tower_aux_list = getattr(self, "vision_tower_aux_list", None)
|
183 |
+
return vision_tower_aux_list
|
184 |
+
|
185 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
186 |
+
# vision_tower = model_args.vision_tower
|
187 |
+
num_query_group = model_args.num_query_group
|
188 |
+
query_num_list = model_args.query_num_list
|
189 |
+
vision_hidden_size = model_args.vision_hidden_size
|
190 |
+
vision_tower_aux_list = model_args.vision_tower_aux_list
|
191 |
+
vision_tower_aux_token_len_list = model_args.vision_tower_aux_token_len_list
|
192 |
+
image_token_len = model_args.image_token_len
|
193 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
194 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
195 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
196 |
+
connector_only = model_args.connector_only
|
197 |
+
connector_depth = model_args.connector_depth
|
198 |
+
|
199 |
+
# self.config.mm_vision_tower = vision_tower
|
200 |
+
self.config.image_token_len = image_token_len
|
201 |
+
self.config.num_query_group = num_query_group
|
202 |
+
self.config.query_num_list = query_num_list
|
203 |
+
assert num_query_group == len(query_num_list)
|
204 |
+
self.config.connector_depth = connector_depth
|
205 |
+
self.config.mm_vision_tower_aux_list = vision_tower_aux_list
|
206 |
+
self.config.mm_vision_tower_aux_token_len_list = vision_tower_aux_token_len_list
|
207 |
+
self.config.connector_only = connector_only
|
208 |
+
self.config.highres_connect = model_args.highres_connect
|
209 |
+
self.config.highres = model_args.highres
|
210 |
+
self.config.frame_pos = model_args.frame_pos
|
211 |
+
self.config.lowres_token = model_args.lowres_token
|
212 |
+
self.config.connect_layer = model_args.connect_layer
|
213 |
+
self.config.dino_threshold = getattr(model_args, "dino_threshold", 0.83)
|
214 |
+
self.config.drop_threshold = getattr(model_args, "drop_threshold", 0.6)
|
215 |
+
self.config.is_image_newline = getattr(model_args, "is_image_newline", True)
|
216 |
+
|
217 |
+
if self.get_vision_tower_aux_list() is None:
|
218 |
+
vision_tower_aux_list = build_vision_tower_aux_list(model_args)
|
219 |
+
if model_args.unfreeze_mm_vision_tower:
|
220 |
+
self.vision_tower_aux_list = nn.ModuleList(vision_tower_aux_list)
|
221 |
+
else:
|
222 |
+
self.vision_tower_aux_list = vision_tower_aux_list
|
223 |
+
else:
|
224 |
+
vision_tower_aux_list = self.vision_tower_aux_list
|
225 |
+
for vision_tower_aux in vision_tower_aux_list:
|
226 |
+
vision_tower_aux.load_model()
|
227 |
+
|
228 |
+
self.config.use_mm_proj = True
|
229 |
+
self.config.mm_projector_type = getattr(
|
230 |
+
model_args, "mm_projector_type", "linear"
|
231 |
+
)
|
232 |
+
self.config.vision_hidden_size = vision_hidden_size
|
233 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
234 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
235 |
+
|
236 |
+
if getattr(self, "mm_projector", None) is None:
|
237 |
+
|
238 |
+
if self.config.mm_projector_type == "sva":
|
239 |
+
self.mm_projector = nn.Sequential(
|
240 |
+
nn.Linear(
|
241 |
+
vision_hidden_size * num_query_group, self.config.hidden_size
|
242 |
+
),
|
243 |
+
nn.GELU(),
|
244 |
+
nn.Linear(self.config.hidden_size, self.config.hidden_size),
|
245 |
+
)
|
246 |
+
for aux_i, vision_tower_aux in enumerate(vision_tower_aux_list):
|
247 |
+
setattr(
|
248 |
+
self,
|
249 |
+
"mm_projector_aux_{}".format(aux_i),
|
250 |
+
nn.Sequential(
|
251 |
+
nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size),
|
252 |
+
nn.GELU(),
|
253 |
+
nn.Linear(vision_hidden_size, vision_hidden_size),
|
254 |
+
nn.LayerNorm(vision_hidden_size),
|
255 |
+
),
|
256 |
+
)
|
257 |
+
|
258 |
+
# vision sampler for each group of query as the connector before the LLM
|
259 |
+
for query_group_i in range(num_query_group):
|
260 |
+
cross_att_token_len_list = [
|
261 |
+
int(vision_tower_aux_token_len**0.5)
|
262 |
+
// int(query_num_list[query_group_i] ** 0.5)
|
263 |
+
for vision_tower_aux_token_len in vision_tower_aux_token_len_list
|
264 |
+
]
|
265 |
+
setattr(
|
266 |
+
self,
|
267 |
+
"vision_sampler_{}".format(query_group_i),
|
268 |
+
VisionTokenSampler(
|
269 |
+
vision_hidden_size,
|
270 |
+
vision_hidden_size,
|
271 |
+
[vision_hidden_size] * len(vision_tower_aux_list),
|
272 |
+
cross_att_token_len_list,
|
273 |
+
vision_hidden_size,
|
274 |
+
connector_depth,
|
275 |
+
),
|
276 |
+
)
|
277 |
+
|
278 |
+
# sampler layers within LLM
|
279 |
+
if not connector_only:
|
280 |
+
num_of_vision_sampler_layers = (
|
281 |
+
self.config.num_of_vision_sampler_layers
|
282 |
+
) = model_args.num_of_vision_sampler_layers
|
283 |
+
self.config.start_of_vision_sampler_layers = (
|
284 |
+
model_args.start_of_vision_sampler_layers
|
285 |
+
)
|
286 |
+
self.config.stride_of_vision_sampler_layers = (
|
287 |
+
model_args.stride_of_vision_sampler_layers
|
288 |
+
)
|
289 |
+
cross_att_token_len_list = [
|
290 |
+
int(vision_tower_aux_token_len**0.5)
|
291 |
+
// int(image_token_len**0.5)
|
292 |
+
for vision_tower_aux_token_len in vision_tower_aux_token_len_list
|
293 |
+
]
|
294 |
+
self.vision_sampler_layers = nn.ModuleList(
|
295 |
+
[
|
296 |
+
VisionTokenSampler(
|
297 |
+
self.config.hidden_size,
|
298 |
+
vision_hidden_size,
|
299 |
+
[vision_hidden_size] * len(vision_tower_aux_list),
|
300 |
+
cross_att_token_len_list,
|
301 |
+
vision_hidden_size,
|
302 |
+
1,
|
303 |
+
)
|
304 |
+
for layer_idx in range(0, num_of_vision_sampler_layers)
|
305 |
+
]
|
306 |
+
)
|
307 |
+
vision_embed_std = 1 / torch.sqrt(
|
308 |
+
torch.tensor(vision_hidden_size, dtype=self.dtype)
|
309 |
+
)
|
310 |
+
self.vision_query = nn.Parameter(
|
311 |
+
torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype)
|
312 |
+
* vision_embed_std
|
313 |
+
)
|
314 |
+
|
315 |
+
embed_std = 1 / torch.sqrt(
|
316 |
+
torch.tensor(self.config.hidden_size, dtype=self.dtype)
|
317 |
+
)
|
318 |
+
self.image_newline = nn.Parameter(
|
319 |
+
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
|
320 |
+
)
|
321 |
+
|
322 |
+
else:
|
323 |
+
self.config.mm_hidden_size = sum(
|
324 |
+
[
|
325 |
+
vision_tower_aux.hidden_size
|
326 |
+
for vision_tower_aux in vision_tower_aux_list
|
327 |
+
]
|
328 |
+
)
|
329 |
+
self.mm_projector = build_vision_projector(self.config)
|
330 |
+
embed_std = 1 / torch.sqrt(
|
331 |
+
torch.tensor(self.config.hidden_size, dtype=self.dtype)
|
332 |
+
)
|
333 |
+
self.image_newline = nn.Parameter(
|
334 |
+
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
# In case it is frozen by LoRA
|
338 |
+
for p in self.mm_projector.parameters():
|
339 |
+
p.requires_grad = True
|
340 |
+
|
341 |
+
if pretrain_mm_mlp_adapter is not None:
|
342 |
+
mm_projector_weights = torch.load(
|
343 |
+
pretrain_mm_mlp_adapter, map_location="cpu"
|
344 |
+
)
|
345 |
+
|
346 |
+
def get_w(weights, keyword):
|
347 |
+
return {
|
348 |
+
k.split(keyword + ".")[1]: v
|
349 |
+
for k, v in weights.items()
|
350 |
+
if keyword + "." in k
|
351 |
+
}
|
352 |
+
|
353 |
+
self.mm_projector.load_state_dict(
|
354 |
+
get_w(mm_projector_weights, "mm_projector"), strict=True
|
355 |
+
)
|
356 |
+
|
357 |
+
if self.config.mm_projector_type == "sva":
|
358 |
+
for aux_i in range(len(vision_tower_aux_list)):
|
359 |
+
getattr(self, "mm_projector_aux_{}".format(aux_i)).load_state_dict(
|
360 |
+
get_w(
|
361 |
+
mm_projector_weights, "mm_projector_aux_{}".format(aux_i)
|
362 |
+
),
|
363 |
+
strict=True,
|
364 |
+
)
|
365 |
+
|
366 |
+
for query_group_i in range(num_query_group):
|
367 |
+
getattr(
|
368 |
+
self, "vision_sampler_{}".format(query_group_i)
|
369 |
+
).load_state_dict(
|
370 |
+
get_w(
|
371 |
+
mm_projector_weights,
|
372 |
+
"vision_sampler_{}".format(query_group_i),
|
373 |
+
),
|
374 |
+
strict=True,
|
375 |
+
)
|
376 |
+
|
377 |
+
if not connector_only:
|
378 |
+
self.vision_sampler_layers.load_state_dict(
|
379 |
+
get_w(mm_projector_weights, "vision_sampler_layers"),
|
380 |
+
strict=True,
|
381 |
+
)
|
382 |
+
self.vision_query.data = mm_projector_weights["model.vision_query"]
|
383 |
+
self.image_newline.data = mm_projector_weights["model.image_newline"]
|
384 |
+
|
385 |
+
|
386 |
+
def unmask_attention_mask(mask, original_size):
|
387 |
+
original_w, original_h = original_size
|
388 |
+
cur_h, cur_w = mask.shape[1:3]
|
389 |
+
|
390 |
+
original_aspect_ratio = original_w / original_h
|
391 |
+
current_aspect_ratio = cur_w / cur_h
|
392 |
+
|
393 |
+
if original_aspect_ratio > current_aspect_ratio:
|
394 |
+
scale_factor = cur_w / original_w
|
395 |
+
new_height = int(original_h * scale_factor)
|
396 |
+
padding = (cur_h - new_height) // 2
|
397 |
+
if padding > 0:
|
398 |
+
mask[:, :padding, :] = 0
|
399 |
+
mask[:, -padding:, :] = 0
|
400 |
+
return mask
|
401 |
+
else:
|
402 |
+
scale_factor = cur_h / original_h
|
403 |
+
new_width = int(original_w * scale_factor)
|
404 |
+
padding = (cur_w - new_width) // 2
|
405 |
+
if padding > 0:
|
406 |
+
mask[:, :, :padding] = 0
|
407 |
+
mask[:, :, -padding:] = 0
|
408 |
+
return mask
|
409 |
+
|
410 |
+
|
411 |
+
def unpad_image(tensor, original_size):
|
412 |
+
"""
|
413 |
+
Unpads a PyTorch tensor of a padded and resized image.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
417 |
+
original_size (tuple): The original size of the image (height, width).
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
torch.Tensor: The unpadded image tensor.
|
421 |
+
"""
|
422 |
+
original_width, original_height = original_size
|
423 |
+
current_height, current_width = tensor.shape[1:3]
|
424 |
+
|
425 |
+
original_aspect_ratio = original_width / original_height
|
426 |
+
current_aspect_ratio = current_width / current_height
|
427 |
+
|
428 |
+
if original_aspect_ratio > current_aspect_ratio:
|
429 |
+
scale_factor = current_width / original_width
|
430 |
+
new_height = int(original_height * scale_factor)
|
431 |
+
padding = (current_height - new_height) // 2
|
432 |
+
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
433 |
+
# if 0 in unpadded_tensor.shape:
|
434 |
+
# print(f"scale_factor: {scale_factor}, new_height: {new_height}, padding: {padding}, original_width: {original_width}, original_height: {original_height}")
|
435 |
+
else:
|
436 |
+
scale_factor = current_height / original_height
|
437 |
+
new_width = int(original_width * scale_factor)
|
438 |
+
padding = (current_width - new_width) // 2
|
439 |
+
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
440 |
+
# if 0 in unpadded_tensor.shape:
|
441 |
+
# print(f"scale_factor: {scale_factor}, new_width: {new_width}, padding: {padding}, original_width: {original_width}, original_height: {original_height}")
|
442 |
+
|
443 |
+
return unpadded_tensor
|
444 |
+
|
445 |
+
|
446 |
+
class CambrianMetaForCausalLM(ABC):
|
447 |
+
|
448 |
+
@abstractmethod
|
449 |
+
def get_model(self):
|
450 |
+
pass
|
451 |
+
|
452 |
+
# def get_vision_tower(self):
|
453 |
+
# return self.get_model().get_vision_tower()
|
454 |
+
|
455 |
+
def get_vision_tower_aux_list(self):
|
456 |
+
return self.get_model().get_vision_tower_aux_list()
|
457 |
+
|
458 |
+
def rearrange_vision_tower_features_train(
|
459 |
+
self,
|
460 |
+
vision_tower_aux_feature_list,
|
461 |
+
vision_tower_aux_attention_masks_list,
|
462 |
+
query_side_len,
|
463 |
+
):
|
464 |
+
vision_tower_aux_feature_rearranged_list = []
|
465 |
+
vision_tower_aux_attention_masks_rearranged_list = []
|
466 |
+
bs = vision_tower_aux_feature_list[0].shape[0]
|
467 |
+
for vision_tower_aux_feature, vision_tower_aux_attention_masks in zip(
|
468 |
+
vision_tower_aux_feature_list, vision_tower_aux_attention_masks_list
|
469 |
+
):
|
470 |
+
aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5)
|
471 |
+
assert (aux_height // query_side_len) * query_side_len == aux_height
|
472 |
+
|
473 |
+
reduce_factor = aux_height // query_side_len
|
474 |
+
vision_tower_aux_feature_rearranged = vision_tower_aux_feature.view(
|
475 |
+
bs, query_side_len, reduce_factor, query_side_len, reduce_factor, -1
|
476 |
+
)
|
477 |
+
vision_tower_aux_feature_rearranged = (
|
478 |
+
vision_tower_aux_feature_rearranged.permute(0, 1, 3, 2, 4, 5)
|
479 |
+
.contiguous()
|
480 |
+
.flatten(0, 2)
|
481 |
+
.flatten(1, 2)
|
482 |
+
)
|
483 |
+
|
484 |
+
vision_tower_aux_attention_masks_rearranged = (
|
485 |
+
vision_tower_aux_attention_masks.view(
|
486 |
+
bs * query_side_len * query_side_len, reduce_factor * reduce_factor
|
487 |
+
)
|
488 |
+
)
|
489 |
+
|
490 |
+
vision_tower_aux_feature_rearranged_list.append(
|
491 |
+
vision_tower_aux_feature_rearranged
|
492 |
+
)
|
493 |
+
vision_tower_aux_attention_masks_rearranged_list.append(
|
494 |
+
vision_tower_aux_attention_masks_rearranged
|
495 |
+
)
|
496 |
+
return (
|
497 |
+
vision_tower_aux_feature_rearranged_list,
|
498 |
+
vision_tower_aux_attention_masks_rearranged_list,
|
499 |
+
)
|
500 |
+
|
501 |
+
def rearrange_vision_tower_features_inference(
|
502 |
+
self, vision_tower_aux_feature_list, query_side_len, image_sizes, unpad=False
|
503 |
+
):
|
504 |
+
vision_tower_aux_feature_rearranged_list = []
|
505 |
+
vision_tower_aux_attention_masks_rearranged_list = []
|
506 |
+
bs = vision_tower_aux_feature_list[0].shape[0]
|
507 |
+
for vision_tower_aux_feature in vision_tower_aux_feature_list:
|
508 |
+
aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5)
|
509 |
+
assert (aux_height // query_side_len) * query_side_len == aux_height
|
510 |
+
|
511 |
+
reduce_factor = aux_height // query_side_len
|
512 |
+
|
513 |
+
vision_tower_aux_feature_rearranged = []
|
514 |
+
vision_tower_aux_attention_masks_rearranged = []
|
515 |
+
for batch_i in range(bs):
|
516 |
+
image_size = image_sizes[batch_i]
|
517 |
+
cur_vision_tower_aux_feature = vision_tower_aux_feature[batch_i]
|
518 |
+
|
519 |
+
cur_vision_tower_aux_attention_masks_rearranged = torch.ones(
|
520 |
+
(1, aux_height, aux_width),
|
521 |
+
dtype=torch.bool,
|
522 |
+
device=cur_vision_tower_aux_feature.device,
|
523 |
+
)
|
524 |
+
cur_vision_tower_aux_feature_rearranged = (
|
525 |
+
cur_vision_tower_aux_feature.view(
|
526 |
+
1,
|
527 |
+
query_side_len,
|
528 |
+
reduce_factor,
|
529 |
+
query_side_len,
|
530 |
+
reduce_factor,
|
531 |
+
-1,
|
532 |
+
)
|
533 |
+
)
|
534 |
+
cur_vision_tower_aux_feature_rearranged = (
|
535 |
+
cur_vision_tower_aux_feature_rearranged.permute(
|
536 |
+
0, 1, 3, 2, 4, 5
|
537 |
+
).contiguous()
|
538 |
+
)
|
539 |
+
if unpad:
|
540 |
+
cur_vision_tower_aux_feature_rearranged = unpad_image(
|
541 |
+
cur_vision_tower_aux_feature_rearranged, image_size
|
542 |
+
)
|
543 |
+
cur_vision_tower_aux_feature_rearranged = (
|
544 |
+
cur_vision_tower_aux_feature_rearranged.flatten(0, 2).flatten(1, 2)
|
545 |
+
) # query_side_len*query_side_len X reduce_factor*reduce_factor X C
|
546 |
+
|
547 |
+
cur_vision_tower_aux_attention_masks_rearranged = unmask_attention_mask(
|
548 |
+
cur_vision_tower_aux_attention_masks_rearranged, image_size
|
549 |
+
)
|
550 |
+
cur_vision_tower_aux_attention_masks_rearranged = (
|
551 |
+
cur_vision_tower_aux_attention_masks_rearranged.view(
|
552 |
+
1, query_side_len, reduce_factor, query_side_len, reduce_factor
|
553 |
+
)
|
554 |
+
.permute(0, 1, 3, 2, 4)
|
555 |
+
.contiguous()
|
556 |
+
)
|
557 |
+
if unpad:
|
558 |
+
cur_vision_tower_aux_attention_masks_rearranged = unpad_image(
|
559 |
+
cur_vision_tower_aux_attention_masks_rearranged, image_size
|
560 |
+
)
|
561 |
+
cur_vision_tower_aux_attention_masks_rearranged = (
|
562 |
+
cur_vision_tower_aux_attention_masks_rearranged.flatten(
|
563 |
+
0, 2
|
564 |
+
).flatten(1, 2)
|
565 |
+
)
|
566 |
+
|
567 |
+
cur_vision_tower_aux_attention_masks_rearranged[
|
568 |
+
cur_vision_tower_aux_attention_masks_rearranged.sum(-1) == 0
|
569 |
+
] = True
|
570 |
+
|
571 |
+
vision_tower_aux_feature_rearranged.append(
|
572 |
+
cur_vision_tower_aux_feature_rearranged
|
573 |
+
)
|
574 |
+
vision_tower_aux_attention_masks_rearranged.append(
|
575 |
+
cur_vision_tower_aux_attention_masks_rearranged
|
576 |
+
)
|
577 |
+
|
578 |
+
vision_tower_aux_feature_rearranged = torch.cat(
|
579 |
+
vision_tower_aux_feature_rearranged, 0
|
580 |
+
)
|
581 |
+
vision_tower_aux_attention_masks_rearranged = torch.cat(
|
582 |
+
vision_tower_aux_attention_masks_rearranged, 0
|
583 |
+
)
|
584 |
+
|
585 |
+
vision_tower_aux_feature_rearranged_list.append(
|
586 |
+
vision_tower_aux_feature_rearranged
|
587 |
+
)
|
588 |
+
vision_tower_aux_attention_masks_rearranged_list.append(
|
589 |
+
vision_tower_aux_attention_masks_rearranged
|
590 |
+
)
|
591 |
+
|
592 |
+
return (
|
593 |
+
vision_tower_aux_feature_rearranged_list,
|
594 |
+
vision_tower_aux_attention_masks_rearranged_list,
|
595 |
+
)
|
596 |
+
|
597 |
+
def encode_images(self, image_aux_list, encode_type=None):
|
598 |
+
vision_tower_aux_list = self.get_model().get_vision_tower_aux_list()
|
599 |
+
image_aux_features_list = []
|
600 |
+
chunk_size = 64
|
601 |
+
if encode_type == "dino":
|
602 |
+
image_aux = image_aux_list[-1]
|
603 |
+
vision_tower_aux = vision_tower_aux_list[-1]
|
604 |
+
if image_aux.shape[0] > chunk_size:
|
605 |
+
image_aux_features_chunks = []
|
606 |
+
for start_idx in range(0, image_aux.shape[0], chunk_size):
|
607 |
+
end_idx = min(start_idx + chunk_size, image_aux.shape[0])
|
608 |
+
chunk = image_aux[start_idx:end_idx]
|
609 |
+
image_aux_features_chunk = vision_tower_aux(chunk)
|
610 |
+
image_aux_features_chunks.append(image_aux_features_chunk)
|
611 |
+
image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
|
612 |
+
else:
|
613 |
+
image_aux_features = vision_tower_aux(image_aux)
|
614 |
+
return image_aux_features
|
615 |
+
elif encode_type == "siglip":
|
616 |
+
image_aux = image_aux_list[0]
|
617 |
+
vision_tower_aux = vision_tower_aux_list[0]
|
618 |
+
if image_aux.shape[0] > chunk_size:
|
619 |
+
image_aux_features_chunks = []
|
620 |
+
for start_idx in range(0, image_aux.shape[0], chunk_size):
|
621 |
+
end_idx = min(start_idx + chunk_size, image_aux.shape[0])
|
622 |
+
chunk = image_aux[start_idx:end_idx]
|
623 |
+
image_aux_features_chunk = vision_tower_aux(chunk)
|
624 |
+
image_aux_features_chunks.append(image_aux_features_chunk)
|
625 |
+
image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
|
626 |
+
else:
|
627 |
+
image_aux_features = vision_tower_aux(image_aux)
|
628 |
+
return image_aux_features
|
629 |
+
else:
|
630 |
+
for image_aux, vision_tower_aux in zip(
|
631 |
+
image_aux_list, vision_tower_aux_list
|
632 |
+
):
|
633 |
+
if image_aux.shape[0] > chunk_size:
|
634 |
+
image_aux_features_chunks = []
|
635 |
+
for start_idx in range(0, image_aux.shape[0], chunk_size):
|
636 |
+
end_idx = min(start_idx + chunk_size, image_aux.shape[0])
|
637 |
+
chunk = image_aux[start_idx:end_idx]
|
638 |
+
image_aux_features_chunk = vision_tower_aux(chunk)
|
639 |
+
image_aux_features_chunks.append(image_aux_features_chunk)
|
640 |
+
image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
|
641 |
+
else:
|
642 |
+
image_aux_features = vision_tower_aux(image_aux)
|
643 |
+
image_aux_features_list.append(image_aux_features)
|
644 |
+
return image_aux_features_list
|
645 |
+
|
646 |
+
def select_frame(
|
647 |
+
self,
|
648 |
+
feature_list,
|
649 |
+
split_sizes,
|
650 |
+
input_ids,
|
651 |
+
new_image_aux_list,
|
652 |
+
image_sizes,
|
653 |
+
window_size=16,
|
654 |
+
threshold=0.83,
|
655 |
+
):
|
656 |
+
dino_features_batch = torch.split(feature_list, split_sizes, dim=0)
|
657 |
+
new_image_aux_batch_0 = torch.split(new_image_aux_list[0], split_sizes, dim=0)
|
658 |
+
new_image_aux_batch_1 = torch.split(new_image_aux_list[1], split_sizes, dim=0)
|
659 |
+
new_split_sizes = []
|
660 |
+
selected_frames_all_0 = []
|
661 |
+
selected_frames_all_1 = []
|
662 |
+
selected_frames_feature_all = []
|
663 |
+
selected_frame_indices_all = []
|
664 |
+
for i_batch, frame_features in enumerate(dino_features_batch):
|
665 |
+
try:
|
666 |
+
if "llama" in self.get_model().config.model_type:
|
667 |
+
text_len = torch.where(input_ids[i_batch] == 128002)[-1][0]
|
668 |
+
else:
|
669 |
+
text_len = torch.where(input_ids[i_batch] == 151643)[-1][0]
|
670 |
+
except:
|
671 |
+
text_len = len(input_ids[i_batch])
|
672 |
+
original_width, original_height = image_sizes[i_batch]
|
673 |
+
if getattr(self.get_model().config, "highres", False):
|
674 |
+
token_per_frame = self.get_model().config.lowres_token ** 2
|
675 |
+
else:
|
676 |
+
token_per_frame = self.get_model().config.image_token_len
|
677 |
+
# current_height, current_width = token_per_side, token_per_side
|
678 |
+
# original_aspect_ratio = original_width / original_height
|
679 |
+
# current_aspect_ratio = current_width / current_height
|
680 |
+
# if original_aspect_ratio > current_aspect_ratio:
|
681 |
+
# scale_factor = current_width / original_width
|
682 |
+
# new_height = int(original_height * scale_factor)
|
683 |
+
# padding = math.ceil((current_height - new_height) / 2.0)
|
684 |
+
# token_per_frame = (
|
685 |
+
# current_height - padding * 2
|
686 |
+
# ) * token_per_side + token_per_side
|
687 |
+
# else:
|
688 |
+
# scale_factor = current_height / original_height
|
689 |
+
# new_width = int(original_width * scale_factor)
|
690 |
+
# padding = math.ceil((current_width - new_width) / 2.0)
|
691 |
+
# token_per_frame = (current_width - padding * 2) * token_per_side + (
|
692 |
+
# current_width - padding * 2
|
693 |
+
# )
|
694 |
+
# token_per_frame = (
|
695 |
+
# token_per_side**2 if token_per_frame < 1 else token_per_frame
|
696 |
+
# )
|
697 |
+
max_num_frames = max(
|
698 |
+
1,
|
699 |
+
(
|
700 |
+
self.get_model().config.tokenizer_model_max_length
|
701 |
+
- text_len
|
702 |
+
- getattr(self.get_model().config, "inference_max_length", 16)
|
703 |
+
)
|
704 |
+
// token_per_frame,
|
705 |
+
)
|
706 |
+
if len(frame_features) < max_num_frames:
|
707 |
+
selected_frames_all_0.append(new_image_aux_batch_0[i_batch])
|
708 |
+
selected_frames_all_1.append(new_image_aux_batch_1[i_batch])
|
709 |
+
selected_frames_feature_all.append(frame_features)
|
710 |
+
new_split_sizes.append(len(frame_features))
|
711 |
+
selected_frame_indices_all.append(torch.arange(len(frame_features)))
|
712 |
+
continue
|
713 |
+
|
714 |
+
num_segments = len(frame_features) // window_size
|
715 |
+
if num_segments == 0:
|
716 |
+
query_feature = frame_features.flatten(1, 2)
|
717 |
+
query_feature = query_feature / torch.norm(
|
718 |
+
(query_feature), dim=1, keepdim=True
|
719 |
+
)
|
720 |
+
similarities = torch.mean(query_feature @ query_feature.T, dim=1)
|
721 |
+
similarities[len(frame_features) // 2] = 0
|
722 |
+
indices = torch.where(similarities < threshold)[0]
|
723 |
+
selected_frame_indices_all.append(indices)
|
724 |
+
selected_frames_all_0.append(new_image_aux_batch_0[i_batch][indices])
|
725 |
+
selected_frames_all_1.append(new_image_aux_batch_1[i_batch][indices])
|
726 |
+
selected_frames_feature_all.append(frame_features[indices])
|
727 |
+
new_split_sizes.append(len(indices))
|
728 |
+
continue
|
729 |
+
segments_frames_0 = []
|
730 |
+
segments_frames_1 = []
|
731 |
+
segments_features = []
|
732 |
+
for start_idx in range(0, len(frame_features), window_size):
|
733 |
+
end_idx = min(start_idx + window_size, len(frame_features))
|
734 |
+
segments_frames_0.append(
|
735 |
+
new_image_aux_batch_0[i_batch][start_idx:end_idx]
|
736 |
+
)
|
737 |
+
segments_frames_1.append(
|
738 |
+
new_image_aux_batch_1[i_batch][start_idx:end_idx]
|
739 |
+
)
|
740 |
+
segments_features.append(frame_features[start_idx:end_idx])
|
741 |
+
selected_frames_0 = []
|
742 |
+
selected_frames_1 = []
|
743 |
+
selected_features = []
|
744 |
+
selected_frame_indices = []
|
745 |
+
for i, segment in enumerate(segments_features):
|
746 |
+
query_feature = segment.flatten(1, 2)
|
747 |
+
query_feature = query_feature / torch.norm(
|
748 |
+
(query_feature), dim=1, keepdim=True
|
749 |
+
)
|
750 |
+
similarities = torch.mean(query_feature @ query_feature.T, dim=1)
|
751 |
+
similarities[len(segment) // 2] = 0
|
752 |
+
indices = torch.where(similarities < threshold)[0]
|
753 |
+
selected_frames_0.append(segments_frames_0[i][indices])
|
754 |
+
selected_frames_1.append(segments_frames_1[i][indices])
|
755 |
+
selected_features.append(segment[indices])
|
756 |
+
selected_frame_indices.extend(indices + i * window_size)
|
757 |
+
selected_frames_0 = torch.cat(selected_frames_0, dim=0)
|
758 |
+
selected_frames_1 = torch.cat(selected_frames_1, dim=0)
|
759 |
+
selected_features = torch.cat(selected_features, dim=0)
|
760 |
+
selected_frame_indices = torch.tensor(selected_frame_indices)
|
761 |
+
# ablation
|
762 |
+
max_num_frames = 400 # in case of OOM
|
763 |
+
if len(selected_frames_0) > max_num_frames:
|
764 |
+
interval = len(selected_frames_0) / float(max_num_frames)
|
765 |
+
indices = [int(interval * i) for i in range(max_num_frames)]
|
766 |
+
new_split_sizes.append(len(indices))
|
767 |
+
selected_frames_all_0.append(selected_frames_0[indices])
|
768 |
+
selected_frames_all_1.append(selected_frames_1[indices])
|
769 |
+
selected_frames_feature_all.append(selected_features[indices])
|
770 |
+
selected_frame_indices = selected_frame_indices[indices]
|
771 |
+
else:
|
772 |
+
new_split_sizes.append(len(selected_frames_0))
|
773 |
+
selected_frames_all_0.append(selected_frames_0)
|
774 |
+
selected_frames_all_1.append(selected_frames_1)
|
775 |
+
selected_frames_feature_all.append(selected_features)
|
776 |
+
selected_frame_indices_all.append(selected_frame_indices)
|
777 |
+
selected_frames_all_0 = torch.cat(selected_frames_all_0, dim=0)
|
778 |
+
selected_frames_all_1 = torch.cat(selected_frames_all_1, dim=0)
|
779 |
+
selected_frames_feature_all = torch.cat(selected_frames_feature_all, dim=0)
|
780 |
+
return (
|
781 |
+
selected_frames_feature_all,
|
782 |
+
new_split_sizes,
|
783 |
+
[selected_frames_all_0, selected_frames_all_1],
|
784 |
+
selected_frame_indices_all,
|
785 |
+
)
|
786 |
+
|
787 |
+
def prepare_inputs_labels_for_multimodal(
|
788 |
+
self,
|
789 |
+
input_ids,
|
790 |
+
position_ids,
|
791 |
+
attention_mask,
|
792 |
+
past_key_values,
|
793 |
+
labels,
|
794 |
+
images,
|
795 |
+
image_aux_attention_masks_list=None,
|
796 |
+
image_sizes=None,
|
797 |
+
):
|
798 |
+
# vision_tower = self.get_vision_tower()
|
799 |
+
vision_tower_aux_list = self.get_model().get_vision_tower_aux_list()
|
800 |
+
if vision_tower_aux_list is None or images is None or input_ids.shape[1] == 1:
|
801 |
+
return (
|
802 |
+
input_ids,
|
803 |
+
position_ids,
|
804 |
+
attention_mask,
|
805 |
+
past_key_values,
|
806 |
+
None,
|
807 |
+
labels,
|
808 |
+
None,
|
809 |
+
None,
|
810 |
+
None,
|
811 |
+
None,
|
812 |
+
)
|
813 |
+
|
814 |
+
image_aux_list = images
|
815 |
+
|
816 |
+
split_sizes = None
|
817 |
+
|
818 |
+
if type(image_aux_list[0]) is list or image_aux_list[0].ndim == 5:
|
819 |
+
split_sizes_ori = [
|
820 |
+
1 if image.ndim == 3 else image.shape[0] for image in image_aux_list[0]
|
821 |
+
]
|
822 |
+
new_image_aux_list = []
|
823 |
+
for image_aux in image_aux_list:
|
824 |
+
if type(image_aux) is list:
|
825 |
+
image_aux = [
|
826 |
+
x.unsqueeze(0) if x.ndim == 3 else x for x in image_aux
|
827 |
+
]
|
828 |
+
concat_image_aux = torch.cat([image for image in image_aux], dim=0)
|
829 |
+
new_image_aux_list.append(concat_image_aux)
|
830 |
+
image_aux_features_dino = self.encode_images(
|
831 |
+
new_image_aux_list, encode_type="dino"
|
832 |
+
)
|
833 |
+
|
834 |
+
(
|
835 |
+
image_aux_features_dino,
|
836 |
+
split_sizes,
|
837 |
+
new_image_aux_list,
|
838 |
+
selected_frame_indices_all,
|
839 |
+
) = self.select_frame(
|
840 |
+
image_aux_features_dino,
|
841 |
+
split_sizes_ori,
|
842 |
+
input_ids,
|
843 |
+
new_image_aux_list,
|
844 |
+
image_sizes,
|
845 |
+
threshold=getattr(self.get_model().config, "dino_threshold", 0.83),
|
846 |
+
)
|
847 |
+
|
848 |
+
image_aux_features_siglip = self.encode_images(
|
849 |
+
new_image_aux_list, encode_type="siglip"
|
850 |
+
)
|
851 |
+
image_aux_features_list = [
|
852 |
+
image_aux_features_siglip,
|
853 |
+
image_aux_features_dino,
|
854 |
+
]
|
855 |
+
|
856 |
+
bs = image_aux_features_list[0].shape[0]
|
857 |
+
dtype = new_image_aux_list[0].dtype
|
858 |
+
|
859 |
+
frame_sizes = []
|
860 |
+
for i in range(len(image_sizes)):
|
861 |
+
for j in range(split_sizes[i]):
|
862 |
+
frame_sizes.append(image_sizes[i])
|
863 |
+
image_sizes = frame_sizes
|
864 |
+
else:
|
865 |
+
image_aux_features_list = self.encode_images(image_aux_list)
|
866 |
+
bs = image_aux_list[0].shape[0]
|
867 |
+
dtype = image_aux_list[0].dtype
|
868 |
+
|
869 |
+
image_token_len = self.get_model().config.image_token_len
|
870 |
+
query_num_list = self.get_model().config.query_num_list
|
871 |
+
|
872 |
+
final_height = final_width = int(image_token_len**0.5)
|
873 |
+
|
874 |
+
final_image_features_list = []
|
875 |
+
final_image_features_down_list = []
|
876 |
+
|
877 |
+
# only needed for sva
|
878 |
+
vision_tower_aux_feature_list_final = None
|
879 |
+
vision_tower_aux_attention_masks_list_final = None
|
880 |
+
global_context_feature_final = None
|
881 |
+
|
882 |
+
if self.get_model().config.mm_projector_type == "sva":
|
883 |
+
vision_tower_aux_feature_list = []
|
884 |
+
vision_tower_aux_attention_masks_list = []
|
885 |
+
# get vision tokens from each vision tower
|
886 |
+
for aux_i in range(len(vision_tower_aux_list)):
|
887 |
+
image_aux_features = image_aux_features_list[aux_i]
|
888 |
+
|
889 |
+
image_aux_features = getattr(
|
890 |
+
self.get_model(), "mm_projector_aux_{}".format(aux_i)
|
891 |
+
)(image_aux_features).to(dtype)
|
892 |
+
if aux_i == 0:
|
893 |
+
global_context_feature = image_aux_features.mean(1).view(
|
894 |
+
bs, 1, 1, -1
|
895 |
+
)
|
896 |
+
|
897 |
+
vision_tower_aux_feature_list.append(image_aux_features)
|
898 |
+
input_mix_res = True
|
899 |
+
input_high_res = True
|
900 |
+
# perform vision sampling for each query group
|
901 |
+
for query_group_i, query_num in enumerate(query_num_list):
|
902 |
+
query_features_i = (
|
903 |
+
self.get_model()
|
904 |
+
.vision_query[query_group_i, :]
|
905 |
+
.view(1, 1, 1, -1)
|
906 |
+
.expand(bs, query_num, -1, -1)
|
907 |
+
)
|
908 |
+
global_context_feature_i = global_context_feature.expand(
|
909 |
+
-1, query_num, 1, -1
|
910 |
+
).flatten(0, 1)
|
911 |
+
query_side_len = int(query_num**0.5)
|
912 |
+
if IS_XLA_AVAILABLE:
|
913 |
+
(
|
914 |
+
vision_tower_aux_feature_list_i,
|
915 |
+
vision_tower_aux_attention_masks_list_i,
|
916 |
+
) = self.rearrange_vision_tower_features_train(
|
917 |
+
vision_tower_aux_feature_list,
|
918 |
+
image_aux_attention_masks_list,
|
919 |
+
query_side_len,
|
920 |
+
)
|
921 |
+
else:
|
922 |
+
(
|
923 |
+
vision_tower_aux_feature_list_i,
|
924 |
+
vision_tower_aux_attention_masks_list_i,
|
925 |
+
) = self.rearrange_vision_tower_features_inference(
|
926 |
+
vision_tower_aux_feature_list, query_side_len, image_sizes
|
927 |
+
)
|
928 |
+
|
929 |
+
query_features_i = getattr(
|
930 |
+
self.get_model(), "vision_sampler_{}".format(query_group_i)
|
931 |
+
)(
|
932 |
+
query_features_i.flatten(0, 1),
|
933 |
+
global_context_feature_i,
|
934 |
+
*vision_tower_aux_feature_list_i,
|
935 |
+
*vision_tower_aux_attention_masks_list_i,
|
936 |
+
)
|
937 |
+
query_features_i = query_features_i.view(bs, query_num, -1)
|
938 |
+
|
939 |
+
if split_sizes is not None:
|
940 |
+
try:
|
941 |
+
if "llama" in self.get_model().config.model_type:
|
942 |
+
text_len = torch.where(input_ids[0] == 128002)[-1][0]
|
943 |
+
else:
|
944 |
+
text_len = torch.where(input_ids[0] == 151643)[-1][0]
|
945 |
+
except:
|
946 |
+
text_len = len(input_ids[0])
|
947 |
+
max_visual_len = (
|
948 |
+
self.get_model().config.tokenizer_model_max_length
|
949 |
+
- text_len
|
950 |
+
- getattr(self.get_model().config, "inference_max_length", 16)
|
951 |
+
)
|
952 |
+
max_num_frames = max(
|
953 |
+
1,
|
954 |
+
math.floor(max_visual_len // (final_height * final_width)),
|
955 |
+
)
|
956 |
+
max_num_frames_low = max(
|
957 |
+
1,
|
958 |
+
math.floor(
|
959 |
+
max_visual_len
|
960 |
+
// (self.get_model().config.lowres_token ** 2)
|
961 |
+
),
|
962 |
+
)
|
963 |
+
if split_sizes[0] < max_num_frames:
|
964 |
+
input_mix_res = False
|
965 |
+
elif split_sizes[0] > max_num_frames_low:
|
966 |
+
input_mix_res = False
|
967 |
+
input_high_res = False
|
968 |
+
|
969 |
+
# input_mix_res = False # ablation
|
970 |
+
|
971 |
+
if (getattr(self.config, "highres", False)) and input_mix_res:
|
972 |
+
_query_features_i = (
|
973 |
+
query_features_i.permute(0, 2, 1)
|
974 |
+
.contiguous()
|
975 |
+
.view(bs, -1, query_side_len, query_side_len)
|
976 |
+
)
|
977 |
+
_query_features_i = F.interpolate(
|
978 |
+
_query_features_i.float(),
|
979 |
+
size=(
|
980 |
+
self.get_model().config.lowres_token,
|
981 |
+
self.get_model().config.lowres_token,
|
982 |
+
),
|
983 |
+
mode="bilinear",
|
984 |
+
align_corners=False,
|
985 |
+
).to(dtype=query_features_i.dtype)
|
986 |
+
_query_features_i = (
|
987 |
+
_query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
|
988 |
+
)
|
989 |
+
final_image_features_down_list.append(_query_features_i)
|
990 |
+
|
991 |
+
# interpolate to the final target size
|
992 |
+
if query_side_len != final_height:
|
993 |
+
query_features_i = (
|
994 |
+
query_features_i.permute(0, 2, 1)
|
995 |
+
.contiguous()
|
996 |
+
.view(bs, -1, query_side_len, query_side_len)
|
997 |
+
)
|
998 |
+
if input_high_res:
|
999 |
+
query_features_i = F.interpolate(
|
1000 |
+
query_features_i.float(),
|
1001 |
+
size=(final_height, final_width),
|
1002 |
+
mode="bilinear",
|
1003 |
+
align_corners=False,
|
1004 |
+
).to(dtype=query_features_i.dtype)
|
1005 |
+
else:
|
1006 |
+
query_features_i = F.interpolate(
|
1007 |
+
query_features_i.float(),
|
1008 |
+
size=(8, 8),
|
1009 |
+
mode="bilinear",
|
1010 |
+
align_corners=False,
|
1011 |
+
).to(dtype=query_features_i.dtype)
|
1012 |
+
query_features_i = (
|
1013 |
+
query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
|
1014 |
+
)
|
1015 |
+
final_image_features_list.append(query_features_i)
|
1016 |
+
|
1017 |
+
if IS_XLA_AVAILABLE:
|
1018 |
+
(
|
1019 |
+
vision_tower_aux_feature_list_final,
|
1020 |
+
vision_tower_aux_attention_masks_list_final,
|
1021 |
+
) = self.rearrange_vision_tower_features_train(
|
1022 |
+
vision_tower_aux_feature_list,
|
1023 |
+
image_aux_attention_masks_list,
|
1024 |
+
final_height,
|
1025 |
+
)
|
1026 |
+
global_context_feature_final = global_context_feature.expand(
|
1027 |
+
-1, final_height * final_width, 1, -1
|
1028 |
+
).flatten(0, 1)
|
1029 |
+
else:
|
1030 |
+
final_image_features_list = image_aux_features_list
|
1031 |
+
|
1032 |
+
image_features = torch.cat(final_image_features_list, -1)
|
1033 |
+
image_features = self.get_model().mm_projector(image_features).to(dtype)
|
1034 |
+
|
1035 |
+
if (getattr(self.config, "highres", False)) and input_mix_res:
|
1036 |
+
image_features_down = torch.cat(final_image_features_down_list, -1)
|
1037 |
+
image_features_down = (
|
1038 |
+
self.get_model().mm_projector(image_features_down).to(dtype)
|
1039 |
+
)
|
1040 |
+
|
1041 |
+
if IS_XLA_AVAILABLE:
|
1042 |
+
image_features = image_features.view(
|
1043 |
+
image_features.shape[0], final_height, final_width, -1
|
1044 |
+
)
|
1045 |
+
image_features = torch.cat(
|
1046 |
+
(
|
1047 |
+
image_features,
|
1048 |
+
self.model.image_newline[None, None, None, :].expand(
|
1049 |
+
image_features.shape[0], final_height, 1, -1
|
1050 |
+
),
|
1051 |
+
),
|
1052 |
+
dim=2,
|
1053 |
+
)
|
1054 |
+
image_features = image_features.flatten(1, 2)
|
1055 |
+
final_size = [(final_height, final_width)] * bs
|
1056 |
+
|
1057 |
+
else:
|
1058 |
+
image_features = image_features.view(bs, final_height, final_width, -1)
|
1059 |
+
if (getattr(self.config, "highres", False)) and input_mix_res:
|
1060 |
+
image_features_down = image_features_down.view(
|
1061 |
+
bs,
|
1062 |
+
self.get_model().config.lowres_token,
|
1063 |
+
self.get_model().config.lowres_token,
|
1064 |
+
-1,
|
1065 |
+
)
|
1066 |
+
image_features_unpadded = []
|
1067 |
+
image_features_downsample = []
|
1068 |
+
final_size = []
|
1069 |
+
if self.get_model().config.mm_projector_type == "sva":
|
1070 |
+
(
|
1071 |
+
vision_tower_aux_feature_list_final,
|
1072 |
+
vision_tower_aux_attention_masks_list_final,
|
1073 |
+
) = self.rearrange_vision_tower_features_inference(
|
1074 |
+
vision_tower_aux_feature_list, final_height, image_sizes, unpad=True
|
1075 |
+
)
|
1076 |
+
global_context_feature_final = []
|
1077 |
+
for batch_i in range(bs):
|
1078 |
+
cur_image_feature = image_features[batch_i]
|
1079 |
+
image_size = image_sizes[batch_i]
|
1080 |
+
|
1081 |
+
cur_image_feature = unpad_image(
|
1082 |
+
cur_image_feature.unsqueeze(0), image_size
|
1083 |
+
)
|
1084 |
+
|
1085 |
+
cur_h, cur_w = cur_image_feature.shape[1:3]
|
1086 |
+
try: # fix bug for some invalid image
|
1087 |
+
cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1)
|
1088 |
+
final_size.append((cur_h, cur_w))
|
1089 |
+
except:
|
1090 |
+
# print(f"invalid after unpad {image_features[batch_i].shape}, {image_sizes[batch_i]}", flush=True)
|
1091 |
+
cur_image_feature = image_features[batch_i].unsqueeze(0)
|
1092 |
+
image_size = image_sizes[batch_i]
|
1093 |
+
cur_h, cur_w = cur_image_feature.shape[1:3]
|
1094 |
+
cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1)
|
1095 |
+
final_size.append((cur_h, cur_w))
|
1096 |
+
|
1097 |
+
if (getattr(self.config, "highres", False)) and input_mix_res:
|
1098 |
+
cur_image_feature_down = unpad_image(
|
1099 |
+
image_features_down[batch_i].unsqueeze(0),
|
1100 |
+
(
|
1101 |
+
int(
|
1102 |
+
image_size[0]
|
1103 |
+
/ (
|
1104 |
+
image_token_len**0.5
|
1105 |
+
/ self.get_model().config.lowres_token
|
1106 |
+
)
|
1107 |
+
),
|
1108 |
+
int(
|
1109 |
+
image_size[1]
|
1110 |
+
/ (
|
1111 |
+
image_token_len**0.5
|
1112 |
+
/ self.get_model().config.lowres_token
|
1113 |
+
)
|
1114 |
+
),
|
1115 |
+
),
|
1116 |
+
)
|
1117 |
+
_cur_h, _cur_w = cur_image_feature_down.shape[1:3]
|
1118 |
+
|
1119 |
+
try: # fix bug for some invalid image
|
1120 |
+
cur_image_feature_down = cur_image_feature_down.view(
|
1121 |
+
1, _cur_h, _cur_w, -1
|
1122 |
+
)
|
1123 |
+
except:
|
1124 |
+
print("invalid after unpad", flush=True)
|
1125 |
+
cur_image_feature_down = image_features_down[batch_i].unsqueeze(
|
1126 |
+
0
|
1127 |
+
)
|
1128 |
+
_cur_h, _cur_w = cur_image_feature_down.shape[1:3]
|
1129 |
+
cur_image_feature_down = cur_image_feature_down.view(
|
1130 |
+
1, _cur_h, _cur_w, -1
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
cur_image_feature_down = torch.cat(
|
1134 |
+
(
|
1135 |
+
cur_image_feature_down,
|
1136 |
+
self.model.image_newline.view(1, 1, 1, -1)
|
1137 |
+
.expand(1, _cur_h, 1, -1)
|
1138 |
+
.to(cur_image_feature_down.device),
|
1139 |
+
),
|
1140 |
+
dim=2,
|
1141 |
+
).flatten(1, 2)
|
1142 |
+
|
1143 |
+
if split_sizes is None and getattr(self.config, "frame_pos", False):
|
1144 |
+
frame_pos = (
|
1145 |
+
self.get_model()
|
1146 |
+
.get_frame_pos(torch.arange(1))
|
1147 |
+
.to(cur_image_feature_down.device)
|
1148 |
+
.to(cur_image_feature_down.dtype)
|
1149 |
+
)
|
1150 |
+
cur_image_feature_down += frame_pos
|
1151 |
+
|
1152 |
+
image_features_downsample.append(cur_image_feature_down.squeeze(0))
|
1153 |
+
|
1154 |
+
cur_image_feature = torch.cat(
|
1155 |
+
(
|
1156 |
+
cur_image_feature,
|
1157 |
+
self.model.image_newline.view(1, 1, 1, -1)
|
1158 |
+
.expand(1, cur_h, 1, -1)
|
1159 |
+
.to(cur_image_feature.device),
|
1160 |
+
),
|
1161 |
+
dim=2,
|
1162 |
+
)
|
1163 |
+
|
1164 |
+
if split_sizes is None and getattr(self.config, "frame_pos", False):
|
1165 |
+
frame_pos = (
|
1166 |
+
self.get_model()
|
1167 |
+
.get_frame_pos(torch.arange(1))
|
1168 |
+
.to(cur_image_feature.device)
|
1169 |
+
.to(cur_image_feature.dtype)
|
1170 |
+
)
|
1171 |
+
cur_image_feature += frame_pos
|
1172 |
+
|
1173 |
+
cur_image_feature = cur_image_feature.flatten(1, 2)
|
1174 |
+
image_features_unpadded.append(cur_image_feature.squeeze(0))
|
1175 |
+
|
1176 |
+
if self.get_model().config.mm_projector_type == "sva":
|
1177 |
+
cur_global_context_feature = global_context_feature[batch_i].expand(
|
1178 |
+
cur_h * cur_w, 1, -1
|
1179 |
+
)
|
1180 |
+
global_context_feature_final.append(cur_global_context_feature)
|
1181 |
+
if self.get_model().config.mm_projector_type == "sva":
|
1182 |
+
global_context_feature_final = torch.cat(
|
1183 |
+
global_context_feature_final, 0
|
1184 |
+
)
|
1185 |
+
|
1186 |
+
if (getattr(self.config, "highres", False)) and input_mix_res:
|
1187 |
+
image_features = image_features_downsample
|
1188 |
+
else:
|
1189 |
+
image_features = image_features_unpadded
|
1190 |
+
|
1191 |
+
# TODO: image start / end is not implemented here to support pretraining.
|
1192 |
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
|
1193 |
+
self.config, "mm_use_im_start_end", False
|
1194 |
+
):
|
1195 |
+
raise NotImplementedError
|
1196 |
+
|
1197 |
+
split_image_features_unpadded = None
|
1198 |
+
frame_split_sizes = None
|
1199 |
+
|
1200 |
+
if split_sizes is not None:
|
1201 |
+
split_image_features = []
|
1202 |
+
split_image_features_unpadded = (
|
1203 |
+
[]
|
1204 |
+
if (getattr(self.config, "highres", False)) and input_mix_res
|
1205 |
+
else None
|
1206 |
+
)
|
1207 |
+
start_idx = 0
|
1208 |
+
for split_batch_idx, split_size in enumerate(split_sizes):
|
1209 |
+
if isinstance(image_features[start_idx : start_idx + split_size], list):
|
1210 |
+
if getattr(self.config, "frame_pos", False):
|
1211 |
+
frame_feature = torch.cat(
|
1212 |
+
image_features[start_idx : start_idx + split_size], dim=0
|
1213 |
+
).reshape(split_size, -1, image_features[0].shape[-1])
|
1214 |
+
frame_pos = (
|
1215 |
+
self.get_model()
|
1216 |
+
.get_frame_pos(selected_frame_indices_all[split_batch_idx])
|
1217 |
+
.to(frame_feature.device)
|
1218 |
+
.to(frame_feature.dtype)
|
1219 |
+
)
|
1220 |
+
frame_feature += frame_pos
|
1221 |
+
split_image_features.append(
|
1222 |
+
frame_feature.reshape(-1, image_features[0].shape[-1])
|
1223 |
+
)
|
1224 |
+
else:
|
1225 |
+
split_image_features.append(
|
1226 |
+
torch.cat(
|
1227 |
+
image_features[start_idx : start_idx + split_size],
|
1228 |
+
dim=0,
|
1229 |
+
)
|
1230 |
+
)
|
1231 |
+
if (getattr(self.config, "highres", False)) and input_mix_res:
|
1232 |
+
if getattr(self.config, "frame_pos", False):
|
1233 |
+
frame_feature = torch.cat(
|
1234 |
+
image_features_unpadded[
|
1235 |
+
start_idx : start_idx + split_size
|
1236 |
+
],
|
1237 |
+
dim=0,
|
1238 |
+
).reshape(split_size, -1, image_features[0].shape[-1])
|
1239 |
+
frame_pos = (
|
1240 |
+
self.get_model()
|
1241 |
+
.get_frame_pos(
|
1242 |
+
selected_frame_indices_all[split_batch_idx]
|
1243 |
+
)
|
1244 |
+
.to(frame_feature.device)
|
1245 |
+
.to(frame_feature.dtype)
|
1246 |
+
)
|
1247 |
+
frame_feature += frame_pos
|
1248 |
+
split_image_features_unpadded.append(
|
1249 |
+
frame_feature.reshape(-1, image_features[0].shape[-1])
|
1250 |
+
)
|
1251 |
+
else:
|
1252 |
+
split_image_features_unpadded.append(
|
1253 |
+
torch.cat(
|
1254 |
+
image_features_unpadded[
|
1255 |
+
start_idx : start_idx + split_size
|
1256 |
+
],
|
1257 |
+
dim=0,
|
1258 |
+
)
|
1259 |
+
)
|
1260 |
+
else:
|
1261 |
+
if getattr(self.config, "frame_pos", False):
|
1262 |
+
frame_feature = image_features[
|
1263 |
+
start_idx : start_idx + split_size
|
1264 |
+
].reshape(split_size, -1, image_features[0].shape[-1])
|
1265 |
+
frame_pos = (
|
1266 |
+
self.get_model()
|
1267 |
+
.get_frame_pos(selected_frame_indices_all[split_batch_idx])
|
1268 |
+
.to(frame_feature.device)
|
1269 |
+
.to(frame_feature.dtype)
|
1270 |
+
)
|
1271 |
+
frame_feature += frame_pos
|
1272 |
+
split_image_features.append(
|
1273 |
+
frame_feature.reshape(-1, image_features[0].shape[-1])
|
1274 |
+
)
|
1275 |
+
else:
|
1276 |
+
split_image_features.append(
|
1277 |
+
image_features[start_idx : start_idx + split_size]
|
1278 |
+
)
|
1279 |
+
if (getattr(self.config, "highres", False)) and input_mix_res:
|
1280 |
+
if getattr(self.config, "frame_pos", False):
|
1281 |
+
frame_feature = image_features_unpadded[
|
1282 |
+
start_idx : start_idx + split_size
|
1283 |
+
]
|
1284 |
+
frame_pos = (
|
1285 |
+
self.get_model()
|
1286 |
+
.get_frame_pos(
|
1287 |
+
selected_frame_indices_all[split_batch_idx]
|
1288 |
+
)
|
1289 |
+
.to(frame_feature.device)
|
1290 |
+
.to(frame_feature.dtype)
|
1291 |
+
)
|
1292 |
+
frame_feature += frame_pos
|
1293 |
+
split_image_features_unpadded.append(
|
1294 |
+
frame_feature.reshape(-1, image_features[0].shape[-1])
|
1295 |
+
)
|
1296 |
+
else:
|
1297 |
+
split_image_features_unpadded.append(
|
1298 |
+
image_features_unpadded[
|
1299 |
+
start_idx : start_idx + split_size
|
1300 |
+
]
|
1301 |
+
)
|
1302 |
+
start_idx += split_size
|
1303 |
+
image_features = split_image_features
|
1304 |
+
frame_split_sizes = split_sizes
|
1305 |
+
|
1306 |
+
_labels = labels
|
1307 |
+
_position_ids = position_ids
|
1308 |
+
_attention_mask = attention_mask
|
1309 |
+
if attention_mask is None:
|
1310 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
1311 |
+
else:
|
1312 |
+
attention_mask = attention_mask.bool()
|
1313 |
+
if position_ids is None:
|
1314 |
+
position_ids = torch.arange(
|
1315 |
+
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
|
1316 |
+
)
|
1317 |
+
if labels is None:
|
1318 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
1319 |
+
|
1320 |
+
# remove the padding using attention_mask -- FIXME
|
1321 |
+
_input_ids = input_ids
|
1322 |
+
|
1323 |
+
attention_mask = attention_mask | (input_ids == IMAGE_TOKEN_INDEX)
|
1324 |
+
|
1325 |
+
input_ids = [
|
1326 |
+
cur_input_ids[cur_attention_mask]
|
1327 |
+
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
|
1328 |
+
]
|
1329 |
+
labels = [
|
1330 |
+
cur_labels[cur_attention_mask]
|
1331 |
+
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
|
1332 |
+
]
|
1333 |
+
|
1334 |
+
new_input_embeds = []
|
1335 |
+
new_labels = []
|
1336 |
+
image_token_indices_batch = []
|
1337 |
+
cur_image_idx = 0
|
1338 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
1339 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
1340 |
+
if num_images == 0:
|
1341 |
+
cur_image_features = image_features[cur_image_idx]
|
1342 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
1343 |
+
cur_input_embeds = torch.cat(
|
1344 |
+
[cur_input_embeds_1, cur_image_features[0:0]], dim=0
|
1345 |
+
)
|
1346 |
+
new_input_embeds.append(cur_input_embeds)
|
1347 |
+
new_labels.append(labels[batch_idx])
|
1348 |
+
cur_image_idx += 1
|
1349 |
+
continue
|
1350 |
+
|
1351 |
+
image_token_indices = (
|
1352 |
+
[-1]
|
1353 |
+
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
|
1354 |
+
+ [cur_input_ids.shape[0]]
|
1355 |
+
)
|
1356 |
+
image_token_indices_batch.append(
|
1357 |
+
torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()[0]
|
1358 |
+
)
|
1359 |
+
cur_input_ids_noim = []
|
1360 |
+
cur_labels = labels[batch_idx]
|
1361 |
+
cur_labels_noim = []
|
1362 |
+
for i in range(len(image_token_indices) - 1):
|
1363 |
+
cur_input_ids_noim.append(
|
1364 |
+
cur_input_ids[
|
1365 |
+
image_token_indices[i] + 1 : image_token_indices[i + 1]
|
1366 |
+
]
|
1367 |
+
)
|
1368 |
+
cur_labels_noim.append(
|
1369 |
+
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
|
1370 |
+
)
|
1371 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
1372 |
+
cur_input_embeds = self.get_model().embed_tokens(
|
1373 |
+
torch.cat(cur_input_ids_noim)
|
1374 |
+
)
|
1375 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
1376 |
+
cur_new_input_embeds = []
|
1377 |
+
cur_new_labels = []
|
1378 |
+
|
1379 |
+
text_len = sum([x.shape[0] for x in cur_input_embeds_no_im])
|
1380 |
+
visual_len = len(image_features[cur_image_idx])
|
1381 |
+
max_visual_len = (
|
1382 |
+
self.get_model().config.tokenizer_model_max_length
|
1383 |
+
- getattr(self.get_model().config, "inference_max_length", 16)
|
1384 |
+
- text_len
|
1385 |
+
)
|
1386 |
+
mix_token = False
|
1387 |
+
|
1388 |
+
# ablation mix
|
1389 |
+
if (
|
1390 |
+
input_mix_res
|
1391 |
+
and (
|
1392 |
+
self.get_model().config.image_token_len
|
1393 |
+
> getattr(self.get_model().config, "lowres_token", 8) ** 2
|
1394 |
+
)
|
1395 |
+
and frame_split_sizes is not None
|
1396 |
+
and getattr(self.config, "highres", False)
|
1397 |
+
):
|
1398 |
+
if max_visual_len > visual_len:
|
1399 |
+
visual_emb = image_features[cur_image_idx]
|
1400 |
+
text_emb = cur_input_embeds_no_im[-1]
|
1401 |
+
highres_num = math.floor(
|
1402 |
+
(max_visual_len - visual_len)
|
1403 |
+
/ (
|
1404 |
+
split_image_features_unpadded[cur_image_idx].shape[0]
|
1405 |
+
// frame_split_sizes[cur_image_idx]
|
1406 |
+
- visual_emb.shape[0] // frame_split_sizes[cur_image_idx]
|
1407 |
+
)
|
1408 |
+
)
|
1409 |
+
if highres_num >= 1:
|
1410 |
+
mix_token = True
|
1411 |
+
sim = torch.matmul(visual_emb, text_emb.transpose(0, 1)).mean(
|
1412 |
+
dim=-1
|
1413 |
+
)
|
1414 |
+
sim_frame = sim.reshape(
|
1415 |
+
frame_split_sizes[cur_image_idx], -1
|
1416 |
+
).mean(dim=-1)
|
1417 |
+
highres_num = min(highres_num, sim_frame.shape[0])
|
1418 |
+
top_values, top_indices = torch.topk(sim_frame, highres_num)
|
1419 |
+
if len(top_indices) > 0:
|
1420 |
+
sorted_indices = torch.sort(top_indices)[1]
|
1421 |
+
top_indices = top_indices[sorted_indices]
|
1422 |
+
visual_emb_frame = image_features[cur_image_idx].reshape(
|
1423 |
+
frame_split_sizes[cur_image_idx],
|
1424 |
+
-1,
|
1425 |
+
image_features[cur_image_idx].shape[-1],
|
1426 |
+
)
|
1427 |
+
visual_emb_frame_highres = split_image_features_unpadded[
|
1428 |
+
cur_image_idx
|
1429 |
+
].reshape(
|
1430 |
+
frame_split_sizes[cur_image_idx],
|
1431 |
+
-1,
|
1432 |
+
split_image_features_unpadded[cur_image_idx].shape[-1],
|
1433 |
+
)
|
1434 |
+
current_point = 0
|
1435 |
+
mix_visual_emb_frame = []
|
1436 |
+
for frame_i in range(len(visual_emb_frame)):
|
1437 |
+
if current_point > len(top_indices) - 1:
|
1438 |
+
mix_visual_emb_frame.append(
|
1439 |
+
visual_emb_frame[frame_i]
|
1440 |
+
)
|
1441 |
+
continue
|
1442 |
+
if frame_i == top_indices[current_point]:
|
1443 |
+
mix_visual_emb_frame.append(
|
1444 |
+
visual_emb_frame_highres[frame_i]
|
1445 |
+
)
|
1446 |
+
current_point += 1
|
1447 |
+
else:
|
1448 |
+
mix_visual_emb_frame.append(
|
1449 |
+
visual_emb_frame[frame_i]
|
1450 |
+
)
|
1451 |
+
image_features[cur_image_idx] = torch.cat(
|
1452 |
+
mix_visual_emb_frame, dim=0
|
1453 |
+
)
|
1454 |
+
# ablation drop
|
1455 |
+
|
1456 |
+
if (
|
1457 |
+
max_visual_len < visual_len
|
1458 |
+
and frame_split_sizes is not None
|
1459 |
+
and not mix_token
|
1460 |
+
):
|
1461 |
+
visual_emb_frame = image_features[cur_image_idx].reshape(
|
1462 |
+
frame_split_sizes[cur_image_idx],
|
1463 |
+
-1,
|
1464 |
+
image_features[cur_image_idx].shape[-1],
|
1465 |
+
)
|
1466 |
+
|
1467 |
+
sim = F.cosine_similarity(
|
1468 |
+
visual_emb_frame[:-1],
|
1469 |
+
visual_emb_frame[1:],
|
1470 |
+
dim=-1,
|
1471 |
+
)
|
1472 |
+
|
1473 |
+
new_visual_emb_frames = []
|
1474 |
+
for start_idx in range(0, len(visual_emb_frame), 8):
|
1475 |
+
end_idx = min(start_idx + 8, len(visual_emb_frame))
|
1476 |
+
chunk_feature = visual_emb_frame[start_idx:end_idx] # 8, HW, C
|
1477 |
+
if len(chunk_feature) == 1:
|
1478 |
+
new_visual_emb_frames.append(chunk_feature[0])
|
1479 |
+
continue
|
1480 |
+
sim = F.cosine_similarity(
|
1481 |
+
chunk_feature[0]
|
1482 |
+
.unsqueeze(0)
|
1483 |
+
.repeat_interleave(len(chunk_feature[1:]), dim=0),
|
1484 |
+
chunk_feature[1:],
|
1485 |
+
dim=-1,
|
1486 |
+
)
|
1487 |
+
new_visual_emb_frame = torch.cat(
|
1488 |
+
[
|
1489 |
+
chunk_feature[0],
|
1490 |
+
chunk_feature[1:].flatten(0, 1)[
|
1491 |
+
sim.flatten(0, 1)
|
1492 |
+
< getattr(
|
1493 |
+
self.get_model().config, "drop_threshold", 0.7
|
1494 |
+
)
|
1495 |
+
],
|
1496 |
+
],
|
1497 |
+
dim=0,
|
1498 |
+
)
|
1499 |
+
new_visual_emb_frames.append(new_visual_emb_frame)
|
1500 |
+
|
1501 |
+
reduced_visual_len = sum([x.shape[0] for x in new_visual_emb_frames])
|
1502 |
+
|
1503 |
+
if reduced_visual_len > max_visual_len:
|
1504 |
+
force_remove = math.ceil(
|
1505 |
+
(reduced_visual_len - max_visual_len)
|
1506 |
+
/ len(new_visual_emb_frames)
|
1507 |
+
)
|
1508 |
+
for chunk_i in range(len(new_visual_emb_frames)):
|
1509 |
+
new_visual_emb_frames[chunk_i] = new_visual_emb_frames[chunk_i][
|
1510 |
+
:-force_remove
|
1511 |
+
]
|
1512 |
+
new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0)
|
1513 |
+
else:
|
1514 |
+
new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0)
|
1515 |
+
|
1516 |
+
image_features[cur_image_idx] = new_visual_emb_frames[:max_visual_len]
|
1517 |
+
|
1518 |
+
for i in range(num_images + 1):
|
1519 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
1520 |
+
cur_new_labels.append(cur_labels_noim[i])
|
1521 |
+
if i < num_images:
|
1522 |
+
cur_image_features = image_features[cur_image_idx]
|
1523 |
+
cur_image_idx += 1
|
1524 |
+
cur_new_input_embeds.append(cur_image_features)
|
1525 |
+
cur_new_labels.append(
|
1526 |
+
torch.full(
|
1527 |
+
(cur_image_features.shape[0],),
|
1528 |
+
IGNORE_INDEX,
|
1529 |
+
device=cur_labels.device,
|
1530 |
+
dtype=cur_labels.dtype,
|
1531 |
+
)
|
1532 |
+
)
|
1533 |
+
|
1534 |
+
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
1535 |
+
|
1536 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
1537 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
1538 |
+
|
1539 |
+
new_input_embeds.append(cur_new_input_embeds)
|
1540 |
+
new_labels.append(cur_new_labels)
|
1541 |
+
|
1542 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
1543 |
+
tokenizer_model_max_length = getattr(
|
1544 |
+
self.config, "tokenizer_model_max_length", None
|
1545 |
+
)
|
1546 |
+
if tokenizer_model_max_length is not None:
|
1547 |
+
new_input_embeds = [
|
1548 |
+
x[:tokenizer_model_max_length] for x in new_input_embeds
|
1549 |
+
]
|
1550 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
1551 |
+
|
1552 |
+
# Combine them
|
1553 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
1554 |
+
batch_size = len(new_input_embeds)
|
1555 |
+
|
1556 |
+
new_input_embeds_padded = []
|
1557 |
+
new_labels_padded = torch.full(
|
1558 |
+
(batch_size, max_len),
|
1559 |
+
IGNORE_INDEX,
|
1560 |
+
dtype=new_labels[0].dtype,
|
1561 |
+
device=new_labels[0].device,
|
1562 |
+
)
|
1563 |
+
attention_mask = torch.zeros(
|
1564 |
+
(batch_size, max_len),
|
1565 |
+
dtype=attention_mask.dtype,
|
1566 |
+
device=attention_mask.device,
|
1567 |
+
)
|
1568 |
+
position_ids = torch.zeros(
|
1569 |
+
(batch_size, max_len),
|
1570 |
+
dtype=position_ids.dtype,
|
1571 |
+
device=position_ids.device,
|
1572 |
+
)
|
1573 |
+
|
1574 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(
|
1575 |
+
zip(new_input_embeds, new_labels)
|
1576 |
+
):
|
1577 |
+
cur_len = cur_new_embed.shape[0]
|
1578 |
+
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
1579 |
+
new_input_embeds_padded.append(
|
1580 |
+
torch.cat(
|
1581 |
+
(
|
1582 |
+
torch.zeros(
|
1583 |
+
(max_len - cur_len, cur_new_embed.shape[1]),
|
1584 |
+
dtype=cur_new_embed.dtype,
|
1585 |
+
device=cur_new_embed.device,
|
1586 |
+
),
|
1587 |
+
cur_new_embed,
|
1588 |
+
),
|
1589 |
+
dim=0,
|
1590 |
+
)
|
1591 |
+
)
|
1592 |
+
if cur_len > 0:
|
1593 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
1594 |
+
attention_mask[i, -cur_len:] = True
|
1595 |
+
position_ids[i, -cur_len:] = torch.arange(
|
1596 |
+
0,
|
1597 |
+
cur_len,
|
1598 |
+
dtype=position_ids.dtype,
|
1599 |
+
device=position_ids.device,
|
1600 |
+
)
|
1601 |
+
else:
|
1602 |
+
new_input_embeds_padded.append(
|
1603 |
+
torch.cat(
|
1604 |
+
(
|
1605 |
+
cur_new_embed,
|
1606 |
+
torch.zeros(
|
1607 |
+
(max_len - cur_len, cur_new_embed.shape[1]),
|
1608 |
+
dtype=cur_new_embed.dtype,
|
1609 |
+
device=cur_new_embed.device,
|
1610 |
+
),
|
1611 |
+
),
|
1612 |
+
dim=0,
|
1613 |
+
)
|
1614 |
+
)
|
1615 |
+
if cur_len > 0:
|
1616 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
1617 |
+
attention_mask[i, :cur_len] = True
|
1618 |
+
position_ids[i, :cur_len] = torch.arange(
|
1619 |
+
0,
|
1620 |
+
cur_len,
|
1621 |
+
dtype=position_ids.dtype,
|
1622 |
+
device=position_ids.device,
|
1623 |
+
)
|
1624 |
+
|
1625 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
1626 |
+
|
1627 |
+
if _labels is None:
|
1628 |
+
new_labels = None
|
1629 |
+
else:
|
1630 |
+
new_labels = new_labels_padded
|
1631 |
+
|
1632 |
+
if _attention_mask is None:
|
1633 |
+
attention_mask = None
|
1634 |
+
else:
|
1635 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
1636 |
+
|
1637 |
+
if _position_ids is None:
|
1638 |
+
position_ids = None
|
1639 |
+
|
1640 |
+
return (
|
1641 |
+
None,
|
1642 |
+
position_ids,
|
1643 |
+
attention_mask,
|
1644 |
+
past_key_values,
|
1645 |
+
new_input_embeds,
|
1646 |
+
new_labels,
|
1647 |
+
vision_tower_aux_feature_list_final,
|
1648 |
+
vision_tower_aux_attention_masks_list_final,
|
1649 |
+
final_size,
|
1650 |
+
global_context_feature_final,
|
1651 |
+
)
|
1652 |
+
|
1653 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
1654 |
+
if model_args.mm_use_im_patch_token:
|
1655 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
1656 |
+
self.resize_token_embeddings(len(tokenizer))
|
1657 |
+
|
1658 |
+
if model_args.mm_use_im_start_end:
|
1659 |
+
num_new_tokens = tokenizer.add_tokens(
|
1660 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
1661 |
+
)
|
1662 |
+
self.resize_token_embeddings(len(tokenizer))
|
1663 |
+
|
1664 |
+
if num_new_tokens > 0:
|
1665 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
1666 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
1667 |
+
|
1668 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
1669 |
+
dim=0, keepdim=True
|
1670 |
+
)
|
1671 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
1672 |
+
dim=0, keepdim=True
|
1673 |
+
)
|
1674 |
+
|
1675 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
1676 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
1677 |
+
|
1678 |
+
if model_args.tune_mm_mlp_adapter:
|
1679 |
+
for p in self.get_input_embeddings().parameters():
|
1680 |
+
p.requires_grad = True
|
1681 |
+
for p in self.get_output_embeddings().parameters():
|
1682 |
+
p.requires_grad = False
|
1683 |
+
|
1684 |
+
if model_args.pretrain_mm_mlp_adapter:
|
1685 |
+
mm_projector_weights = torch.load(
|
1686 |
+
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
|
1687 |
+
)
|
1688 |
+
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
1689 |
+
assert num_new_tokens == 2
|
1690 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
1691 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[
|
1692 |
+
-num_new_tokens:
|
1693 |
+
]
|
1694 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
1695 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
1696 |
+
else:
|
1697 |
+
raise ValueError(
|
1698 |
+
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
|
1699 |
+
)
|
1700 |
+
elif model_args.mm_use_im_patch_token:
|
1701 |
+
if model_args.tune_mm_mlp_adapter:
|
1702 |
+
for p in self.get_input_embeddings().parameters():
|
1703 |
+
p.requires_grad = False
|
1704 |
+
for p in self.get_output_embeddings().parameters():
|
1705 |
+
p.requires_grad = False
|
longvu/consolidate.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pyre-unsafe
|
2 |
+
"""
|
3 |
+
Usage:
|
4 |
+
python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
+
from longvu import * # noqa
|
12 |
+
from .utils import auto_upgrade
|
13 |
+
|
14 |
+
|
15 |
+
def consolidate_ckpt(src_path, dst_path):
|
16 |
+
print("Loading model")
|
17 |
+
auto_upgrade(src_path)
|
18 |
+
src_model = AutoModelForCausalLM.from_pretrained(
|
19 |
+
src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
20 |
+
)
|
21 |
+
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
|
22 |
+
src_model.save_pretrained(dst_path)
|
23 |
+
src_tokenizer.save_pretrained(dst_path)
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument("--src", type=str, required=True)
|
29 |
+
parser.add_argument("--dst", type=str, required=True)
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
consolidate_ckpt(args.src, args.dst)
|
longvu/constants.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
13 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
longvu/conversation.py
ADDED
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import dataclasses
|
3 |
+
from enum import auto, Enum
|
4 |
+
from io import BytesIO
|
5 |
+
from typing import Any, Dict, List, Tuple, Union
|
6 |
+
|
7 |
+
from longvu.file_io import PathManager
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
from transformers import AutoTokenizer
|
11 |
+
|
12 |
+
|
13 |
+
class SeparatorStyle(Enum):
|
14 |
+
"""Different separator style."""
|
15 |
+
|
16 |
+
SINGLE = auto()
|
17 |
+
TWO = auto()
|
18 |
+
MPT = auto()
|
19 |
+
PLAIN = auto()
|
20 |
+
LLAMA_2 = auto()
|
21 |
+
LLAMA_3 = auto()
|
22 |
+
LLAMA_3_1 = auto()
|
23 |
+
LLAMA_3_2 = auto()
|
24 |
+
QWEN = auto()
|
25 |
+
CHATML = auto()
|
26 |
+
|
27 |
+
|
28 |
+
@dataclasses.dataclass
|
29 |
+
class Conversation:
|
30 |
+
"""A class that keeps all conversation history."""
|
31 |
+
|
32 |
+
system: str
|
33 |
+
roles: List[str]
|
34 |
+
messages: List[List[str]]
|
35 |
+
offset: int
|
36 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
37 |
+
sep: str = "###"
|
38 |
+
# pyre-fixme[8]: Attribute has type `str`; used as `None`.
|
39 |
+
sep2: str = None
|
40 |
+
version: str = "Unknown"
|
41 |
+
|
42 |
+
tokenizer: Any = None
|
43 |
+
# Stop criteria (the default one is EOS token)
|
44 |
+
# pyre-fixme[8]: Attribute has type `Union[List[str], str]`; used as `None`.
|
45 |
+
stop_str: Union[str, List[str]] = None
|
46 |
+
# Stops generation if meeting any token in this list
|
47 |
+
# pyre-fixme[8]: Attribute has type `List[int]`; used as `None`.
|
48 |
+
stop_token_ids: List[int] = None
|
49 |
+
|
50 |
+
skip_next: bool = False
|
51 |
+
|
52 |
+
def get_prompt(self):
|
53 |
+
messages = self.messages
|
54 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
55 |
+
messages = self.messages.copy()
|
56 |
+
init_role, init_msg = messages[0].copy()
|
57 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
58 |
+
if "mmtag" in self.version:
|
59 |
+
messages[0] = (init_role, init_msg)
|
60 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
61 |
+
messages.insert(1, (self.roles[1], "Received."))
|
62 |
+
else:
|
63 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
64 |
+
|
65 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
66 |
+
ret = self.system + self.sep
|
67 |
+
for role, message in messages:
|
68 |
+
if message:
|
69 |
+
if type(message) is tuple:
|
70 |
+
message, _, _ = message
|
71 |
+
ret += role + ": " + message + self.sep
|
72 |
+
else:
|
73 |
+
ret += role + ":"
|
74 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
75 |
+
seps = [self.sep, self.sep2]
|
76 |
+
ret = self.system + seps[0]
|
77 |
+
for i, (role, message) in enumerate(messages):
|
78 |
+
if message:
|
79 |
+
if type(message) is tuple:
|
80 |
+
message, _, _ = message
|
81 |
+
ret += role + ": " + message + seps[i % 2]
|
82 |
+
else:
|
83 |
+
ret += role + ":"
|
84 |
+
|
85 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
86 |
+
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
87 |
+
for role, message in messages:
|
88 |
+
if message:
|
89 |
+
if type(message) is tuple:
|
90 |
+
message, images, _ = message
|
91 |
+
message = "<image>" * len(images) + message
|
92 |
+
ret += role + "\n" + message + self.sep + "\n"
|
93 |
+
else:
|
94 |
+
ret += role + "\n"
|
95 |
+
return ret
|
96 |
+
|
97 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
98 |
+
ret = self.system + self.sep
|
99 |
+
for role, message in messages:
|
100 |
+
if message:
|
101 |
+
if type(message) is tuple:
|
102 |
+
message, _, _ = message
|
103 |
+
ret += role + message + self.sep
|
104 |
+
else:
|
105 |
+
ret += role
|
106 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
107 |
+
wrap_sys = lambda msg: (
|
108 |
+
f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
109 |
+
)
|
110 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
111 |
+
ret = ""
|
112 |
+
|
113 |
+
for i, (role, message) in enumerate(messages):
|
114 |
+
if i == 0:
|
115 |
+
assert message, "first message should not be none"
|
116 |
+
assert role == self.roles[0], "first message should come from user"
|
117 |
+
if message:
|
118 |
+
if type(message) is tuple:
|
119 |
+
message, _, _ = message
|
120 |
+
if i == 0:
|
121 |
+
message = wrap_sys(self.system) + message
|
122 |
+
if i % 2 == 0:
|
123 |
+
message = wrap_inst(message)
|
124 |
+
ret += self.sep + message
|
125 |
+
else:
|
126 |
+
ret += " " + message + " " + self.sep2
|
127 |
+
else:
|
128 |
+
ret += ""
|
129 |
+
ret = ret.lstrip(self.sep)
|
130 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
131 |
+
if self.tokenizer is None:
|
132 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
133 |
+
PathManager.get_local_path(
|
134 |
+
"manifold://xr_core_ai_asl_llm/tree/users/shenx/models/Cambrian-Llama3_1-8b-t576/"
|
135 |
+
)
|
136 |
+
)
|
137 |
+
chat_template_messages = [{"role": "system", "content": self.system}]
|
138 |
+
for role, message in messages:
|
139 |
+
if message:
|
140 |
+
if type(message) is tuple:
|
141 |
+
message, images = message
|
142 |
+
message = "<image>" * len(images) + message
|
143 |
+
chat_template_messages.append({"role": role, "content": message})
|
144 |
+
|
145 |
+
# print("chat", chat_template_messages, flush=True)
|
146 |
+
return self.tokenizer.apply_chat_template(
|
147 |
+
chat_template_messages, tokenize=False, add_generation_prompt=True
|
148 |
+
)
|
149 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3_1:
|
150 |
+
if self.tokenizer is None:
|
151 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
152 |
+
PathManager.get_local_path(
|
153 |
+
"manifold://xr_core_ai_asl_llm/tree/users/shenx/models/Cambrian-Llama3_1-8b-t576/"
|
154 |
+
)
|
155 |
+
)
|
156 |
+
chat_template_messages = [{"role": "system", "content": self.system}]
|
157 |
+
for role, message in messages:
|
158 |
+
if message:
|
159 |
+
if type(message) is tuple:
|
160 |
+
message, images = message
|
161 |
+
message = "<image>" * len(images) + message
|
162 |
+
chat_template_messages.append({"role": role, "content": message})
|
163 |
+
|
164 |
+
return self.tokenizer.apply_chat_template(
|
165 |
+
chat_template_messages, tokenize=False, add_generation_prompt=False
|
166 |
+
)
|
167 |
+
elif (
|
168 |
+
# self.sep_style == SeparatorStyle.LLAMA_3 or
|
169 |
+
self.sep_style
|
170 |
+
== SeparatorStyle.LLAMA_3_2
|
171 |
+
):
|
172 |
+
wrap_sys = lambda msg: (
|
173 |
+
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{msg}<|eot_id|>"
|
174 |
+
if len(msg) > 0
|
175 |
+
else msg
|
176 |
+
)
|
177 |
+
wrap_inst_user = (
|
178 |
+
lambda msg: f"<|start_header_id|>user<|end_header_id|>{msg}<|eot_id|>"
|
179 |
+
)
|
180 |
+
wrap_inst_assistant = (
|
181 |
+
lambda msg: f"<|start_header_id|>assistant<|end_header_id|>{msg}<|eot_id|>"
|
182 |
+
)
|
183 |
+
ret = ""
|
184 |
+
|
185 |
+
for i, (role, message) in enumerate(messages):
|
186 |
+
if i == 0:
|
187 |
+
assert message, "first message should not be none"
|
188 |
+
assert role == self.roles[0], "first message should come from user"
|
189 |
+
if message:
|
190 |
+
if type(message) is tuple:
|
191 |
+
message, _, _ = message
|
192 |
+
if i == 0:
|
193 |
+
ret += wrap_sys(self.system)
|
194 |
+
|
195 |
+
if i % 2 == 0:
|
196 |
+
message = wrap_inst_user(message)
|
197 |
+
ret += message
|
198 |
+
else:
|
199 |
+
message = wrap_inst_assistant(message)
|
200 |
+
ret += message
|
201 |
+
else:
|
202 |
+
ret += ""
|
203 |
+
ret += "<|start_header_id|>assistant<|end_header_id|>"
|
204 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
205 |
+
seps = [self.sep, self.sep2]
|
206 |
+
ret = self.system
|
207 |
+
for i, (role, message) in enumerate(messages):
|
208 |
+
if message:
|
209 |
+
if type(message) is tuple:
|
210 |
+
message, _, _ = message
|
211 |
+
ret += message + seps[i % 2]
|
212 |
+
else:
|
213 |
+
ret += ""
|
214 |
+
else:
|
215 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
216 |
+
|
217 |
+
return ret
|
218 |
+
|
219 |
+
def append_message(self, role, message):
|
220 |
+
self.messages.append([role, message])
|
221 |
+
|
222 |
+
def process_image(
|
223 |
+
self,
|
224 |
+
image,
|
225 |
+
image_process_mode,
|
226 |
+
return_pil=False,
|
227 |
+
image_format="PNG",
|
228 |
+
max_len=1344,
|
229 |
+
min_len=672,
|
230 |
+
):
|
231 |
+
if image_process_mode == "Pad":
|
232 |
+
|
233 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
234 |
+
width, height = pil_img.size
|
235 |
+
if width == height:
|
236 |
+
return pil_img
|
237 |
+
elif width > height:
|
238 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
239 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
240 |
+
return result
|
241 |
+
else:
|
242 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
243 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
244 |
+
return result
|
245 |
+
|
246 |
+
image = expand2square(image)
|
247 |
+
elif image_process_mode in ["Default", "Crop"]:
|
248 |
+
pass
|
249 |
+
elif image_process_mode == "Resize":
|
250 |
+
image = image.resize((336, 336))
|
251 |
+
else:
|
252 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
253 |
+
if max(image.size) > max_len:
|
254 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
255 |
+
aspect_ratio = max_hw / min_hw
|
256 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
257 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
258 |
+
W, H = image.size
|
259 |
+
if H > W:
|
260 |
+
H, W = longest_edge, shortest_edge
|
261 |
+
else:
|
262 |
+
H, W = shortest_edge, longest_edge
|
263 |
+
image = image.resize((W, H))
|
264 |
+
if return_pil:
|
265 |
+
return image
|
266 |
+
else:
|
267 |
+
buffered = BytesIO()
|
268 |
+
image.save(buffered, format=image_format)
|
269 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
270 |
+
return img_b64_str
|
271 |
+
|
272 |
+
def get_images(self, return_pil=False):
|
273 |
+
images = []
|
274 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
275 |
+
if i % 2 == 0:
|
276 |
+
if type(msg) is tuple:
|
277 |
+
msg, image, image_process_mode = msg
|
278 |
+
image = self.process_image(
|
279 |
+
image, image_process_mode, return_pil=return_pil
|
280 |
+
)
|
281 |
+
images.append(image)
|
282 |
+
return images
|
283 |
+
|
284 |
+
def to_gradio_chatbot(self):
|
285 |
+
ret = []
|
286 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
287 |
+
if i % 2 == 0:
|
288 |
+
if type(msg) is tuple:
|
289 |
+
msg, image, image_process_mode = msg
|
290 |
+
img_b64_str = self.process_image(
|
291 |
+
image, "Default", return_pil=False, image_format="JPEG"
|
292 |
+
)
|
293 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
|
294 |
+
msg = img_str + msg.replace("<image>", "").strip()
|
295 |
+
ret.append([msg, None])
|
296 |
+
else:
|
297 |
+
ret.append([msg, None])
|
298 |
+
else:
|
299 |
+
ret[-1][-1] = msg
|
300 |
+
return ret
|
301 |
+
|
302 |
+
def copy(self):
|
303 |
+
return Conversation(
|
304 |
+
system=self.system,
|
305 |
+
roles=self.roles,
|
306 |
+
messages=[[x, y] for x, y in self.messages],
|
307 |
+
offset=self.offset,
|
308 |
+
sep_style=self.sep_style,
|
309 |
+
sep=self.sep,
|
310 |
+
sep2=self.sep2,
|
311 |
+
version=self.version,
|
312 |
+
)
|
313 |
+
|
314 |
+
def dict(self):
|
315 |
+
if len(self.get_images()) > 0:
|
316 |
+
return {
|
317 |
+
"system": self.system,
|
318 |
+
"roles": self.roles,
|
319 |
+
"messages": [
|
320 |
+
[x, y[0] if type(y) is tuple else y] for x, y in self.messages
|
321 |
+
],
|
322 |
+
"offset": self.offset,
|
323 |
+
"sep": self.sep,
|
324 |
+
"sep2": self.sep2,
|
325 |
+
}
|
326 |
+
return {
|
327 |
+
"system": self.system,
|
328 |
+
"roles": self.roles,
|
329 |
+
"messages": self.messages,
|
330 |
+
"offset": self.offset,
|
331 |
+
"sep": self.sep,
|
332 |
+
"sep2": self.sep2,
|
333 |
+
}
|
334 |
+
|
335 |
+
|
336 |
+
conv_vicuna_v0 = Conversation(
|
337 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
338 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
339 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
340 |
+
roles=("Human", "Assistant"),
|
341 |
+
# pyre-fixme[6]: For 3rd argument expected `List[List[str]]` but got
|
342 |
+
# `Tuple[Tuple[str, str], Tuple[str, str]]`.
|
343 |
+
messages=(
|
344 |
+
(
|
345 |
+
"Human",
|
346 |
+
"What are the key differences between renewable and non-renewable energy sources?",
|
347 |
+
),
|
348 |
+
(
|
349 |
+
"Assistant",
|
350 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
351 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
352 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
353 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
354 |
+
"renewable and non-renewable energy sources:\n"
|
355 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
356 |
+
"energy sources are finite and will eventually run out.\n"
|
357 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
358 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
359 |
+
"and other negative effects.\n"
|
360 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
361 |
+
"have lower operational costs than non-renewable sources.\n"
|
362 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
363 |
+
"locations than non-renewable sources.\n"
|
364 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
365 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
366 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
367 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
|
368 |
+
),
|
369 |
+
),
|
370 |
+
offset=2,
|
371 |
+
sep_style=SeparatorStyle.SINGLE,
|
372 |
+
sep="###",
|
373 |
+
)
|
374 |
+
|
375 |
+
conv_vicuna_v1 = Conversation(
|
376 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
377 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
378 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
379 |
+
roles=("USER", "ASSISTANT"),
|
380 |
+
version="v1",
|
381 |
+
# pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
|
382 |
+
messages=(),
|
383 |
+
offset=0,
|
384 |
+
sep_style=SeparatorStyle.TWO,
|
385 |
+
sep=" ",
|
386 |
+
sep2="</s>",
|
387 |
+
)
|
388 |
+
|
389 |
+
conv_llama_2 = Conversation(
|
390 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
391 |
+
|
392 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
393 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
394 |
+
roles=("USER", "ASSISTANT"),
|
395 |
+
version="llama_v2",
|
396 |
+
# pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
|
397 |
+
messages=(),
|
398 |
+
offset=0,
|
399 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
400 |
+
sep="<s>",
|
401 |
+
sep2="</s>",
|
402 |
+
)
|
403 |
+
|
404 |
+
conv_llava_llama_2 = Conversation(
|
405 |
+
system="You are a helpful language and vision assistant. "
|
406 |
+
"You are able to understand the visual content that the user provides, "
|
407 |
+
"and assist the user with a variety of tasks using natural language.",
|
408 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
409 |
+
roles=("USER", "ASSISTANT"),
|
410 |
+
version="llama_v2",
|
411 |
+
# pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
|
412 |
+
messages=(),
|
413 |
+
offset=0,
|
414 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
415 |
+
sep="<s>",
|
416 |
+
sep2="</s>",
|
417 |
+
)
|
418 |
+
|
419 |
+
conv_mpt = Conversation(
|
420 |
+
system="""<|im_start|>system
|
421 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
422 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
423 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
424 |
+
version="mpt",
|
425 |
+
# pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
|
426 |
+
messages=(),
|
427 |
+
offset=0,
|
428 |
+
sep_style=SeparatorStyle.MPT,
|
429 |
+
sep="<|im_end|>",
|
430 |
+
)
|
431 |
+
|
432 |
+
conv_llava_plain = Conversation(
|
433 |
+
system="",
|
434 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
435 |
+
roles=("", ""),
|
436 |
+
# pyre-fixme[6]: For 3rd argument expected `List[List[str]]` but got `Tuple[]`.
|
437 |
+
messages=(),
|
438 |
+
offset=0,
|
439 |
+
sep_style=SeparatorStyle.PLAIN,
|
440 |
+
sep="\n",
|
441 |
+
version="plain",
|
442 |
+
)
|
443 |
+
|
444 |
+
conv_llava_v0 = Conversation(
|
445 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
446 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
447 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
448 |
+
roles=("Human", "Assistant"),
|
449 |
+
# pyre-fixme[6]: For 3rd argument expected `List[List[str]]` but got `Tuple[]`.
|
450 |
+
messages=(),
|
451 |
+
offset=0,
|
452 |
+
sep_style=SeparatorStyle.SINGLE,
|
453 |
+
sep="###",
|
454 |
+
)
|
455 |
+
|
456 |
+
conv_llava_v0_mmtag = Conversation(
|
457 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
458 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
459 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
460 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
461 |
+
roles=("Human", "Assistant"),
|
462 |
+
# pyre-fixme[6]: For 3rd argument expected `List[List[str]]` but got `Tuple[]`.
|
463 |
+
messages=(),
|
464 |
+
offset=0,
|
465 |
+
sep_style=SeparatorStyle.SINGLE,
|
466 |
+
sep="###",
|
467 |
+
version="v0_mmtag",
|
468 |
+
)
|
469 |
+
|
470 |
+
conv_llava_v1 = Conversation(
|
471 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
472 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
473 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
474 |
+
roles=("USER", "ASSISTANT"),
|
475 |
+
version="v1",
|
476 |
+
# pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
|
477 |
+
messages=(),
|
478 |
+
offset=0,
|
479 |
+
sep_style=SeparatorStyle.TWO,
|
480 |
+
sep=" ",
|
481 |
+
sep2="</s>",
|
482 |
+
)
|
483 |
+
|
484 |
+
conv_llava_v1_mmtag = Conversation(
|
485 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
486 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
487 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
488 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
489 |
+
roles=("USER", "ASSISTANT"),
|
490 |
+
# pyre-fixme[6]: For 3rd argument expected `List[List[str]]` but got `Tuple[]`.
|
491 |
+
messages=(),
|
492 |
+
offset=0,
|
493 |
+
sep_style=SeparatorStyle.TWO,
|
494 |
+
sep=" ",
|
495 |
+
sep2="</s>",
|
496 |
+
version="v1_mmtag",
|
497 |
+
)
|
498 |
+
|
499 |
+
conv_mistral_instruct = Conversation(
|
500 |
+
system="",
|
501 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
502 |
+
roles=("USER", "ASSISTANT"),
|
503 |
+
version="llama_v2",
|
504 |
+
# pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
|
505 |
+
messages=(),
|
506 |
+
offset=0,
|
507 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
508 |
+
sep="",
|
509 |
+
sep2="</s>",
|
510 |
+
)
|
511 |
+
|
512 |
+
conv_chatml_direct = Conversation(
|
513 |
+
system="""<|im_start|>system
|
514 |
+
Answer the questions.""",
|
515 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
516 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
517 |
+
version="mpt",
|
518 |
+
# pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
|
519 |
+
messages=(),
|
520 |
+
offset=0,
|
521 |
+
sep_style=SeparatorStyle.MPT,
|
522 |
+
sep="<|im_end|>",
|
523 |
+
)
|
524 |
+
|
525 |
+
# llama3_tokenizer = AutoTokenizer.from_pretrained(
|
526 |
+
# PathManager.get_local_path(
|
527 |
+
# "./checkpoint/"
|
528 |
+
# )
|
529 |
+
# )
|
530 |
+
|
531 |
+
conv_llama3 = Conversation(
|
532 |
+
system="""As a multimodal AI, you have the ability to process and analyze images. Whenever an image is present in the conversation, very carefully examine it and consider its content when formulating your response. You should give concise responses to very simple questions, but provide thorough responses to more complex and open-ended questions.""",
|
533 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
534 |
+
roles=("user", "assistant"),
|
535 |
+
version="llama3",
|
536 |
+
# pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
|
537 |
+
messages=(),
|
538 |
+
offset=0,
|
539 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
540 |
+
# tokenizer=llama3_tokenizer,
|
541 |
+
sep="<|eot_id|>",
|
542 |
+
)
|
543 |
+
|
544 |
+
conv_llama3_2 = Conversation(
|
545 |
+
system="""You are a helpful assistant.""",
|
546 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
547 |
+
roles=("user", "assistant"),
|
548 |
+
version="llama3_2",
|
549 |
+
# pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
|
550 |
+
messages=(),
|
551 |
+
offset=0,
|
552 |
+
sep_style=SeparatorStyle.LLAMA_3_2,
|
553 |
+
sep="<|eot_id|>",
|
554 |
+
)
|
555 |
+
|
556 |
+
conv_phi3_instruct = Conversation(
|
557 |
+
system="""<|system|>\nYou are a helpful AI assistant.""",
|
558 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
559 |
+
roles=("\n<|user|>\n", "\n<|assistant|>\n"),
|
560 |
+
version="phi3",
|
561 |
+
# pyre-fixme[6]: For 4th argument expected `List[List[str]]` but got `Tuple[]`.
|
562 |
+
messages=(),
|
563 |
+
offset=0,
|
564 |
+
sep_style=SeparatorStyle.MPT,
|
565 |
+
sep="<|end|>",
|
566 |
+
)
|
567 |
+
|
568 |
+
conv_qwen = Conversation(
|
569 |
+
system="""<|im_start|>system
|
570 |
+
You are a helpful assistant.""",
|
571 |
+
# pyre-fixme[6]: For 2nd argument expected `List[str]` but got `Tuple[str, str]`.
|
572 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
573 |
+
version="qwen",
|
574 |
+
messages=[],
|
575 |
+
offset=0,
|
576 |
+
sep_style=SeparatorStyle.CHATML,
|
577 |
+
sep="<|im_end|>",
|
578 |
+
)
|
579 |
+
|
580 |
+
default_conversation = conv_vicuna_v1
|
581 |
+
conv_templates = {
|
582 |
+
"default": conv_vicuna_v0,
|
583 |
+
"v0": conv_vicuna_v0,
|
584 |
+
"v1": conv_vicuna_v1,
|
585 |
+
"vicuna_v1": conv_vicuna_v1,
|
586 |
+
"llama_2": conv_llama_2,
|
587 |
+
"mistral_instruct": conv_mistral_instruct,
|
588 |
+
"chatml_direct": conv_chatml_direct,
|
589 |
+
"mistral_direct": conv_chatml_direct,
|
590 |
+
"plain": conv_llava_plain,
|
591 |
+
"v0_plain": conv_llava_plain,
|
592 |
+
"llava_v0": conv_llava_v0,
|
593 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
594 |
+
"llava_v1": conv_llava_v1,
|
595 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
596 |
+
"llava_llama_2": conv_llava_llama_2,
|
597 |
+
"mpt": conv_mpt,
|
598 |
+
"llama3": conv_llama3,
|
599 |
+
"llama3_2": conv_llama3_2,
|
600 |
+
"phi3": conv_phi3_instruct,
|
601 |
+
"qwen": conv_qwen,
|
602 |
+
}
|
603 |
+
|
604 |
+
|
605 |
+
if __name__ == "__main__":
|
606 |
+
print(default_conversation.get_prompt())
|
longvu/file_io.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
4 |
+
|
5 |
+
from iopath.common.file_io import HTTPURLHandler, PathManager as PathManagerBase
|
6 |
+
|
7 |
+
__all__ = ["PathManager"]
|
8 |
+
|
9 |
+
|
10 |
+
PathManager = PathManagerBase()
|
11 |
+
PathManager.register_handler(HTTPURLHandler())
|
longvu/language_model/__pycache__/cambrian_llama.cpython-310.pyc
ADDED
Binary file (8.51 kB). View file
|
|
longvu/language_model/__pycache__/cambrian_qwen.cpython-310.pyc
ADDED
Binary file (7.98 kB). View file
|
|
longvu/language_model/cambrian_llama.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
|
23 |
+
from transformers import (
|
24 |
+
AutoConfig,
|
25 |
+
AutoModelForCausalLM,
|
26 |
+
LlamaConfig,
|
27 |
+
LlamaForCausalLM,
|
28 |
+
LlamaModel,
|
29 |
+
)
|
30 |
+
from transformers.cache_utils import Cache, DynamicCache
|
31 |
+
from transformers.generation.utils import GenerateOutput
|
32 |
+
|
33 |
+
from transformers.modeling_attn_mask_utils import (
|
34 |
+
_prepare_4d_causal_attention_mask,
|
35 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
36 |
+
)
|
37 |
+
|
38 |
+
from transformers.modeling_outputs import (
|
39 |
+
BaseModelOutputWithPast,
|
40 |
+
CausalLMOutputWithPast,
|
41 |
+
)
|
42 |
+
from transformers.utils import logging
|
43 |
+
|
44 |
+
from ..cambrian_arch import CambrianMetaForCausalLM, CambrianMetaModel
|
45 |
+
|
46 |
+
IS_XLA_AVAILABLE = False
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
class CambrianConfig(LlamaConfig):
|
52 |
+
model_type = "cambrian_llama"
|
53 |
+
|
54 |
+
debug = "debug"
|
55 |
+
|
56 |
+
|
57 |
+
class CambrianLlamaModel(CambrianMetaModel, LlamaModel):
|
58 |
+
config_class = CambrianConfig
|
59 |
+
|
60 |
+
def __init__(self, config: LlamaConfig):
|
61 |
+
super(CambrianLlamaModel, self).__init__(config)
|
62 |
+
|
63 |
+
def forward(
|
64 |
+
self,
|
65 |
+
# pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
|
66 |
+
input_ids: torch.LongTensor = None,
|
67 |
+
attention_mask: Optional[torch.Tensor] = None,
|
68 |
+
position_ids: Optional[torch.LongTensor] = None,
|
69 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
70 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
71 |
+
use_cache: Optional[bool] = None,
|
72 |
+
output_attentions: Optional[bool] = None,
|
73 |
+
output_hidden_states: Optional[bool] = None,
|
74 |
+
return_dict: Optional[bool] = None,
|
75 |
+
vision_tower_aux_feature_list: Optional[List[torch.FloatTensor]] = None,
|
76 |
+
vision_tower_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
|
77 |
+
final_vision_feature_size: Optional[List[tuple]] = None,
|
78 |
+
global_context_feature: Optional[torch.Tensor] = None,
|
79 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
80 |
+
|
81 |
+
output_attentions = (
|
82 |
+
output_attentions
|
83 |
+
if output_attentions is not None
|
84 |
+
# pyre-fixme[16]: `CambrianLlamaModel` has no attribute `config`.
|
85 |
+
else self.config.output_attentions
|
86 |
+
)
|
87 |
+
|
88 |
+
output_hidden_states = (
|
89 |
+
output_hidden_states
|
90 |
+
if output_hidden_states is not None
|
91 |
+
else self.config.output_hidden_states
|
92 |
+
)
|
93 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
94 |
+
|
95 |
+
return_dict = (
|
96 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
97 |
+
)
|
98 |
+
|
99 |
+
# retrieve input_ids and inputs_embeds
|
100 |
+
if input_ids is not None and inputs_embeds is not None:
|
101 |
+
raise ValueError(
|
102 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
103 |
+
)
|
104 |
+
elif input_ids is not None:
|
105 |
+
batch_size, seq_length = input_ids.shape[:2]
|
106 |
+
elif inputs_embeds is not None:
|
107 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
108 |
+
else:
|
109 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
110 |
+
|
111 |
+
# pyre-fixme[16]: `CambrianLlamaModel` has no attribute
|
112 |
+
# `gradient_checkpointing`.
|
113 |
+
# pyre-fixme[16]: `CambrianLlamaModel` has no attribute `training`.
|
114 |
+
if self.gradient_checkpointing and self.training:
|
115 |
+
if use_cache:
|
116 |
+
logger.warning_once(
|
117 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
118 |
+
)
|
119 |
+
use_cache = False
|
120 |
+
|
121 |
+
past_key_values_length = 0
|
122 |
+
if use_cache:
|
123 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
124 |
+
if use_legacy_cache:
|
125 |
+
# pyre-fixme[9]: past_key_values has type
|
126 |
+
# `Optional[List[FloatTensor]]`; used as `DynamicCache`.
|
127 |
+
# pyre-fixme[6]: For 1st argument expected
|
128 |
+
# `Optional[Tuple[Tuple[FloatTensor]]]` but got
|
129 |
+
# `Optional[List[FloatTensor]]`.
|
130 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
131 |
+
# pyre-fixme[16]: `Optional` has no attribute `get_usable_length`.
|
132 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
133 |
+
|
134 |
+
if position_ids is None:
|
135 |
+
# pyre-fixme[16]: `Optional` has no attribute `device`.
|
136 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
137 |
+
position_ids = torch.arange(
|
138 |
+
past_key_values_length,
|
139 |
+
seq_length + past_key_values_length,
|
140 |
+
dtype=torch.long,
|
141 |
+
device=device,
|
142 |
+
)
|
143 |
+
position_ids = position_ids.unsqueeze(0)
|
144 |
+
|
145 |
+
if inputs_embeds is None:
|
146 |
+
# pyre-fixme[16]: `CambrianLlamaModel` has no attribute `embed_tokens`.
|
147 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
148 |
+
|
149 |
+
# pyre-fixme[16]: `CambrianLlamaModel` has no attribute
|
150 |
+
# `_use_flash_attention_2`.
|
151 |
+
self._use_flash_attention_2 = getattr(self, "_use_flash_attention_2", False)
|
152 |
+
# pyre-fixme[16]: `CambrianLlamaModel` has no attribute `_use_sdpa`.
|
153 |
+
self._use_sdpa = getattr(self, "_use_sdpa", True)
|
154 |
+
if self._use_flash_attention_2:
|
155 |
+
# 2d mask is passed through the layers
|
156 |
+
attention_mask = (
|
157 |
+
attention_mask
|
158 |
+
if (attention_mask is not None and 0 in attention_mask)
|
159 |
+
else None
|
160 |
+
)
|
161 |
+
elif self._use_sdpa and not output_attentions:
|
162 |
+
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
163 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
164 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
165 |
+
attention_mask,
|
166 |
+
(batch_size, seq_length),
|
167 |
+
inputs_embeds,
|
168 |
+
past_key_values_length,
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
# 4d mask is passed through the layers
|
172 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
173 |
+
attention_mask,
|
174 |
+
(batch_size, seq_length),
|
175 |
+
inputs_embeds,
|
176 |
+
past_key_values_length,
|
177 |
+
)
|
178 |
+
|
179 |
+
# embed positions
|
180 |
+
hidden_states = inputs_embeds
|
181 |
+
# decoder layers
|
182 |
+
all_hidden_states = () if output_hidden_states else None
|
183 |
+
all_self_attns = () if output_attentions else None
|
184 |
+
next_decoder_cache = None
|
185 |
+
|
186 |
+
# pyre-fixme[16]: `CambrianLlamaModel` has no attribute `layers`.
|
187 |
+
for i, decoder_layer in enumerate(self.layers):
|
188 |
+
if output_hidden_states:
|
189 |
+
all_hidden_states += (hidden_states,)
|
190 |
+
|
191 |
+
if self.gradient_checkpointing and self.training:
|
192 |
+
# pyre-fixme[16]: `CambrianLlamaModel` has no attribute
|
193 |
+
# `_gradient_checkpointing_func`.
|
194 |
+
layer_outputs = self._gradient_checkpointing_func(
|
195 |
+
decoder_layer.__call__,
|
196 |
+
hidden_states,
|
197 |
+
attention_mask,
|
198 |
+
position_ids,
|
199 |
+
past_key_values,
|
200 |
+
output_attentions,
|
201 |
+
use_cache,
|
202 |
+
)
|
203 |
+
else:
|
204 |
+
layer_outputs = decoder_layer(
|
205 |
+
hidden_states,
|
206 |
+
attention_mask=attention_mask,
|
207 |
+
position_ids=position_ids,
|
208 |
+
past_key_value=past_key_values,
|
209 |
+
output_attentions=output_attentions,
|
210 |
+
use_cache=use_cache,
|
211 |
+
)
|
212 |
+
|
213 |
+
hidden_states = layer_outputs[0]
|
214 |
+
|
215 |
+
if use_cache:
|
216 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
217 |
+
|
218 |
+
if output_attentions:
|
219 |
+
all_self_attns += (layer_outputs[1],)
|
220 |
+
|
221 |
+
# pyre-fixme[16]: `CambrianLlamaModel` has no attribute `norm`.
|
222 |
+
hidden_states = self.norm(hidden_states)
|
223 |
+
|
224 |
+
# add hidden states from the last decoder layer
|
225 |
+
if output_hidden_states:
|
226 |
+
all_hidden_states += (hidden_states,)
|
227 |
+
|
228 |
+
next_cache = None
|
229 |
+
if use_cache:
|
230 |
+
next_cache = (
|
231 |
+
next_decoder_cache.to_legacy_cache()
|
232 |
+
# pyre-fixme[61]: `use_legacy_cache` is undefined, or not always
|
233 |
+
# defined.
|
234 |
+
if use_legacy_cache
|
235 |
+
else next_decoder_cache
|
236 |
+
)
|
237 |
+
if not return_dict:
|
238 |
+
return tuple(
|
239 |
+
v
|
240 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
241 |
+
if v is not None
|
242 |
+
)
|
243 |
+
return BaseModelOutputWithPast(
|
244 |
+
last_hidden_state=hidden_states,
|
245 |
+
past_key_values=next_cache,
|
246 |
+
hidden_states=all_hidden_states,
|
247 |
+
attentions=all_self_attns,
|
248 |
+
)
|
249 |
+
|
250 |
+
|
251 |
+
class CambrianLlamaForCausalLM(LlamaForCausalLM, CambrianMetaForCausalLM):
|
252 |
+
config_class = CambrianConfig
|
253 |
+
|
254 |
+
def __init__(self, config):
|
255 |
+
super(LlamaForCausalLM, self).__init__(config)
|
256 |
+
|
257 |
+
self.model = CambrianLlamaModel(config)
|
258 |
+
self.pretraining_tp = config.pretraining_tp
|
259 |
+
self.vocab_size = config.vocab_size
|
260 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
261 |
+
|
262 |
+
# Initialize weights and apply final processing
|
263 |
+
self.post_init()
|
264 |
+
|
265 |
+
def get_model(self):
|
266 |
+
return self.model
|
267 |
+
|
268 |
+
def forward(
|
269 |
+
self,
|
270 |
+
# pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
|
271 |
+
input_ids: torch.LongTensor = None,
|
272 |
+
attention_mask: Optional[torch.Tensor] = None,
|
273 |
+
position_ids: Optional[torch.LongTensor] = None,
|
274 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
275 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
276 |
+
labels: Optional[torch.LongTensor] = None,
|
277 |
+
use_cache: Optional[bool] = None,
|
278 |
+
output_attentions: Optional[bool] = None,
|
279 |
+
output_hidden_states: Optional[bool] = None,
|
280 |
+
images: Optional[torch.FloatTensor] = None,
|
281 |
+
image_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
|
282 |
+
image_sizes: Optional[List[List[int]]] = None,
|
283 |
+
return_dict: Optional[bool] = None,
|
284 |
+
cache_position=None,
|
285 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
286 |
+
|
287 |
+
final_vision_feature_size = None
|
288 |
+
|
289 |
+
if inputs_embeds is None:
|
290 |
+
(
|
291 |
+
input_ids,
|
292 |
+
position_ids,
|
293 |
+
attention_mask,
|
294 |
+
past_key_values,
|
295 |
+
inputs_embeds,
|
296 |
+
labels,
|
297 |
+
vision_tower_aux_feature_list,
|
298 |
+
vision_tower_aux_attention_masks_list,
|
299 |
+
final_vision_feature_size,
|
300 |
+
global_context_feature,
|
301 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
302 |
+
input_ids,
|
303 |
+
position_ids,
|
304 |
+
attention_mask,
|
305 |
+
past_key_values,
|
306 |
+
labels,
|
307 |
+
images,
|
308 |
+
image_aux_attention_masks_list,
|
309 |
+
image_sizes,
|
310 |
+
)
|
311 |
+
if IS_XLA_AVAILABLE:
|
312 |
+
# Very Important for TorchXLA
|
313 |
+
# self.model.gradient_checkpointing = False
|
314 |
+
|
315 |
+
# pyre-fixme[21]: Could not find module `torch_xla.utils.checkpoint`.
|
316 |
+
from torch_xla.utils.checkpoint import checkpoint
|
317 |
+
|
318 |
+
# self.model.gradient_checkpointing = True
|
319 |
+
# pyre-fixme[16]: `CambrianLlamaModel` has no attribute
|
320 |
+
# `_gradient_checkpointing_func`.
|
321 |
+
self.model._gradient_checkpointing_func = checkpoint
|
322 |
+
|
323 |
+
output_attentions = (
|
324 |
+
output_attentions
|
325 |
+
if output_attentions is not None
|
326 |
+
# pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute `config`.
|
327 |
+
else self.config.output_attentions
|
328 |
+
)
|
329 |
+
output_hidden_states = (
|
330 |
+
output_hidden_states
|
331 |
+
if output_hidden_states is not None
|
332 |
+
else self.config.output_hidden_states
|
333 |
+
)
|
334 |
+
return_dict = (
|
335 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
336 |
+
)
|
337 |
+
|
338 |
+
# training
|
339 |
+
if IS_XLA_AVAILABLE:
|
340 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
341 |
+
# pyre-fixme[29]: `CambrianLlamaModel` is not a function.
|
342 |
+
outputs = self.model(
|
343 |
+
input_ids=input_ids,
|
344 |
+
attention_mask=attention_mask,
|
345 |
+
position_ids=position_ids,
|
346 |
+
past_key_values=past_key_values,
|
347 |
+
inputs_embeds=inputs_embeds,
|
348 |
+
use_cache=use_cache,
|
349 |
+
output_attentions=output_attentions,
|
350 |
+
output_hidden_states=output_hidden_states,
|
351 |
+
return_dict=return_dict,
|
352 |
+
# pyre-fixme[61]: `vision_tower_aux_feature_list` is undefined, or
|
353 |
+
# not always defined.
|
354 |
+
vision_tower_aux_feature_list=vision_tower_aux_feature_list,
|
355 |
+
# pyre-fixme[61]: `vision_tower_aux_attention_masks_list` is
|
356 |
+
# undefined, or not always defined.
|
357 |
+
vision_tower_aux_attention_masks_list=vision_tower_aux_attention_masks_list,
|
358 |
+
final_vision_feature_size=final_vision_feature_size,
|
359 |
+
# pyre-fixme[61]: `global_context_feature` is undefined, or not
|
360 |
+
# always defined.
|
361 |
+
global_context_feature=global_context_feature,
|
362 |
+
)
|
363 |
+
|
364 |
+
# inference
|
365 |
+
else:
|
366 |
+
if hasattr(self, "vision_tower_aux_feature_list"):
|
367 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
368 |
+
# pyre-fixme[29]: `CambrianLlamaModel` is not a function.
|
369 |
+
outputs = self.model(
|
370 |
+
input_ids=input_ids,
|
371 |
+
attention_mask=attention_mask,
|
372 |
+
position_ids=position_ids,
|
373 |
+
past_key_values=past_key_values,
|
374 |
+
inputs_embeds=inputs_embeds,
|
375 |
+
use_cache=use_cache,
|
376 |
+
output_attentions=output_attentions,
|
377 |
+
output_hidden_states=output_hidden_states,
|
378 |
+
return_dict=return_dict,
|
379 |
+
vision_tower_aux_feature_list=(
|
380 |
+
# pyre-fixme[61]: `vision_tower_aux_feature_list` is
|
381 |
+
# undefined, or not always defined.
|
382 |
+
vision_tower_aux_feature_list
|
383 |
+
if inputs_embeds is None
|
384 |
+
# pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
|
385 |
+
# attribute `vision_tower_aux_feature_list`.
|
386 |
+
else self.vision_tower_aux_feature_list
|
387 |
+
),
|
388 |
+
vision_tower_aux_attention_masks_list=(
|
389 |
+
# pyre-fixme[61]: `vision_tower_aux_attention_masks_list` is
|
390 |
+
# undefined, or not always defined.
|
391 |
+
vision_tower_aux_attention_masks_list
|
392 |
+
if inputs_embeds is None
|
393 |
+
# pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
|
394 |
+
# attribute `vision_tower_aux_attention_masks_list`.
|
395 |
+
else self.vision_tower_aux_attention_masks_list
|
396 |
+
),
|
397 |
+
final_vision_feature_size=(
|
398 |
+
final_vision_feature_size
|
399 |
+
if inputs_embeds is None
|
400 |
+
# pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
|
401 |
+
# attribute `final_vision_feature_size`.
|
402 |
+
else self.final_vision_feature_size
|
403 |
+
),
|
404 |
+
global_context_feature=(
|
405 |
+
# pyre-fixme[61]: `global_context_feature` is undefined, or
|
406 |
+
# not always defined.
|
407 |
+
global_context_feature
|
408 |
+
if inputs_embeds is None
|
409 |
+
# pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
|
410 |
+
# attribute `global_context_feature`.
|
411 |
+
else self.global_context_feature
|
412 |
+
),
|
413 |
+
)
|
414 |
+
else:
|
415 |
+
# pyre-fixme[29]: `CambrianLlamaModel` is not a function.
|
416 |
+
outputs = self.model(
|
417 |
+
input_ids=input_ids,
|
418 |
+
attention_mask=attention_mask,
|
419 |
+
position_ids=position_ids,
|
420 |
+
past_key_values=past_key_values,
|
421 |
+
inputs_embeds=inputs_embeds,
|
422 |
+
use_cache=use_cache,
|
423 |
+
output_attentions=output_attentions,
|
424 |
+
output_hidden_states=output_hidden_states,
|
425 |
+
return_dict=return_dict,
|
426 |
+
# final_vision_feature_size=final_vision_feature_size,
|
427 |
+
)
|
428 |
+
|
429 |
+
hidden_states = outputs[0]
|
430 |
+
if self.config.pretraining_tp > 1:
|
431 |
+
lm_head_slices = self.lm_head.weight.split(
|
432 |
+
self.vocab_size // self.config.pretraining_tp, dim=0
|
433 |
+
)
|
434 |
+
logits = [
|
435 |
+
F.linear(hidden_states, lm_head_slices[i])
|
436 |
+
for i in range(self.config.pretraining_tp)
|
437 |
+
]
|
438 |
+
logits = torch.cat(logits, dim=-1)
|
439 |
+
else:
|
440 |
+
logits = self.lm_head(hidden_states)
|
441 |
+
logits = logits.float()
|
442 |
+
|
443 |
+
loss = None
|
444 |
+
if labels is not None:
|
445 |
+
# Shift so that tokens < n predict n
|
446 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
447 |
+
shift_labels = labels[..., 1:].contiguous()
|
448 |
+
# Flatten the tokens
|
449 |
+
loss_fct = CrossEntropyLoss()
|
450 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
451 |
+
shift_labels = shift_labels.view(-1)
|
452 |
+
# Enable model parallelism
|
453 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
454 |
+
loss = loss_fct(shift_logits, shift_labels)
|
455 |
+
|
456 |
+
if not return_dict:
|
457 |
+
output = (logits,) + outputs[1:]
|
458 |
+
return (loss,) + output if loss is not None else output
|
459 |
+
|
460 |
+
return CausalLMOutputWithPast(
|
461 |
+
loss=loss,
|
462 |
+
logits=logits,
|
463 |
+
past_key_values=outputs.past_key_values,
|
464 |
+
hidden_states=outputs.hidden_states,
|
465 |
+
attentions=outputs.attentions,
|
466 |
+
)
|
467 |
+
|
468 |
+
@torch.no_grad()
|
469 |
+
def generate(
|
470 |
+
self,
|
471 |
+
inputs: Optional[torch.Tensor] = None,
|
472 |
+
images: Optional[torch.Tensor] = None,
|
473 |
+
image_sizes: Optional[torch.Tensor] = None,
|
474 |
+
**kwargs,
|
475 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
476 |
+
position_ids = kwargs.pop("position_ids", None)
|
477 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
478 |
+
if "inputs_embeds" in kwargs:
|
479 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
480 |
+
|
481 |
+
if images is not None:
|
482 |
+
(
|
483 |
+
inputs,
|
484 |
+
position_ids,
|
485 |
+
attention_mask,
|
486 |
+
_,
|
487 |
+
inputs_embeds,
|
488 |
+
_,
|
489 |
+
vision_tower_aux_feature_list,
|
490 |
+
vision_tower_aux_attention_masks_list,
|
491 |
+
final_vision_feature_size,
|
492 |
+
global_context_feature,
|
493 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
494 |
+
inputs,
|
495 |
+
position_ids,
|
496 |
+
attention_mask,
|
497 |
+
None,
|
498 |
+
None,
|
499 |
+
images,
|
500 |
+
image_sizes=image_sizes,
|
501 |
+
)
|
502 |
+
# pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
|
503 |
+
# `vision_tower_aux_feature_list`.
|
504 |
+
self.vision_tower_aux_feature_list = vision_tower_aux_feature_list
|
505 |
+
# pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
|
506 |
+
# `vision_tower_aux_attention_masks_list`.
|
507 |
+
self.vision_tower_aux_attention_masks_list = (
|
508 |
+
vision_tower_aux_attention_masks_list
|
509 |
+
)
|
510 |
+
# pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
|
511 |
+
# `final_vision_feature_size`.
|
512 |
+
self.final_vision_feature_size = final_vision_feature_size
|
513 |
+
# pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
|
514 |
+
# `global_context_feature`.
|
515 |
+
self.global_context_feature = global_context_feature
|
516 |
+
else:
|
517 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
518 |
+
|
519 |
+
# pyre-fixme[16]: `LlamaForCausalLM` has no attribute `generate`.
|
520 |
+
return super().generate(
|
521 |
+
position_ids=position_ids,
|
522 |
+
attention_mask=attention_mask,
|
523 |
+
inputs_embeds=inputs_embeds,
|
524 |
+
**kwargs,
|
525 |
+
)
|
526 |
+
|
527 |
+
def prepare_inputs_for_generation(
|
528 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
529 |
+
):
|
530 |
+
images = kwargs.pop("images", None)
|
531 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
532 |
+
inputs = super().prepare_inputs_for_generation(
|
533 |
+
input_ids,
|
534 |
+
past_key_values=past_key_values,
|
535 |
+
inputs_embeds=inputs_embeds,
|
536 |
+
**kwargs,
|
537 |
+
)
|
538 |
+
if images is not None:
|
539 |
+
inputs["images"] = images
|
540 |
+
if image_sizes is not None:
|
541 |
+
inputs["image_sizes"] = image_sizes
|
542 |
+
return inputs
|
543 |
+
|
544 |
+
|
545 |
+
AutoConfig.register("cambrian_llama", CambrianConfig)
|
546 |
+
AutoModelForCausalLM.register(CambrianConfig, CambrianLlamaForCausalLM)
|
longvu/language_model/cambrian_qwen.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
|
23 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
24 |
+
from transformers.cache_utils import Cache, DynamicCache
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
from transformers.modeling_outputs import (
|
28 |
+
BaseModelOutputWithPast,
|
29 |
+
CausalLMOutputWithPast,
|
30 |
+
)
|
31 |
+
from transformers.utils import logging
|
32 |
+
|
33 |
+
from ..cambrian_arch import CambrianMetaForCausalLM, CambrianMetaModel
|
34 |
+
|
35 |
+
IS_XLA_AVAILABLE = False
|
36 |
+
|
37 |
+
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2Model
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__)
|
40 |
+
|
41 |
+
|
42 |
+
class CambrianConfig(Qwen2Config):
|
43 |
+
model_type = "cambrian_qwen"
|
44 |
+
|
45 |
+
debug = "debug"
|
46 |
+
|
47 |
+
|
48 |
+
class CambrianQwenModel(CambrianMetaModel, Qwen2Model):
|
49 |
+
config_class = CambrianConfig
|
50 |
+
|
51 |
+
def __init__(self, config: Qwen2Config):
|
52 |
+
super(CambrianQwenModel, self).__init__(config)
|
53 |
+
|
54 |
+
def forward(
|
55 |
+
self,
|
56 |
+
# pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
|
57 |
+
input_ids: torch.LongTensor = None,
|
58 |
+
attention_mask: Optional[torch.Tensor] = None,
|
59 |
+
position_ids: Optional[torch.LongTensor] = None,
|
60 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
61 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
62 |
+
use_cache: Optional[bool] = None,
|
63 |
+
output_attentions: Optional[bool] = None,
|
64 |
+
output_hidden_states: Optional[bool] = None,
|
65 |
+
return_dict: Optional[bool] = None,
|
66 |
+
cache_position: Optional[torch.LongTensor] = None,
|
67 |
+
vision_tower_aux_feature_list: Optional[List[torch.FloatTensor]] = None,
|
68 |
+
vision_tower_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
|
69 |
+
final_vision_feature_size: Optional[List[tuple]] = None,
|
70 |
+
global_context_feature: Optional[torch.Tensor] = None,
|
71 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
72 |
+
output_attentions = (
|
73 |
+
output_attentions
|
74 |
+
if output_attentions is not None
|
75 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `config`.
|
76 |
+
else self.config.output_attentions
|
77 |
+
)
|
78 |
+
output_hidden_states = (
|
79 |
+
output_hidden_states
|
80 |
+
if output_hidden_states is not None
|
81 |
+
else self.config.output_hidden_states
|
82 |
+
)
|
83 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
84 |
+
|
85 |
+
return_dict = (
|
86 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
87 |
+
)
|
88 |
+
|
89 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
90 |
+
raise ValueError(
|
91 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
92 |
+
)
|
93 |
+
|
94 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `gradient_checkpointing`.
|
95 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `training`.
|
96 |
+
if self.gradient_checkpointing and self.training:
|
97 |
+
if use_cache:
|
98 |
+
logger.warning_once(
|
99 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
100 |
+
)
|
101 |
+
use_cache = False
|
102 |
+
|
103 |
+
use_legacy_cache = False
|
104 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
105 |
+
use_legacy_cache = True
|
106 |
+
# pyre-fixme[6]: For 1st argument expected
|
107 |
+
# `Optional[Tuple[Tuple[FloatTensor]]]` but got
|
108 |
+
# `Optional[List[FloatTensor]]`.
|
109 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
110 |
+
logger.warning_once(
|
111 |
+
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
|
112 |
+
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
113 |
+
)
|
114 |
+
|
115 |
+
if inputs_embeds is None:
|
116 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `embed_tokens`.
|
117 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
118 |
+
|
119 |
+
if cache_position is None:
|
120 |
+
past_seen_tokens = (
|
121 |
+
# pyre-fixme[16]: Item `List` of `Union[List[torch._C.FloatTensor],
|
122 |
+
# DynamicCache]` has no attribute `get_seq_length`.
|
123 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
124 |
+
)
|
125 |
+
cache_position = torch.arange(
|
126 |
+
past_seen_tokens,
|
127 |
+
past_seen_tokens + inputs_embeds.shape[1],
|
128 |
+
device=inputs_embeds.device,
|
129 |
+
)
|
130 |
+
if position_ids is None:
|
131 |
+
position_ids = cache_position.unsqueeze(0)
|
132 |
+
|
133 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `_update_causal_mask`.
|
134 |
+
causal_mask = self._update_causal_mask(
|
135 |
+
attention_mask,
|
136 |
+
inputs_embeds,
|
137 |
+
cache_position,
|
138 |
+
past_key_values,
|
139 |
+
output_attentions,
|
140 |
+
)
|
141 |
+
|
142 |
+
hidden_states = inputs_embeds
|
143 |
+
|
144 |
+
# decoder layers
|
145 |
+
all_hidden_states = () if output_hidden_states else None
|
146 |
+
all_self_attns = () if output_attentions else None
|
147 |
+
next_decoder_cache = None
|
148 |
+
|
149 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `layers`.
|
150 |
+
for i, decoder_layer in enumerate(self.layers):
|
151 |
+
if output_hidden_states:
|
152 |
+
all_hidden_states += (hidden_states,)
|
153 |
+
|
154 |
+
if self.gradient_checkpointing and self.training:
|
155 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute
|
156 |
+
# `_gradient_checkpointing_func`.
|
157 |
+
layer_outputs = self._gradient_checkpointing_func(
|
158 |
+
decoder_layer.__call__,
|
159 |
+
hidden_states,
|
160 |
+
causal_mask,
|
161 |
+
position_ids,
|
162 |
+
past_key_values,
|
163 |
+
output_attentions,
|
164 |
+
use_cache,
|
165 |
+
cache_position,
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
layer_outputs = decoder_layer(
|
169 |
+
hidden_states,
|
170 |
+
attention_mask=causal_mask,
|
171 |
+
position_ids=position_ids,
|
172 |
+
past_key_value=past_key_values,
|
173 |
+
output_attentions=output_attentions,
|
174 |
+
use_cache=use_cache,
|
175 |
+
cache_position=cache_position,
|
176 |
+
)
|
177 |
+
|
178 |
+
hidden_states = layer_outputs[0]
|
179 |
+
|
180 |
+
if use_cache:
|
181 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
182 |
+
|
183 |
+
if output_attentions:
|
184 |
+
all_self_attns += (layer_outputs[1],)
|
185 |
+
|
186 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `norm`.
|
187 |
+
hidden_states = self.norm(hidden_states)
|
188 |
+
|
189 |
+
# add hidden states from the last decoder layer
|
190 |
+
if output_hidden_states:
|
191 |
+
all_hidden_states += (hidden_states,)
|
192 |
+
|
193 |
+
next_cache = None
|
194 |
+
if use_cache:
|
195 |
+
next_cache = (
|
196 |
+
next_decoder_cache.to_legacy_cache()
|
197 |
+
if use_legacy_cache
|
198 |
+
else next_decoder_cache
|
199 |
+
)
|
200 |
+
|
201 |
+
if not return_dict:
|
202 |
+
return tuple(
|
203 |
+
v
|
204 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
205 |
+
if v is not None
|
206 |
+
)
|
207 |
+
return BaseModelOutputWithPast(
|
208 |
+
last_hidden_state=hidden_states,
|
209 |
+
past_key_values=next_cache,
|
210 |
+
hidden_states=all_hidden_states,
|
211 |
+
attentions=all_self_attns,
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
class CambrianQwenForCausalLM(Qwen2ForCausalLM, CambrianMetaForCausalLM):
|
216 |
+
config_class = CambrianConfig
|
217 |
+
|
218 |
+
def __init__(self, config):
|
219 |
+
# super(Qwen2ForCausalLM, self).__init__(config)
|
220 |
+
Qwen2ForCausalLM.__init__(self, config)
|
221 |
+
config.model_type = "cambrian_qwen"
|
222 |
+
config.rope_scaling = None
|
223 |
+
|
224 |
+
self.model = CambrianQwenModel(config)
|
225 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
226 |
+
# Initialize weights and apply final processing
|
227 |
+
self.post_init()
|
228 |
+
|
229 |
+
def get_model(self):
|
230 |
+
return self.model
|
231 |
+
|
232 |
+
def forward(
|
233 |
+
self,
|
234 |
+
# pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
|
235 |
+
input_ids: torch.LongTensor = None,
|
236 |
+
attention_mask: Optional[torch.Tensor] = None,
|
237 |
+
position_ids: Optional[torch.LongTensor] = None,
|
238 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
239 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
240 |
+
labels: Optional[torch.LongTensor] = None,
|
241 |
+
use_cache: Optional[bool] = None,
|
242 |
+
output_attentions: Optional[bool] = None,
|
243 |
+
output_hidden_states: Optional[bool] = None,
|
244 |
+
images: Optional[torch.FloatTensor] = None,
|
245 |
+
image_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
|
246 |
+
image_sizes: Optional[List[List[int]]] = None,
|
247 |
+
return_dict: Optional[bool] = None,
|
248 |
+
modalities: Optional[List[str]] = ["image"],
|
249 |
+
dpo_forward: Optional[bool] = False,
|
250 |
+
cache_position=None,
|
251 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
252 |
+
|
253 |
+
input_image_features = None
|
254 |
+
highres_image_features = None
|
255 |
+
frame_split_sizes = None
|
256 |
+
|
257 |
+
if inputs_embeds is None:
|
258 |
+
(
|
259 |
+
input_ids,
|
260 |
+
position_ids,
|
261 |
+
attention_mask,
|
262 |
+
past_key_values,
|
263 |
+
inputs_embeds,
|
264 |
+
labels,
|
265 |
+
vision_tower_aux_feature_list,
|
266 |
+
vision_tower_aux_attention_masks_list,
|
267 |
+
final_vision_feature_size,
|
268 |
+
global_context_feature,
|
269 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
270 |
+
input_ids,
|
271 |
+
position_ids,
|
272 |
+
attention_mask,
|
273 |
+
past_key_values,
|
274 |
+
labels,
|
275 |
+
images,
|
276 |
+
image_aux_attention_masks_list,
|
277 |
+
image_sizes,
|
278 |
+
)
|
279 |
+
|
280 |
+
if dpo_forward:
|
281 |
+
# pyre-fixme[29]: `CambrianQwenModel` is not a function.
|
282 |
+
outputs = self.model(
|
283 |
+
input_ids=input_ids,
|
284 |
+
attention_mask=attention_mask,
|
285 |
+
position_ids=position_ids,
|
286 |
+
past_key_values=past_key_values,
|
287 |
+
inputs_embeds=inputs_embeds,
|
288 |
+
use_cache=use_cache,
|
289 |
+
output_attentions=output_attentions,
|
290 |
+
output_hidden_states=output_hidden_states,
|
291 |
+
return_dict=return_dict,
|
292 |
+
)
|
293 |
+
|
294 |
+
hidden_states = outputs[0]
|
295 |
+
logits = self.lm_head(hidden_states)
|
296 |
+
return logits, labels
|
297 |
+
|
298 |
+
else:
|
299 |
+
if hasattr(self, "vision_tower_aux_feature_list"):
|
300 |
+
# pyre-fixme[29]: `CambrianQwenModel` is not a function.
|
301 |
+
outputs = self.model(
|
302 |
+
input_ids=input_ids,
|
303 |
+
attention_mask=attention_mask,
|
304 |
+
position_ids=position_ids,
|
305 |
+
past_key_values=past_key_values,
|
306 |
+
inputs_embeds=inputs_embeds,
|
307 |
+
use_cache=use_cache,
|
308 |
+
output_attentions=output_attentions,
|
309 |
+
output_hidden_states=output_hidden_states,
|
310 |
+
return_dict=return_dict,
|
311 |
+
vision_tower_aux_feature_list=(
|
312 |
+
# pyre-fixme[61]: `vision_tower_aux_feature_list` is
|
313 |
+
# undefined, or not always defined.
|
314 |
+
vision_tower_aux_feature_list
|
315 |
+
if inputs_embeds is None
|
316 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
317 |
+
# `vision_tower_aux_feature_list`.
|
318 |
+
else self.vision_tower_aux_feature_list
|
319 |
+
),
|
320 |
+
vision_tower_aux_attention_masks_list=(
|
321 |
+
# pyre-fixme[61]: `vision_tower_aux_attention_masks_list` is
|
322 |
+
# undefined, or not always defined.
|
323 |
+
vision_tower_aux_attention_masks_list
|
324 |
+
if inputs_embeds is None
|
325 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
326 |
+
# `vision_tower_aux_attention_masks_list`.
|
327 |
+
else self.vision_tower_aux_attention_masks_list
|
328 |
+
),
|
329 |
+
final_vision_feature_size=(
|
330 |
+
# pyre-fixme[61]: `final_vision_feature_size` is undefined,
|
331 |
+
# or not always defined.
|
332 |
+
final_vision_feature_size
|
333 |
+
if inputs_embeds is None
|
334 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
335 |
+
# `final_vision_feature_size`.
|
336 |
+
else self.final_vision_feature_size
|
337 |
+
),
|
338 |
+
global_context_feature=(
|
339 |
+
# pyre-fixme[61]: `global_context_feature` is undefined, or
|
340 |
+
# not always defined.
|
341 |
+
global_context_feature
|
342 |
+
if inputs_embeds is None
|
343 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
344 |
+
# `global_context_feature`.
|
345 |
+
else self.global_context_feature
|
346 |
+
),
|
347 |
+
)
|
348 |
+
else:
|
349 |
+
# pyre-fixme[29]: `CambrianQwenModel` is not a function.
|
350 |
+
outputs = self.model(
|
351 |
+
input_ids=input_ids,
|
352 |
+
attention_mask=attention_mask,
|
353 |
+
position_ids=position_ids,
|
354 |
+
past_key_values=past_key_values,
|
355 |
+
inputs_embeds=inputs_embeds,
|
356 |
+
use_cache=use_cache,
|
357 |
+
output_attentions=output_attentions,
|
358 |
+
output_hidden_states=output_hidden_states,
|
359 |
+
return_dict=return_dict,
|
360 |
+
# final_vision_feature_size=final_vision_feature_size,
|
361 |
+
)
|
362 |
+
|
363 |
+
hidden_states = outputs[0]
|
364 |
+
logits = self.lm_head(hidden_states)
|
365 |
+
logits = logits.float()
|
366 |
+
|
367 |
+
loss = None
|
368 |
+
if labels is not None:
|
369 |
+
# Shift so that tokens < n predict n
|
370 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
371 |
+
shift_labels = labels[..., 1:].contiguous()
|
372 |
+
# Flatten the tokens
|
373 |
+
loss_fct = CrossEntropyLoss()
|
374 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute `config`.
|
375 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
376 |
+
shift_labels = shift_labels.view(-1)
|
377 |
+
# Enable model parallelism
|
378 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
379 |
+
loss = loss_fct(shift_logits, shift_labels)
|
380 |
+
|
381 |
+
if not return_dict:
|
382 |
+
output = (logits,) + outputs[1:]
|
383 |
+
return (loss,) + output if loss is not None else output
|
384 |
+
|
385 |
+
return CausalLMOutputWithPast(
|
386 |
+
loss=loss,
|
387 |
+
logits=logits,
|
388 |
+
past_key_values=outputs.past_key_values,
|
389 |
+
hidden_states=outputs.hidden_states,
|
390 |
+
attentions=outputs.attentions,
|
391 |
+
)
|
392 |
+
|
393 |
+
@torch.no_grad()
|
394 |
+
def generate(
|
395 |
+
self,
|
396 |
+
inputs: Optional[torch.Tensor] = None,
|
397 |
+
images: Optional[torch.Tensor] = None,
|
398 |
+
image_sizes: Optional[torch.Tensor] = None,
|
399 |
+
**kwargs,
|
400 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
401 |
+
position_ids = kwargs.pop("position_ids", None)
|
402 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
403 |
+
if "inputs_embeds" in kwargs:
|
404 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
405 |
+
|
406 |
+
if images is not None:
|
407 |
+
(
|
408 |
+
inputs,
|
409 |
+
position_ids,
|
410 |
+
attention_mask,
|
411 |
+
_,
|
412 |
+
inputs_embeds,
|
413 |
+
_,
|
414 |
+
vision_tower_aux_feature_list,
|
415 |
+
vision_tower_aux_attention_masks_list,
|
416 |
+
final_vision_feature_size,
|
417 |
+
global_context_feature,
|
418 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
419 |
+
inputs,
|
420 |
+
position_ids,
|
421 |
+
attention_mask,
|
422 |
+
None,
|
423 |
+
None,
|
424 |
+
images,
|
425 |
+
image_sizes=image_sizes,
|
426 |
+
)
|
427 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
428 |
+
# `vision_tower_aux_feature_list`.
|
429 |
+
self.vision_tower_aux_feature_list = vision_tower_aux_feature_list
|
430 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
431 |
+
# `vision_tower_aux_attention_masks_list`.
|
432 |
+
self.vision_tower_aux_attention_masks_list = (
|
433 |
+
vision_tower_aux_attention_masks_list
|
434 |
+
)
|
435 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
436 |
+
# `final_vision_feature_size`.
|
437 |
+
self.final_vision_feature_size = final_vision_feature_size
|
438 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
439 |
+
# `global_context_feature`.
|
440 |
+
self.global_context_feature = global_context_feature
|
441 |
+
else:
|
442 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
443 |
+
|
444 |
+
# pyre-fixme[16]: `Qwen2ForCausalLM` has no attribute `generate`.
|
445 |
+
return super().generate(
|
446 |
+
position_ids=position_ids,
|
447 |
+
attention_mask=attention_mask,
|
448 |
+
inputs_embeds=inputs_embeds,
|
449 |
+
**kwargs,
|
450 |
+
)
|
451 |
+
|
452 |
+
def prepare_inputs_for_generation(
|
453 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
454 |
+
):
|
455 |
+
images = kwargs.pop("images", None)
|
456 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
457 |
+
inputs = super().prepare_inputs_for_generation(
|
458 |
+
input_ids,
|
459 |
+
past_key_values=past_key_values,
|
460 |
+
inputs_embeds=inputs_embeds,
|
461 |
+
**kwargs,
|
462 |
+
)
|
463 |
+
if images is not None:
|
464 |
+
inputs["images"] = images
|
465 |
+
if image_sizes is not None:
|
466 |
+
inputs["image_sizes"] = image_sizes
|
467 |
+
return inputs
|
468 |
+
|
469 |
+
|
470 |
+
AutoConfig.register("cambrian_qwen", CambrianConfig)
|
471 |
+
AutoModelForCausalLM.register(CambrianConfig, CambrianQwenForCausalLM)
|
longvu/make_delta.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pyre-unsafe
|
2 |
+
"""
|
3 |
+
Usage:
|
4 |
+
python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
+
|
13 |
+
from .utils import auto_upgrade
|
14 |
+
|
15 |
+
|
16 |
+
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
|
17 |
+
print("Loading base model")
|
18 |
+
base = AutoModelForCausalLM.from_pretrained(
|
19 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
20 |
+
)
|
21 |
+
|
22 |
+
print("Loading target model")
|
23 |
+
auto_upgrade(target_model_path)
|
24 |
+
target = AutoModelForCausalLM.from_pretrained(
|
25 |
+
target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
26 |
+
)
|
27 |
+
|
28 |
+
print("Calculating delta")
|
29 |
+
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
|
30 |
+
if name not in base.state_dict():
|
31 |
+
assert name in [
|
32 |
+
"model.mm_projector.weight",
|
33 |
+
"model.mm_projector.bias",
|
34 |
+
], f"{name} not in base model"
|
35 |
+
continue
|
36 |
+
if param.data.shape == base.state_dict()[name].shape:
|
37 |
+
param.data -= base.state_dict()[name]
|
38 |
+
else:
|
39 |
+
assert name in [
|
40 |
+
"model.embed_tokens.weight",
|
41 |
+
"lm_head.weight",
|
42 |
+
], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
|
43 |
+
bparam = base.state_dict()[name]
|
44 |
+
param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
|
45 |
+
|
46 |
+
print("Saving delta")
|
47 |
+
if hub_repo_id:
|
48 |
+
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
|
49 |
+
else:
|
50 |
+
kwargs = {}
|
51 |
+
target.save_pretrained(delta_path, **kwargs)
|
52 |
+
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
|
53 |
+
target_tokenizer.save_pretrained(delta_path, **kwargs)
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
parser = argparse.ArgumentParser()
|
58 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
59 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
60 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
61 |
+
parser.add_argument("--hub-repo-id", type=str, default=None)
|
62 |
+
args = parser.parse_args()
|
63 |
+
|
64 |
+
make_delta(
|
65 |
+
args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id
|
66 |
+
)
|
longvu/mm_datautils.py
ADDED
@@ -0,0 +1,1688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pyre-strict
|
2 |
+
import copy
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Dict, List, Sequence
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import tokenizers
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
import transformers
|
15 |
+
|
16 |
+
from longvu import conversation as conversation_lib
|
17 |
+
|
18 |
+
from longvu.constants import (
|
19 |
+
DEFAULT_IM_END_TOKEN,
|
20 |
+
DEFAULT_IM_START_TOKEN,
|
21 |
+
DEFAULT_IMAGE_TOKEN,
|
22 |
+
IGNORE_INDEX,
|
23 |
+
IMAGE_TOKEN_INDEX,
|
24 |
+
)
|
25 |
+
|
26 |
+
# pyre-fixme[21]: Could not find module `decord`.
|
27 |
+
from decord import cpu, VideoReader # @manual=fbsource//third-party/pypi/decord:decord
|
28 |
+
|
29 |
+
from packaging import version
|
30 |
+
from PIL import Image
|
31 |
+
from torch import distributed as dist
|
32 |
+
from torch.distributed.fsdp import (
|
33 |
+
FullStateDictConfig,
|
34 |
+
FullyShardedDataParallel as FSDP,
|
35 |
+
StateDictType,
|
36 |
+
)
|
37 |
+
from torch.utils.data import Dataset
|
38 |
+
|
39 |
+
# pyre-fixme
|
40 |
+
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse(
|
41 |
+
"0.14"
|
42 |
+
)
|
43 |
+
from transformers import StoppingCriteria
|
44 |
+
|
45 |
+
from longvu.mm_utils import KeywordsStoppingCriteria
|
46 |
+
|
47 |
+
|
48 |
+
# pyre-fixme[3]: Return type must be annotated.
|
49 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
50 |
+
def maybe_zero_3(param, ignore_status: bool = False, name=None):
|
51 |
+
# NO deepspeed
|
52 |
+
|
53 |
+
# from deepspeed import zero
|
54 |
+
# from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
55 |
+
# if hasattr(param, "ds_id"):
|
56 |
+
# if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
57 |
+
# if not ignore_status:
|
58 |
+
# print(name, 'no ignore status')
|
59 |
+
# with zero.GatheredParameters([param]):
|
60 |
+
# param = param.data.detach().cpu().clone()
|
61 |
+
# else:
|
62 |
+
# param = param.detach().cpu().clone()
|
63 |
+
return param.detach().cpu().clone()
|
64 |
+
|
65 |
+
|
66 |
+
# pyre-fixme[3]: Return type must be annotated.
|
67 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
68 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
69 |
+
to_return = {
|
70 |
+
k: t
|
71 |
+
for k, t in named_params
|
72 |
+
if any(key_match in k for key_match in keys_to_match)
|
73 |
+
}
|
74 |
+
to_return = {
|
75 |
+
k: maybe_zero_3(v, ignore_status=True, name=k).cpu()
|
76 |
+
for k, v in to_return.items()
|
77 |
+
}
|
78 |
+
return to_return
|
79 |
+
|
80 |
+
|
81 |
+
# pyre-fixme[3]: Return type must be annotated.
|
82 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
83 |
+
def find_all_linear_names(model):
|
84 |
+
cls = torch.nn.Linear
|
85 |
+
lora_module_names = set()
|
86 |
+
multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
|
87 |
+
for name, module in model.named_modules():
|
88 |
+
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
89 |
+
continue
|
90 |
+
if isinstance(module, cls):
|
91 |
+
names = name.split(".")
|
92 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
93 |
+
|
94 |
+
if "lm_head" in lora_module_names: # needed for 16-bit
|
95 |
+
lora_module_names.remove("lm_head")
|
96 |
+
return list(lora_module_names)
|
97 |
+
|
98 |
+
|
99 |
+
def safe_save_model_for_hf_trainer(
|
100 |
+
trainer: transformers.Trainer, output_dir: str
|
101 |
+
) -> None:
|
102 |
+
"""Collects the state dict and dump to disk."""
|
103 |
+
global_rank = dist.get_rank()
|
104 |
+
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
105 |
+
# pyre-fixme[16]: `Trainer` has no attribute `args`.
|
106 |
+
if len(trainer.args.fsdp) == 0:
|
107 |
+
# pyre-fixme[16]: `Trainer` has no attribute `model`.
|
108 |
+
cpu_state_dict = trainer.model.state_dict()
|
109 |
+
else:
|
110 |
+
with FSDP.state_dict_type(
|
111 |
+
trainer.model, StateDictType.FULL_STATE_DICT, save_policy
|
112 |
+
):
|
113 |
+
cpu_state_dict = trainer.model.state_dict()
|
114 |
+
|
115 |
+
for key in cpu_state_dict.keys():
|
116 |
+
cpu_state_dict[key] = cpu_state_dict[key].to(torch.bfloat16)
|
117 |
+
|
118 |
+
if global_rank == 0:
|
119 |
+
trainer.model.config.save_pretrained(output_dir)
|
120 |
+
current_folder = output_dir.split("/")[-1]
|
121 |
+
parent_folder = os.path.dirname(output_dir)
|
122 |
+
save_path = os.path.join(output_dir, "pytorch_model.bin")
|
123 |
+
if getattr(trainer.args, "tune_mm_mlp_adapter", False) and not getattr(
|
124 |
+
trainer.args, "tune_text_decoder", False
|
125 |
+
):
|
126 |
+
# Only save Adapter
|
127 |
+
keys_to_match = ["mm_projector"]
|
128 |
+
if getattr(trainer.args, "use_im_start_end", False):
|
129 |
+
keys_to_match.extend(["embed_tokens", "embed_in"])
|
130 |
+
|
131 |
+
freeze_layer_remove = []
|
132 |
+
for key in cpu_state_dict.keys():
|
133 |
+
remove = True
|
134 |
+
for key_match in keys_to_match:
|
135 |
+
if key_match in key:
|
136 |
+
remove = False
|
137 |
+
break
|
138 |
+
if remove:
|
139 |
+
freeze_layer_remove.append(key)
|
140 |
+
for key in freeze_layer_remove:
|
141 |
+
del cpu_state_dict[key]
|
142 |
+
|
143 |
+
if current_folder.startswith("checkpoint-"):
|
144 |
+
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
|
145 |
+
os.makedirs(mm_projector_folder, exist_ok=True)
|
146 |
+
save_path = os.path.join(mm_projector_folder, f"{current_folder}.bin")
|
147 |
+
else:
|
148 |
+
save_path = os.path.join(output_dir, f"mm_projector.bin")
|
149 |
+
torch.save(cpu_state_dict, save_path)
|
150 |
+
|
151 |
+
|
152 |
+
def smart_tokenizer_and_embedding_resize(
|
153 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
154 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
155 |
+
special_tokens_dict: Dict,
|
156 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
157 |
+
model: transformers.PreTrainedModel,
|
158 |
+
) -> None:
|
159 |
+
"""Resize tokenizer and embedding.
|
160 |
+
|
161 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
162 |
+
"""
|
163 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
164 |
+
# pyre-fixme[16]: `PreTrainedModel` has no attribute `resize_token_embeddings`.
|
165 |
+
model.resize_token_embeddings(len(tokenizer))
|
166 |
+
|
167 |
+
if num_new_tokens > 0:
|
168 |
+
# pyre-fixme[16]: `PreTrainedModel` has no attribute `get_input_embeddings`.
|
169 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
170 |
+
# pyre-fixme[16]: `PreTrainedModel` has no attribute `get_output_embeddings`.
|
171 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
172 |
+
|
173 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
174 |
+
dim=0, keepdim=True
|
175 |
+
)
|
176 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
177 |
+
dim=0, keepdim=True
|
178 |
+
)
|
179 |
+
|
180 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
181 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
182 |
+
|
183 |
+
|
184 |
+
def _tokenize_fn(
|
185 |
+
strings: Sequence[str],
|
186 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
187 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
188 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
189 |
+
) -> Dict:
|
190 |
+
"""Tokenize a list of strings."""
|
191 |
+
tokenized_list = [
|
192 |
+
tokenizer(
|
193 |
+
text,
|
194 |
+
return_tensors="pt",
|
195 |
+
padding="longest",
|
196 |
+
max_length=tokenizer.model_max_length,
|
197 |
+
truncation=True,
|
198 |
+
)
|
199 |
+
for text in strings
|
200 |
+
]
|
201 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
202 |
+
input_ids_lens = labels_lens = [
|
203 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
204 |
+
for tokenized in tokenized_list
|
205 |
+
]
|
206 |
+
return dict(
|
207 |
+
input_ids=input_ids,
|
208 |
+
labels=labels,
|
209 |
+
input_ids_lens=input_ids_lens,
|
210 |
+
labels_lens=labels_lens,
|
211 |
+
)
|
212 |
+
|
213 |
+
|
214 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
215 |
+
def _mask_targets(target, tokenized_lens, speakers) -> None:
|
216 |
+
# cur_idx = 0
|
217 |
+
cur_idx = tokenized_lens[0]
|
218 |
+
tokenized_lens = tokenized_lens[1:]
|
219 |
+
target[:cur_idx] = IGNORE_INDEX
|
220 |
+
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
221 |
+
if speaker == "human":
|
222 |
+
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
|
223 |
+
cur_idx += tokenized_len
|
224 |
+
|
225 |
+
|
226 |
+
# pyre-fixme[3]: Return type must be annotated.
|
227 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
228 |
+
def _add_speaker_and_signal(header, source, get_conversation: bool = True):
|
229 |
+
"""Add speaker and start/end signal on each round."""
|
230 |
+
BEGIN_SIGNAL = "### "
|
231 |
+
END_SIGNAL = "\n"
|
232 |
+
conversation = header
|
233 |
+
for sentence in source:
|
234 |
+
from_str = sentence["from"]
|
235 |
+
if from_str.lower() == "human":
|
236 |
+
from_str = conversation_lib.default_conversation.roles[0]
|
237 |
+
elif from_str.lower() == "gpt":
|
238 |
+
from_str = conversation_lib.default_conversation.roles[1]
|
239 |
+
else:
|
240 |
+
from_str = "unknown"
|
241 |
+
sentence["value"] = (
|
242 |
+
BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
|
243 |
+
)
|
244 |
+
if get_conversation:
|
245 |
+
conversation += sentence["value"]
|
246 |
+
conversation += BEGIN_SIGNAL
|
247 |
+
return conversation
|
248 |
+
|
249 |
+
|
250 |
+
# pyre-fixme[3]: Return type must be annotated.
|
251 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
252 |
+
def expand2square(pil_img, background_color):
|
253 |
+
width, height = pil_img.size
|
254 |
+
if width == height:
|
255 |
+
return pil_img
|
256 |
+
elif width > height:
|
257 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
258 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
259 |
+
return result
|
260 |
+
else:
|
261 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
262 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
263 |
+
return result
|
264 |
+
|
265 |
+
|
266 |
+
# pyre-fixme[3]: Return type must be annotated.
|
267 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
268 |
+
def process_images(images, image_processor, model_cfg):
|
269 |
+
if isinstance(image_processor, list):
|
270 |
+
processor_aux_list = image_processor
|
271 |
+
new_images_aux_list = []
|
272 |
+
for image in images:
|
273 |
+
if isinstance(image, np.ndarray):
|
274 |
+
image = Image.fromarray(image)
|
275 |
+
image_aux_list = []
|
276 |
+
for processor_aux in processor_aux_list:
|
277 |
+
image_aux = image
|
278 |
+
if hasattr(processor_aux, "image_mean"):
|
279 |
+
try:
|
280 |
+
target_resolution = processor_aux.crop_size["height"]
|
281 |
+
except:
|
282 |
+
target_resolution = processor_aux.size["height"]
|
283 |
+
image_aux = expand2square(
|
284 |
+
image_aux, tuple(int(x * 255) for x in processor_aux.image_mean)
|
285 |
+
).resize((target_resolution, target_resolution))
|
286 |
+
image_aux = processor_aux.preprocess(image_aux, return_tensors="pt")[
|
287 |
+
"pixel_values"
|
288 |
+
][0]
|
289 |
+
image_aux_list.append(image_aux)
|
290 |
+
new_images_aux_list.append(image_aux_list)
|
291 |
+
new_images_aux_list = [
|
292 |
+
list(batch_image_aux) for batch_image_aux in zip(*new_images_aux_list)
|
293 |
+
]
|
294 |
+
new_images_aux_list = [
|
295 |
+
torch.stack(image_aux).half().cuda() for image_aux in new_images_aux_list
|
296 |
+
]
|
297 |
+
return new_images_aux_list
|
298 |
+
else:
|
299 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
300 |
+
new_images = []
|
301 |
+
if image_aspect_ratio == "pad":
|
302 |
+
for image in images:
|
303 |
+
image = expand2square(
|
304 |
+
image, tuple(int(x * 255) for x in image_processor.image_mean)
|
305 |
+
)
|
306 |
+
image = image_processor.preprocess(image, return_tensors="pt")[
|
307 |
+
"pixel_values"
|
308 |
+
][0]
|
309 |
+
new_images.append(image)
|
310 |
+
else:
|
311 |
+
return image_processor(images, return_tensors="pt")["pixel_values"]
|
312 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
313 |
+
new_images = torch.stack(new_images, dim=0)
|
314 |
+
return new_images
|
315 |
+
|
316 |
+
|
317 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
318 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
319 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
320 |
+
def preprocess_multimodal(sources: Sequence[str], data_args) -> Dict:
|
321 |
+
is_multimodal = data_args.is_multimodal
|
322 |
+
if not is_multimodal:
|
323 |
+
# pyre-fixme[7]: Expected `Dict[typing.Any, typing.Any]` but got
|
324 |
+
# `Sequence[str]`.
|
325 |
+
return sources
|
326 |
+
|
327 |
+
for source in sources:
|
328 |
+
for sentence in source:
|
329 |
+
if (
|
330 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
|
331 |
+
# but got `str`.
|
332 |
+
DEFAULT_IMAGE_TOKEN in sentence["value"]
|
333 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
|
334 |
+
# but got `str`.
|
335 |
+
or "<video>" in sentence["value"]
|
336 |
+
):
|
337 |
+
# pyre-fixme[16]: `str` has no attribute `__setitem__`.
|
338 |
+
sentence["value"] = (
|
339 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice,
|
340 |
+
# SupportsIndex]` but got `str`.
|
341 |
+
sentence["value"]
|
342 |
+
.replace(DEFAULT_IMAGE_TOKEN, "")
|
343 |
+
.replace("<video>", "")
|
344 |
+
.strip()
|
345 |
+
)
|
346 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice,
|
347 |
+
# SupportsIndex]` but got `str`.
|
348 |
+
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
|
349 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice,
|
350 |
+
# SupportsIndex]` but got `str`.
|
351 |
+
sentence["value"] = sentence["value"].strip()
|
352 |
+
if "mmtag" in conversation_lib.default_conversation.version:
|
353 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice,
|
354 |
+
# SupportsIndex]` but got `str`.
|
355 |
+
sentence["value"] = sentence["value"].replace(
|
356 |
+
DEFAULT_IMAGE_TOKEN,
|
357 |
+
"<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>",
|
358 |
+
)
|
359 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
360 |
+
if data_args.mm_use_im_start_end:
|
361 |
+
replace_token = (
|
362 |
+
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
363 |
+
)
|
364 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
|
365 |
+
# but got `str`.
|
366 |
+
sentence["value"] = sentence["value"].replace(
|
367 |
+
DEFAULT_IMAGE_TOKEN, replace_token
|
368 |
+
)
|
369 |
+
|
370 |
+
# pyre-fixme[7]: Expected `Dict[typing.Any, typing.Any]` but got `Sequence[str]`.
|
371 |
+
return sources
|
372 |
+
|
373 |
+
|
374 |
+
def preprocess_llama_2(
|
375 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
376 |
+
sources,
|
377 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
378 |
+
has_image: bool = False,
|
379 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
380 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
381 |
+
) -> Dict:
|
382 |
+
conv = conversation_lib.default_conversation.copy()
|
383 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
384 |
+
|
385 |
+
# Apply prompt templates
|
386 |
+
conversations = []
|
387 |
+
for i, source in enumerate(sources):
|
388 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
389 |
+
# Skip the first one if it is not from human
|
390 |
+
source = source[1:]
|
391 |
+
|
392 |
+
conv.messages = []
|
393 |
+
for j, sentence in enumerate(source):
|
394 |
+
role = roles[sentence["from"]]
|
395 |
+
assert role == conv.roles[j % 2], f"{i}"
|
396 |
+
conv.append_message(role, sentence["value"])
|
397 |
+
conversations.append(conv.get_prompt())
|
398 |
+
|
399 |
+
# Tokenize conversations
|
400 |
+
|
401 |
+
if has_image:
|
402 |
+
input_ids = torch.stack(
|
403 |
+
[
|
404 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
405 |
+
for prompt in conversations
|
406 |
+
],
|
407 |
+
dim=0,
|
408 |
+
)
|
409 |
+
else:
|
410 |
+
input_ids = tokenizer(
|
411 |
+
conversations,
|
412 |
+
return_tensors="pt",
|
413 |
+
padding="longest",
|
414 |
+
max_length=tokenizer.model_max_length,
|
415 |
+
truncation=True,
|
416 |
+
).input_ids
|
417 |
+
|
418 |
+
targets = input_ids.clone()
|
419 |
+
|
420 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
|
421 |
+
|
422 |
+
# Mask targets
|
423 |
+
sep = "[/INST] "
|
424 |
+
for conversation, target in zip(conversations, targets):
|
425 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
426 |
+
|
427 |
+
rounds = conversation.split(conv.sep2)
|
428 |
+
cur_len = 1
|
429 |
+
target[:cur_len] = IGNORE_INDEX
|
430 |
+
for i, rou in enumerate(rounds):
|
431 |
+
if rou == "":
|
432 |
+
break
|
433 |
+
|
434 |
+
parts = rou.split(sep)
|
435 |
+
if len(parts) != 2:
|
436 |
+
break
|
437 |
+
parts[0] += sep
|
438 |
+
|
439 |
+
if has_image:
|
440 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
441 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
442 |
+
else:
|
443 |
+
round_len = len(tokenizer(rou).input_ids)
|
444 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
445 |
+
|
446 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
447 |
+
|
448 |
+
cur_len += round_len
|
449 |
+
target[cur_len:] = IGNORE_INDEX
|
450 |
+
|
451 |
+
if cur_len < tokenizer.model_max_length:
|
452 |
+
if cur_len != total_len:
|
453 |
+
target[:] = IGNORE_INDEX
|
454 |
+
print(
|
455 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
456 |
+
f" (ignored)"
|
457 |
+
)
|
458 |
+
|
459 |
+
return dict(
|
460 |
+
input_ids=input_ids,
|
461 |
+
labels=targets,
|
462 |
+
)
|
463 |
+
|
464 |
+
|
465 |
+
def preprocess_v1(
|
466 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
467 |
+
sources,
|
468 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
469 |
+
has_image: bool = False,
|
470 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
471 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
472 |
+
) -> Dict:
|
473 |
+
conv = conversation_lib.default_conversation.copy()
|
474 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
475 |
+
|
476 |
+
# Apply prompt templates
|
477 |
+
conversations = []
|
478 |
+
for i, source in enumerate(sources):
|
479 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
480 |
+
# Skip the first one if it is not from human
|
481 |
+
source = source[1:]
|
482 |
+
|
483 |
+
conv.messages = []
|
484 |
+
for j, sentence in enumerate(source):
|
485 |
+
role = roles[sentence["from"]]
|
486 |
+
assert role == conv.roles[j % 2], f"{i}"
|
487 |
+
conv.append_message(role, sentence["value"])
|
488 |
+
conversations.append(conv.get_prompt())
|
489 |
+
|
490 |
+
# Tokenize conversations
|
491 |
+
|
492 |
+
if has_image:
|
493 |
+
input_ids = torch.stack(
|
494 |
+
[
|
495 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
496 |
+
for prompt in conversations
|
497 |
+
],
|
498 |
+
dim=0,
|
499 |
+
)
|
500 |
+
else:
|
501 |
+
input_ids = tokenizer(
|
502 |
+
conversations,
|
503 |
+
return_tensors="pt",
|
504 |
+
padding="longest",
|
505 |
+
max_length=tokenizer.model_max_length,
|
506 |
+
truncation=True,
|
507 |
+
).input_ids
|
508 |
+
|
509 |
+
targets = input_ids.clone()
|
510 |
+
|
511 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
512 |
+
|
513 |
+
# Mask targets
|
514 |
+
sep = conv.sep + conv.roles[1] + ": "
|
515 |
+
for conversation, target in zip(conversations, targets):
|
516 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
517 |
+
|
518 |
+
rounds = conversation.split(conv.sep2)
|
519 |
+
cur_len = 1
|
520 |
+
target[:cur_len] = IGNORE_INDEX
|
521 |
+
for i, rou in enumerate(rounds):
|
522 |
+
if rou == "":
|
523 |
+
break
|
524 |
+
|
525 |
+
parts = rou.split(sep)
|
526 |
+
if len(parts) != 2:
|
527 |
+
break
|
528 |
+
parts[0] += sep
|
529 |
+
|
530 |
+
if has_image:
|
531 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
532 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
533 |
+
else:
|
534 |
+
round_len = len(tokenizer(rou).input_ids)
|
535 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
536 |
+
# pyre-fixme
|
537 |
+
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
|
538 |
+
round_len -= 1
|
539 |
+
instruction_len -= 1
|
540 |
+
|
541 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
542 |
+
|
543 |
+
cur_len += round_len
|
544 |
+
target[cur_len:] = IGNORE_INDEX
|
545 |
+
|
546 |
+
if cur_len < tokenizer.model_max_length:
|
547 |
+
if cur_len != total_len:
|
548 |
+
target[:] = IGNORE_INDEX
|
549 |
+
print(
|
550 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
551 |
+
f" (ignored)"
|
552 |
+
)
|
553 |
+
|
554 |
+
return dict(
|
555 |
+
input_ids=input_ids,
|
556 |
+
labels=targets,
|
557 |
+
)
|
558 |
+
|
559 |
+
|
560 |
+
# pyre-fixme[3]: Return type must be annotated.
|
561 |
+
def tokenizer_image_token(
|
562 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
563 |
+
prompt,
|
564 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
565 |
+
tokenizer,
|
566 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
567 |
+
image_token_index=IMAGE_TOKEN_INDEX,
|
568 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
569 |
+
return_tensors=None,
|
570 |
+
):
|
571 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
572 |
+
|
573 |
+
# pyre-fixme[3]: Return type must be annotated.
|
574 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
575 |
+
def insert_separator(X, sep):
|
576 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
577 |
+
|
578 |
+
input_ids = []
|
579 |
+
offset = 0
|
580 |
+
if (
|
581 |
+
len(prompt_chunks) > 0
|
582 |
+
and len(prompt_chunks[0]) > 0
|
583 |
+
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
584 |
+
):
|
585 |
+
offset = 1
|
586 |
+
input_ids.append(prompt_chunks[0][0])
|
587 |
+
|
588 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
589 |
+
input_ids.extend(x[offset:])
|
590 |
+
|
591 |
+
if return_tensors is not None:
|
592 |
+
if return_tensors == "pt":
|
593 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
594 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
595 |
+
return input_ids
|
596 |
+
|
597 |
+
|
598 |
+
# pyre-fixme[3]: Return type must be annotated.
|
599 |
+
def tokenizer_image_token_llama3(
|
600 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
601 |
+
prompt,
|
602 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
603 |
+
tokenizer,
|
604 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
605 |
+
image_token_index=IMAGE_TOKEN_INDEX,
|
606 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
607 |
+
return_tensors=None,
|
608 |
+
):
|
609 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
610 |
+
|
611 |
+
# pyre-fixme[3]: Return type must be annotated.
|
612 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
613 |
+
def insert_separator(X, sep):
|
614 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
615 |
+
|
616 |
+
input_ids = []
|
617 |
+
for x in insert_separator(prompt_chunks, [image_token_index]):
|
618 |
+
input_ids.extend(x)
|
619 |
+
|
620 |
+
if return_tensors is not None:
|
621 |
+
if return_tensors == "pt":
|
622 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
623 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
624 |
+
return input_ids
|
625 |
+
|
626 |
+
|
627 |
+
def preprocess_qwen(
|
628 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
629 |
+
sources,
|
630 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
631 |
+
has_image: bool = False,
|
632 |
+
system_message: str = "You are a helpful assistant.",
|
633 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
634 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
635 |
+
) -> Dict:
|
636 |
+
# roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
|
637 |
+
roles = {"human": "user", "gpt": "assistant"}
|
638 |
+
|
639 |
+
# Add image tokens to tokenizer as a special tokens
|
640 |
+
# Use a deepcopy of tokenizer so that we don't modify on the tokenizer
|
641 |
+
tokenizer = copy.deepcopy(tokenizer)
|
642 |
+
# When there is actually an image, we add the image tokens as a special token
|
643 |
+
if has_image:
|
644 |
+
tokenizer.add_tokens(["<image>"], special_tokens=True)
|
645 |
+
|
646 |
+
image_token_index = tokenizer.convert_tokens_to_ids("<image>")
|
647 |
+
im_start, im_end = tokenizer.additional_special_tokens_ids
|
648 |
+
# unmask_tokens = ["<|im_start|>", "<|im_start|>", "\n"]
|
649 |
+
unmask_tokens_idx = [198, im_start, im_end]
|
650 |
+
nl_tokens = tokenizer("\n").input_ids
|
651 |
+
|
652 |
+
# Reset Qwen chat templates so that it won't include system message every time we apply
|
653 |
+
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
654 |
+
tokenizer.chat_template = chat_template
|
655 |
+
|
656 |
+
# _system = tokenizer("system").input_ids + nl_tokens
|
657 |
+
# _user = tokenizer("user").input_ids + nl_tokens
|
658 |
+
# _assistant = tokenizer("assistant").input_ids + nl_tokens
|
659 |
+
|
660 |
+
# Apply prompt templates
|
661 |
+
input_ids, targets = [], []
|
662 |
+
for i, source in enumerate(sources):
|
663 |
+
if roles[source[0]["from"]] != roles["human"]:
|
664 |
+
source = source[1:]
|
665 |
+
|
666 |
+
input_id, target = [], []
|
667 |
+
|
668 |
+
# New version, use apply chat template
|
669 |
+
# Build system message for each sentence
|
670 |
+
input_id += tokenizer.apply_chat_template(
|
671 |
+
[{"role": "system", "content": system_message}]
|
672 |
+
)
|
673 |
+
target += [IGNORE_INDEX] * len(input_id)
|
674 |
+
|
675 |
+
for conv in source:
|
676 |
+
# Make sure llava data can load
|
677 |
+
try:
|
678 |
+
role = conv["role"]
|
679 |
+
content = conv["content"]
|
680 |
+
except:
|
681 |
+
role = conv["from"]
|
682 |
+
content = conv["value"]
|
683 |
+
|
684 |
+
role = roles.get(role, role)
|
685 |
+
|
686 |
+
conv = [{"role": role, "content": content}]
|
687 |
+
encode_id = tokenizer.apply_chat_template(conv)
|
688 |
+
input_id += encode_id
|
689 |
+
if role in ["user", "system"]:
|
690 |
+
target += [IGNORE_INDEX] * len(encode_id)
|
691 |
+
else:
|
692 |
+
target += encode_id
|
693 |
+
|
694 |
+
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
|
695 |
+
for idx, encode_id in enumerate(input_id):
|
696 |
+
if encode_id in unmask_tokens_idx:
|
697 |
+
target[idx] = encode_id
|
698 |
+
if encode_id == image_token_index:
|
699 |
+
input_id[idx] = IMAGE_TOKEN_INDEX
|
700 |
+
input_ids.append(input_id)
|
701 |
+
targets.append(target)
|
702 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
703 |
+
targets = torch.tensor(targets, dtype=torch.long)
|
704 |
+
|
705 |
+
return dict(
|
706 |
+
input_ids=input_ids, # tensor(bs x seq_len)
|
707 |
+
labels=targets, # tensor(bs x seq_len)
|
708 |
+
)
|
709 |
+
|
710 |
+
|
711 |
+
def preprocess_llama3(
|
712 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
713 |
+
sources,
|
714 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
715 |
+
has_image: bool = False,
|
716 |
+
system_message: str = "You are a helpful assistant.",
|
717 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
718 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
719 |
+
) -> Dict:
|
720 |
+
# roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
|
721 |
+
roles = {"human": "user", "gpt": "assistant"}
|
722 |
+
|
723 |
+
# Add image tokens to tokenizer as a special tokens
|
724 |
+
# Use a deepcopy of tokenizer so that we don't modify on the tokenizer
|
725 |
+
tokenizer = copy.deepcopy(tokenizer)
|
726 |
+
# When there is actually an image, we add the image tokens as a special token
|
727 |
+
if has_image:
|
728 |
+
tokenizer.add_tokens(["<image>"], special_tokens=True)
|
729 |
+
image_token_index = tokenizer.convert_tokens_to_ids("<image>")
|
730 |
+
bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
|
731 |
+
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
|
732 |
+
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
|
733 |
+
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
734 |
+
|
735 |
+
unmask_tokens = [
|
736 |
+
"<|begin_of_text|>",
|
737 |
+
"<|start_header_id|>",
|
738 |
+
"<|end_header_id|>",
|
739 |
+
"<|eot_id|>",
|
740 |
+
"\n\n",
|
741 |
+
]
|
742 |
+
unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens]
|
743 |
+
|
744 |
+
# After update, calling tokenizer of llama3 will
|
745 |
+
# auto add bos id for the tokens. ヽ(`⌒´)ノ
|
746 |
+
# pyre-fixme[53]: Captured variable `bos_token_id` is not annotated.
|
747 |
+
# pyre-fixme[3]: Return type must be annotated.
|
748 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
749 |
+
def safe_tokenizer_llama3(text):
|
750 |
+
input_ids = tokenizer(text).input_ids
|
751 |
+
if input_ids[0] == bos_token_id:
|
752 |
+
input_ids = input_ids[1:]
|
753 |
+
return input_ids
|
754 |
+
|
755 |
+
nl_tokens = tokenizer.convert_tokens_to_ids("\n\n")
|
756 |
+
|
757 |
+
# chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{%- if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{%- endif %}"
|
758 |
+
chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
|
759 |
+
tokenizer.chat_template = chat_template
|
760 |
+
|
761 |
+
# Apply prompt templates
|
762 |
+
input_ids, targets = [], []
|
763 |
+
for i, source in enumerate(sources):
|
764 |
+
if roles[source[0]["from"]] != roles["human"]:
|
765 |
+
source = source[1:]
|
766 |
+
|
767 |
+
input_id, target = [], []
|
768 |
+
|
769 |
+
# New version, use apply chat template
|
770 |
+
# Build system message for each sentence
|
771 |
+
input_id += tokenizer.apply_chat_template(
|
772 |
+
[{"role": "system", "content": system_message}]
|
773 |
+
# pyre-fixme[6]: For 1st argument expected `Union[int, str]` but got `slice`.
|
774 |
+
)[:-4]
|
775 |
+
|
776 |
+
target += [IGNORE_INDEX] * len(input_id)
|
777 |
+
|
778 |
+
for conv in source:
|
779 |
+
# Make sure llava data can load
|
780 |
+
try:
|
781 |
+
role = conv["role"]
|
782 |
+
content = conv["content"]
|
783 |
+
except:
|
784 |
+
role = conv["from"]
|
785 |
+
content = conv["value"]
|
786 |
+
|
787 |
+
role = roles.get(role, role)
|
788 |
+
|
789 |
+
conv = [{"role": role, "content": content}]
|
790 |
+
# First is bos token we don't need here
|
791 |
+
# pyre-fixme[6]: For 1st argument expected `Union[int, str]` but got
|
792 |
+
# `slice`.
|
793 |
+
encode_id = tokenizer.apply_chat_template(conv)[1:-4]
|
794 |
+
input_id += encode_id
|
795 |
+
if role in ["user", "system"]:
|
796 |
+
target += [IGNORE_INDEX] * len(encode_id)
|
797 |
+
else:
|
798 |
+
target += encode_id
|
799 |
+
|
800 |
+
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
|
801 |
+
for idx, encode_id in enumerate(input_id):
|
802 |
+
if encode_id in unmask_tokens_idx:
|
803 |
+
target[idx] = encode_id
|
804 |
+
if encode_id == image_token_index:
|
805 |
+
input_id[idx] = IMAGE_TOKEN_INDEX
|
806 |
+
input_ids.append(input_id)
|
807 |
+
targets.append(target)
|
808 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
809 |
+
targets = torch.tensor(targets, dtype=torch.long)
|
810 |
+
|
811 |
+
print("input_ids", input_ids, flush=True)
|
812 |
+
print("targets", targets, flush=True)
|
813 |
+
|
814 |
+
return dict(
|
815 |
+
input_ids=input_ids, # tensor(bs x seq_len)
|
816 |
+
labels=targets, # tensor(bs x seq_len)
|
817 |
+
)
|
818 |
+
|
819 |
+
|
820 |
+
def preprocess_llama_3_1(
|
821 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
822 |
+
sources,
|
823 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
824 |
+
has_image: bool = False,
|
825 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
826 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
827 |
+
) -> Dict:
|
828 |
+
conv = conversation_lib.default_conversation.copy()
|
829 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
830 |
+
|
831 |
+
# Apply prompt templates
|
832 |
+
conversations = []
|
833 |
+
for i, source in enumerate(sources):
|
834 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
835 |
+
# Skip the first one if it is not from human
|
836 |
+
source = source[1:]
|
837 |
+
|
838 |
+
conv.messages = []
|
839 |
+
for j, sentence in enumerate(source):
|
840 |
+
if sentence["from"] == "Answer":
|
841 |
+
sentence["from"] = "gpt" # data bug
|
842 |
+
role = roles[sentence["from"]]
|
843 |
+
# assert role == conv.roles[j % 2], f"{i}"
|
844 |
+
conv.append_message(role, sentence["value"])
|
845 |
+
conversations.append(conv.get_prompt())
|
846 |
+
|
847 |
+
# Tokenize conversations
|
848 |
+
|
849 |
+
if has_image:
|
850 |
+
input_ids = torch.stack(
|
851 |
+
[
|
852 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
853 |
+
for prompt in conversations
|
854 |
+
],
|
855 |
+
dim=0,
|
856 |
+
)
|
857 |
+
else:
|
858 |
+
input_ids = tokenizer(
|
859 |
+
conversations,
|
860 |
+
return_tensors="pt",
|
861 |
+
padding="longest",
|
862 |
+
max_length=tokenizer.model_max_length,
|
863 |
+
truncation=True,
|
864 |
+
).input_ids
|
865 |
+
|
866 |
+
# remove the first bos token
|
867 |
+
if input_ids[0][0] == input_ids[0][1] == tokenizer.bos_token_id:
|
868 |
+
input_ids = input_ids[:, 1:]
|
869 |
+
targets = input_ids.clone()
|
870 |
+
|
871 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3_1
|
872 |
+
|
873 |
+
# Mask targets
|
874 |
+
sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>" + "\n\n"
|
875 |
+
# sep = conv.sep + conv.roles[1] + ": "
|
876 |
+
for conversation, target in zip(conversations, targets):
|
877 |
+
total_len = int(target.shape[0])
|
878 |
+
|
879 |
+
rounds = conversation.split(conv.tokenizer.eos_token)
|
880 |
+
rounds = [rounds[0]] + [
|
881 |
+
rounds[idx] + rounds[idx + 1] for idx in range(1, len(rounds) - 1, 2)
|
882 |
+
]
|
883 |
+
|
884 |
+
cur_len = 1
|
885 |
+
target[:cur_len] = IGNORE_INDEX
|
886 |
+
for i, rou in enumerate(rounds):
|
887 |
+
if rou == "":
|
888 |
+
break
|
889 |
+
|
890 |
+
parts = rou.split(sep)
|
891 |
+
if len(parts) != 2 and i != 0:
|
892 |
+
break
|
893 |
+
|
894 |
+
if i == 0:
|
895 |
+
round_len = len(tokenizer(rou, add_special_tokens=False).input_ids)
|
896 |
+
instruction_len = len(
|
897 |
+
tokenizer(rou, add_special_tokens=False).input_ids
|
898 |
+
)
|
899 |
+
|
900 |
+
else:
|
901 |
+
parts[0] += sep
|
902 |
+
if has_image:
|
903 |
+
round_len = len(tokenizer_image_token(rou, tokenizer)) + 1
|
904 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
|
905 |
+
else:
|
906 |
+
round_len = len(tokenizer(rou).input_ids) + 1
|
907 |
+
instruction_len = len(tokenizer(parts[0]).input_ids)
|
908 |
+
|
909 |
+
# if i > 0: round_len += 1
|
910 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
911 |
+
cur_len += round_len
|
912 |
+
|
913 |
+
target[cur_len:] = IGNORE_INDEX
|
914 |
+
cur_len = cur_len + len(tokenizer(sep, add_special_tokens=False).input_ids)
|
915 |
+
|
916 |
+
# if cur_len > tokenizer.model_max_length: print(f"WARNING: max length context")
|
917 |
+
if cur_len < tokenizer.model_max_length:
|
918 |
+
if cur_len != total_len:
|
919 |
+
target[:] = IGNORE_INDEX
|
920 |
+
print(
|
921 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
922 |
+
f" (ignored)"
|
923 |
+
)
|
924 |
+
|
925 |
+
return dict(
|
926 |
+
input_ids=input_ids,
|
927 |
+
labels=targets,
|
928 |
+
)
|
929 |
+
|
930 |
+
|
931 |
+
def preprocess_llama_3_2(
|
932 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
933 |
+
sources,
|
934 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
935 |
+
has_image: bool = False,
|
936 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
937 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
938 |
+
) -> Dict:
|
939 |
+
conv = conversation_lib.default_conversation.copy()
|
940 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
941 |
+
|
942 |
+
# Apply prompt templates
|
943 |
+
conversations = []
|
944 |
+
for i, source in enumerate(sources):
|
945 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
946 |
+
# Skip the first one if it is not from human
|
947 |
+
source = source[1:]
|
948 |
+
|
949 |
+
conv.messages = []
|
950 |
+
for j, sentence in enumerate(source):
|
951 |
+
role = roles[sentence["from"]]
|
952 |
+
assert role == conv.roles[j % 2], f"{i}"
|
953 |
+
conv.append_message(role, sentence["value"])
|
954 |
+
conversations.append(conv.get_prompt())
|
955 |
+
|
956 |
+
# Tokenize conversations
|
957 |
+
|
958 |
+
if has_image:
|
959 |
+
input_ids = torch.stack(
|
960 |
+
[
|
961 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
962 |
+
for prompt in conversations
|
963 |
+
],
|
964 |
+
dim=0,
|
965 |
+
)
|
966 |
+
else:
|
967 |
+
input_ids = tokenizer(
|
968 |
+
conversations,
|
969 |
+
return_tensors="pt",
|
970 |
+
padding="longest",
|
971 |
+
max_length=tokenizer.model_max_length,
|
972 |
+
truncation=True,
|
973 |
+
).input_ids
|
974 |
+
|
975 |
+
# remove the first bos token
|
976 |
+
if input_ids[0][0] == input_ids[0][1] == tokenizer.bos_token_id:
|
977 |
+
input_ids = input_ids[:, 1:]
|
978 |
+
targets = input_ids.clone()
|
979 |
+
|
980 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3_2
|
981 |
+
|
982 |
+
# Mask targets
|
983 |
+
sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>" + "\n\n"
|
984 |
+
# sep = conv.sep + conv.roles[1] + ": "
|
985 |
+
for conversation, target in zip(conversations, targets):
|
986 |
+
total_len = int(target.shape[0])
|
987 |
+
|
988 |
+
rounds = conversation.split(conv.tokenizer.eos_token)
|
989 |
+
rounds = [rounds[0]] + [
|
990 |
+
rounds[idx] + rounds[idx + 1] for idx in range(1, len(rounds) - 1, 2)
|
991 |
+
]
|
992 |
+
|
993 |
+
cur_len = 1
|
994 |
+
target[:cur_len] = IGNORE_INDEX
|
995 |
+
for i, rou in enumerate(rounds):
|
996 |
+
if rou == "":
|
997 |
+
break
|
998 |
+
|
999 |
+
parts = rou.split(sep)
|
1000 |
+
if len(parts) != 2 and i != 0:
|
1001 |
+
break
|
1002 |
+
|
1003 |
+
if i == 0:
|
1004 |
+
round_len = len(tokenizer(rou, add_special_tokens=False).input_ids)
|
1005 |
+
instruction_len = len(
|
1006 |
+
tokenizer(rou, add_special_tokens=False).input_ids
|
1007 |
+
)
|
1008 |
+
|
1009 |
+
else:
|
1010 |
+
parts[0] += sep
|
1011 |
+
if has_image:
|
1012 |
+
round_len = len(tokenizer_image_token(rou, tokenizer)) + 1
|
1013 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
|
1014 |
+
else:
|
1015 |
+
round_len = len(tokenizer(rou).input_ids) + 1
|
1016 |
+
instruction_len = len(tokenizer(parts[0]).input_ids)
|
1017 |
+
|
1018 |
+
# if i > 0: round_len += 1
|
1019 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
1020 |
+
cur_len += round_len
|
1021 |
+
|
1022 |
+
target[cur_len:] = IGNORE_INDEX
|
1023 |
+
cur_len = cur_len + len(tokenizer(sep, add_special_tokens=False).input_ids)
|
1024 |
+
|
1025 |
+
# if cur_len > tokenizer.model_max_length: print(f"WARNING: max length context")
|
1026 |
+
if cur_len < tokenizer.model_max_length:
|
1027 |
+
if cur_len != total_len:
|
1028 |
+
target[:] = IGNORE_INDEX
|
1029 |
+
print(
|
1030 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
1031 |
+
f" (ignored)"
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
return dict(
|
1035 |
+
input_ids=input_ids,
|
1036 |
+
labels=targets,
|
1037 |
+
)
|
1038 |
+
|
1039 |
+
|
1040 |
+
def preprocess_phi3(
|
1041 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
1042 |
+
sources,
|
1043 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
1044 |
+
has_image: bool = False,
|
1045 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
1046 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
1047 |
+
) -> Dict:
|
1048 |
+
conv = conversation_lib.conv_templates["phi3"].copy()
|
1049 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
1050 |
+
|
1051 |
+
# Apply prompt templates
|
1052 |
+
conversations = []
|
1053 |
+
for i, source in enumerate(sources):
|
1054 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
1055 |
+
# Skip the first one if it is not from human
|
1056 |
+
source = source[1:]
|
1057 |
+
|
1058 |
+
conv.messages = []
|
1059 |
+
for j, sentence in enumerate(source):
|
1060 |
+
role = roles[sentence["from"]]
|
1061 |
+
assert role == conv.roles[j % 2], f"{i}"
|
1062 |
+
conv.append_message(role, sentence["value"])
|
1063 |
+
conversations.append(conv.get_prompt())
|
1064 |
+
|
1065 |
+
# Tokenize conversations
|
1066 |
+
if has_image:
|
1067 |
+
input_ids = torch.stack(
|
1068 |
+
[
|
1069 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
1070 |
+
for prompt in conversations
|
1071 |
+
],
|
1072 |
+
dim=0,
|
1073 |
+
)
|
1074 |
+
else:
|
1075 |
+
input_ids = tokenizer(
|
1076 |
+
conversations,
|
1077 |
+
return_tensors="pt",
|
1078 |
+
padding="longest",
|
1079 |
+
max_length=tokenizer.model_max_length,
|
1080 |
+
truncation=True,
|
1081 |
+
).input_ids
|
1082 |
+
|
1083 |
+
targets = input_ids.clone()
|
1084 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
|
1085 |
+
|
1086 |
+
# Mask targets
|
1087 |
+
sep = conv.sep + conv.roles[1]
|
1088 |
+
for conversation, target in zip(conversations, targets):
|
1089 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
1090 |
+
|
1091 |
+
rounds = conversation.split(conv.sep)
|
1092 |
+
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
|
1093 |
+
for conv_idx in range(3, len(rounds), 2):
|
1094 |
+
re_rounds.append(
|
1095 |
+
conv.sep.join(rounds[conv_idx : conv_idx + 2])
|
1096 |
+
) # user + gpt
|
1097 |
+
cur_len = 0
|
1098 |
+
target[:cur_len] = IGNORE_INDEX
|
1099 |
+
for i, rou in enumerate(re_rounds):
|
1100 |
+
if rou == "":
|
1101 |
+
break
|
1102 |
+
|
1103 |
+
parts = rou.split(sep)
|
1104 |
+
if len(parts) != 2:
|
1105 |
+
break
|
1106 |
+
parts[0] += sep
|
1107 |
+
|
1108 |
+
if has_image:
|
1109 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
1110 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
|
1111 |
+
else:
|
1112 |
+
round_len = len(tokenizer(rou).input_ids)
|
1113 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
|
1114 |
+
|
1115 |
+
if i == 0:
|
1116 |
+
round_len += 1
|
1117 |
+
instruction_len += 1
|
1118 |
+
else:
|
1119 |
+
round_len -= 2
|
1120 |
+
instruction_len -= 2
|
1121 |
+
|
1122 |
+
if (
|
1123 |
+
i != 0
|
1124 |
+
and getattr(tokenizer, "legacy", False)
|
1125 |
+
and IS_TOKENIZER_GREATER_THAN_0_14
|
1126 |
+
):
|
1127 |
+
round_len += 1
|
1128 |
+
instruction_len += 1
|
1129 |
+
|
1130 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
1131 |
+
|
1132 |
+
cur_len += round_len
|
1133 |
+
target[cur_len:] = IGNORE_INDEX
|
1134 |
+
|
1135 |
+
if cur_len < tokenizer.model_max_length:
|
1136 |
+
if cur_len != total_len:
|
1137 |
+
target[:] = IGNORE_INDEX
|
1138 |
+
print(
|
1139 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
1140 |
+
f" (ignored)"
|
1141 |
+
)
|
1142 |
+
|
1143 |
+
return dict(
|
1144 |
+
input_ids=input_ids,
|
1145 |
+
labels=targets,
|
1146 |
+
)
|
1147 |
+
|
1148 |
+
|
1149 |
+
def preprocess_mpt(
|
1150 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
1151 |
+
sources,
|
1152 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
1153 |
+
has_image: bool = False,
|
1154 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
1155 |
+
) -> Dict:
|
1156 |
+
conv = conversation_lib.default_conversation.copy()
|
1157 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
1158 |
+
|
1159 |
+
# Apply prompt templates
|
1160 |
+
conversations = []
|
1161 |
+
for i, source in enumerate(sources):
|
1162 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
1163 |
+
# Skip the first one if it is not from human
|
1164 |
+
source = source[1:]
|
1165 |
+
|
1166 |
+
conv.messages = []
|
1167 |
+
for j, sentence in enumerate(source):
|
1168 |
+
role = roles[sentence["from"]]
|
1169 |
+
assert role == conv.roles[j % 2], f"{i}"
|
1170 |
+
conv.append_message(role, sentence["value"])
|
1171 |
+
conversations.append(conv.get_prompt())
|
1172 |
+
|
1173 |
+
# Tokenize conversations
|
1174 |
+
if has_image:
|
1175 |
+
input_ids = torch.stack(
|
1176 |
+
[
|
1177 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
1178 |
+
for prompt in conversations
|
1179 |
+
],
|
1180 |
+
dim=0,
|
1181 |
+
)
|
1182 |
+
else:
|
1183 |
+
input_ids = tokenizer(
|
1184 |
+
conversations,
|
1185 |
+
return_tensors="pt",
|
1186 |
+
padding="longest",
|
1187 |
+
max_length=tokenizer.model_max_length,
|
1188 |
+
truncation=True,
|
1189 |
+
).input_ids
|
1190 |
+
|
1191 |
+
targets = input_ids.clone()
|
1192 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
|
1193 |
+
|
1194 |
+
# Mask targets
|
1195 |
+
sep = conv.sep + conv.roles[1]
|
1196 |
+
for conversation, target in zip(conversations, targets):
|
1197 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
1198 |
+
|
1199 |
+
rounds = conversation.split(conv.sep)
|
1200 |
+
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
|
1201 |
+
for conv_idx in range(3, len(rounds), 2):
|
1202 |
+
re_rounds.append(
|
1203 |
+
conv.sep.join(rounds[conv_idx : conv_idx + 2])
|
1204 |
+
) # user + gpt
|
1205 |
+
cur_len = 0
|
1206 |
+
target[:cur_len] = IGNORE_INDEX
|
1207 |
+
for i, rou in enumerate(re_rounds):
|
1208 |
+
if rou == "":
|
1209 |
+
break
|
1210 |
+
|
1211 |
+
parts = rou.split(sep)
|
1212 |
+
if len(parts) != 2:
|
1213 |
+
break
|
1214 |
+
parts[0] += sep
|
1215 |
+
if has_image:
|
1216 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
1217 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
|
1218 |
+
else:
|
1219 |
+
round_len = len(tokenizer(rou).input_ids)
|
1220 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
|
1221 |
+
|
1222 |
+
if (
|
1223 |
+
i != 0
|
1224 |
+
and getattr(tokenizer, "legacy", False)
|
1225 |
+
and IS_TOKENIZER_GREATER_THAN_0_14
|
1226 |
+
):
|
1227 |
+
round_len += 1
|
1228 |
+
instruction_len += 1
|
1229 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
1230 |
+
|
1231 |
+
cur_len += round_len
|
1232 |
+
target[cur_len:] = IGNORE_INDEX
|
1233 |
+
|
1234 |
+
if cur_len < tokenizer.model_max_length:
|
1235 |
+
if cur_len != total_len:
|
1236 |
+
target[:] = IGNORE_INDEX
|
1237 |
+
print(
|
1238 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
1239 |
+
f" (ignored)"
|
1240 |
+
)
|
1241 |
+
|
1242 |
+
return dict(
|
1243 |
+
input_ids=input_ids,
|
1244 |
+
labels=targets,
|
1245 |
+
)
|
1246 |
+
|
1247 |
+
|
1248 |
+
def preprocess_plain(
|
1249 |
+
sources: Sequence[str],
|
1250 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
1251 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
1252 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
1253 |
+
) -> Dict:
|
1254 |
+
# add end signal and concatenate together
|
1255 |
+
conversations = []
|
1256 |
+
for source in sources:
|
1257 |
+
assert len(source) == 2
|
1258 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]` but
|
1259 |
+
# got `str`.
|
1260 |
+
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
|
1261 |
+
# pyre-fixme[16]: `str` has no attribute `__setitem__`.
|
1262 |
+
source[0]["value"] = DEFAULT_IMAGE_TOKEN
|
1263 |
+
conversation = (
|
1264 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
|
1265 |
+
# but got `str`.
|
1266 |
+
source[0]["value"]
|
1267 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
|
1268 |
+
# but got `str`.
|
1269 |
+
+ source[1]["value"]
|
1270 |
+
+ conversation_lib.default_conversation.sep
|
1271 |
+
)
|
1272 |
+
conversations.append(conversation)
|
1273 |
+
# tokenize conversations
|
1274 |
+
input_ids = [
|
1275 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
1276 |
+
for prompt in conversations
|
1277 |
+
]
|
1278 |
+
targets = copy.deepcopy(input_ids)
|
1279 |
+
for target, source in zip(targets, sources):
|
1280 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]` but
|
1281 |
+
# got `str`.
|
1282 |
+
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
|
1283 |
+
target[:tokenized_len] = IGNORE_INDEX
|
1284 |
+
|
1285 |
+
return dict(input_ids=input_ids, labels=targets)
|
1286 |
+
|
1287 |
+
|
1288 |
+
def preprocess(
|
1289 |
+
sources: Sequence[str],
|
1290 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
1291 |
+
has_image: bool = False,
|
1292 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
1293 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
1294 |
+
) -> Dict:
|
1295 |
+
"""
|
1296 |
+
Given a list of sources, each is a conversation list. This transform:
|
1297 |
+
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
1298 |
+
2. Concatenate conversations together;
|
1299 |
+
3. Tokenize the concatenated conversation;
|
1300 |
+
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
1301 |
+
"""
|
1302 |
+
if (
|
1303 |
+
conversation_lib.default_conversation.sep_style
|
1304 |
+
== conversation_lib.SeparatorStyle.PLAIN
|
1305 |
+
):
|
1306 |
+
return preprocess_plain(sources, tokenizer)
|
1307 |
+
if (
|
1308 |
+
conversation_lib.default_conversation.sep_style
|
1309 |
+
== conversation_lib.SeparatorStyle.LLAMA_2
|
1310 |
+
):
|
1311 |
+
return preprocess_llama_2(sources, tokenizer, has_image=has_image)
|
1312 |
+
if conversation_lib.default_conversation.version.startswith("v1"):
|
1313 |
+
return preprocess_v1(sources, tokenizer, has_image=has_image)
|
1314 |
+
if conversation_lib.default_conversation.version == "mpt":
|
1315 |
+
return preprocess_mpt(sources, tokenizer, has_image=has_image)
|
1316 |
+
if conversation_lib.default_conversation.version == "llama3":
|
1317 |
+
return preprocess_llama3(sources, tokenizer, has_image=has_image)
|
1318 |
+
if conversation_lib.default_conversation.version == "llama3_1":
|
1319 |
+
return preprocess_llama_3_1(sources, tokenizer, has_image=has_image)
|
1320 |
+
if conversation_lib.default_conversation.version == "llama3_2":
|
1321 |
+
return preprocess_llama_3_2(sources, tokenizer, has_image=has_image)
|
1322 |
+
if conversation_lib.default_conversation.version == "phi3":
|
1323 |
+
return preprocess_phi3(sources, tokenizer, has_image=has_image)
|
1324 |
+
if conversation_lib.default_conversation.version == "qwen":
|
1325 |
+
return preprocess_qwen(sources, tokenizer, has_image=has_image)
|
1326 |
+
# add end signal and concatenate together
|
1327 |
+
conversations = []
|
1328 |
+
for source in sources:
|
1329 |
+
header = f"{conversation_lib.default_conversation.system}\n\n"
|
1330 |
+
conversation = _add_speaker_and_signal(header, source)
|
1331 |
+
conversations.append(conversation)
|
1332 |
+
|
1333 |
+
# tokenize conversations
|
1334 |
+
# pyre-fixme[3]: Return type must be annotated.
|
1335 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
1336 |
+
def get_tokenize_len(prompts):
|
1337 |
+
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
|
1338 |
+
|
1339 |
+
if has_image:
|
1340 |
+
input_ids = [
|
1341 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
1342 |
+
for prompt in conversations
|
1343 |
+
]
|
1344 |
+
else:
|
1345 |
+
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
|
1346 |
+
input_ids = conversations_tokenized["input_ids"]
|
1347 |
+
|
1348 |
+
targets = copy.deepcopy(input_ids)
|
1349 |
+
for target, source in zip(targets, sources):
|
1350 |
+
if has_image:
|
1351 |
+
# pyre-fixme[61]: `header` is undefined, or not always defined.
|
1352 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]`
|
1353 |
+
# but got `str`.
|
1354 |
+
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
|
1355 |
+
else:
|
1356 |
+
tokenized_lens = _tokenize_fn(
|
1357 |
+
# pyre-fixme[61]: `header` is undefined, or not always defined.
|
1358 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice,
|
1359 |
+
# SupportsIndex]` but got `str`.
|
1360 |
+
[header] + [s["value"] for s in source],
|
1361 |
+
tokenizer,
|
1362 |
+
)["input_ids_lens"]
|
1363 |
+
# pyre-fixme[6]: For 1st argument expected `Union[slice, SupportsIndex]` but
|
1364 |
+
# got `str`.
|
1365 |
+
speakers = [sentence["from"] for sentence in source]
|
1366 |
+
_mask_targets(target, tokenized_lens, speakers)
|
1367 |
+
|
1368 |
+
return dict(input_ids=input_ids, labels=targets)
|
1369 |
+
|
1370 |
+
|
1371 |
+
class LazySupervisedDataset(Dataset):
|
1372 |
+
"""Dataset for supervised fine-tuning."""
|
1373 |
+
|
1374 |
+
def __init__(
|
1375 |
+
self,
|
1376 |
+
data_path: str,
|
1377 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
1378 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
1379 |
+
data_args,
|
1380 |
+
) -> None:
|
1381 |
+
super(LazySupervisedDataset, self).__init__()
|
1382 |
+
list_data_dict = json.load(open(data_path, "r"))
|
1383 |
+
|
1384 |
+
self.tokenizer = tokenizer
|
1385 |
+
# pyre-fixme[4]: Attribute must be annotated.
|
1386 |
+
self.list_data_dict = list_data_dict
|
1387 |
+
# pyre-fixme[4]: Attribute must be annotated.
|
1388 |
+
self.data_args = data_args
|
1389 |
+
|
1390 |
+
@property
|
1391 |
+
# pyre-fixme[3]: Return type must be annotated.
|
1392 |
+
def lengths(self):
|
1393 |
+
length_list = []
|
1394 |
+
for sample in self.list_data_dict:
|
1395 |
+
img_tokens = 128 if "image" in sample else 0
|
1396 |
+
length_list.append(
|
1397 |
+
sum(len(conv["value"].split()) for conv in sample["conversations"])
|
1398 |
+
+ img_tokens
|
1399 |
+
)
|
1400 |
+
return length_list
|
1401 |
+
|
1402 |
+
@property
|
1403 |
+
def modality_lengths(self) -> List[int]:
|
1404 |
+
length_list = []
|
1405 |
+
for sample in self.list_data_dict:
|
1406 |
+
cur_len = sum(
|
1407 |
+
len(conv["value"].split()) for conv in sample["conversations"]
|
1408 |
+
)
|
1409 |
+
cur_len = (
|
1410 |
+
cur_len if ("image" in sample) or ("video" in sample) else -cur_len
|
1411 |
+
)
|
1412 |
+
length_list.append(cur_len)
|
1413 |
+
return length_list
|
1414 |
+
|
1415 |
+
def __len__(self) -> int:
|
1416 |
+
return len(self.list_data_dict)
|
1417 |
+
|
1418 |
+
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
|
1419 |
+
sources = self.list_data_dict[i]
|
1420 |
+
if isinstance(i, int):
|
1421 |
+
sources = [sources]
|
1422 |
+
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
1423 |
+
has_image = True
|
1424 |
+
if "image" in sources[0]:
|
1425 |
+
image_file = self.list_data_dict[i]["image"]
|
1426 |
+
image_folder = self.data_args.image_folder
|
1427 |
+
processor = self.data_args.image_processor
|
1428 |
+
full_path = os.path.join(image_folder, image_file)
|
1429 |
+
if not os.path.exists(full_path):
|
1430 |
+
print(full_path)
|
1431 |
+
has_image = False
|
1432 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
1433 |
+
else:
|
1434 |
+
image = Image.open(full_path).convert("RGB")
|
1435 |
+
if self.data_args.image_aspect_ratio == "sam":
|
1436 |
+
image = np.array(image)[:, :, ::-1]
|
1437 |
+
if self.data_args.image_aspect_ratio == "pad":
|
1438 |
+
# pyre-fixme[3]: Return type must be annotated.
|
1439 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
1440 |
+
def expand2square(pil_img, background_color):
|
1441 |
+
width, height = pil_img.size
|
1442 |
+
if width == height:
|
1443 |
+
return pil_img
|
1444 |
+
elif width > height:
|
1445 |
+
result = Image.new(
|
1446 |
+
pil_img.mode, (width, width), background_color
|
1447 |
+
)
|
1448 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
1449 |
+
return result
|
1450 |
+
else:
|
1451 |
+
result = Image.new(
|
1452 |
+
pil_img.mode, (height, height), background_color
|
1453 |
+
)
|
1454 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
1455 |
+
return result
|
1456 |
+
|
1457 |
+
image = expand2square(
|
1458 |
+
image, tuple(int(x * 255) for x in processor.image_mean)
|
1459 |
+
)
|
1460 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
1461 |
+
"pixel_values"
|
1462 |
+
][0]
|
1463 |
+
else:
|
1464 |
+
if self.data_args.image_aspect_ratio != "sam":
|
1465 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
1466 |
+
"pixel_values"
|
1467 |
+
][0]
|
1468 |
+
sources = preprocess_multimodal(
|
1469 |
+
copy.deepcopy([e["conversations"] for e in sources]), self.data_args
|
1470 |
+
)
|
1471 |
+
elif "video" in sources[0]:
|
1472 |
+
video_file = self.list_data_dict[i]["video"]
|
1473 |
+
video_folder = self.data_args.image_folder
|
1474 |
+
if "webvid" in video_folder:
|
1475 |
+
video_file = os.path.join(video_folder, "videos", video_file)
|
1476 |
+
elif "ActivityNet" in video_folder:
|
1477 |
+
video_file = os.path.join(video_folder, "train_val", video_file)
|
1478 |
+
else:
|
1479 |
+
video_file = os.path.join(video_folder, video_file)
|
1480 |
+
if not os.path.exists(video_file):
|
1481 |
+
print("nonexist: {}".format(video_file), flush=True)
|
1482 |
+
for sub_folder in os.listdir(video_folder):
|
1483 |
+
if os.path.isdir(os.path.join(video_folder, sub_folder)):
|
1484 |
+
for sub_sub_folder in os.listdir(
|
1485 |
+
os.path.join(video_folder, sub_folder)
|
1486 |
+
):
|
1487 |
+
print("folder", sub_folder, sub_sub_folder)
|
1488 |
+
has_image = False
|
1489 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
1490 |
+
else:
|
1491 |
+
if video_file.endswith(".webm"):
|
1492 |
+
has_image = False
|
1493 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
1494 |
+
else:
|
1495 |
+
try:
|
1496 |
+
# if video_file.endswith(".webm"):
|
1497 |
+
# video_webm = VideoFileClip(video_file)
|
1498 |
+
# video_frames = np.array(list(video_webm.iter_frames()))
|
1499 |
+
# sample_fps = round(video_webm.fps / self.data_args.video_fps)
|
1500 |
+
# frame_idx = [i for i in range(0, len(video_frames), sample_fps)]
|
1501 |
+
# video = video_frames[frame_idx]
|
1502 |
+
# else:
|
1503 |
+
vr = VideoReader(video_file, ctx=cpu(0), num_threads=1)
|
1504 |
+
sample_fps = round(vr.get_avg_fps() / self.data_args.video_fps)
|
1505 |
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
1506 |
+
video = vr.get_batch(frame_idx).asnumpy()
|
1507 |
+
if self.data_args.image_aspect_ratio == "sam":
|
1508 |
+
image = video[:, :, :, ::-1][:100]
|
1509 |
+
else:
|
1510 |
+
processor = self.data_args.image_processor
|
1511 |
+
image = processor.preprocess(video, return_tensors="pt")[
|
1512 |
+
"pixel_values"
|
1513 |
+
]
|
1514 |
+
sources = preprocess_multimodal(
|
1515 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
1516 |
+
self.data_args,
|
1517 |
+
)
|
1518 |
+
except:
|
1519 |
+
has_image = False
|
1520 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
1521 |
+
else:
|
1522 |
+
has_image = False
|
1523 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
1524 |
+
data_dict = preprocess(
|
1525 |
+
# pyre-fixme[6]: For 1st argument expected `Sequence[str]` but got
|
1526 |
+
# `Union[Dict[typing.Any, typing.Any], List[typing.Any]]`.
|
1527 |
+
sources,
|
1528 |
+
self.tokenizer,
|
1529 |
+
has_image=has_image,
|
1530 |
+
)
|
1531 |
+
if isinstance(i, int):
|
1532 |
+
data_dict = dict(
|
1533 |
+
input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]
|
1534 |
+
)
|
1535 |
+
|
1536 |
+
# image exist in the data
|
1537 |
+
if has_image:
|
1538 |
+
if "image" in self.list_data_dict[i]:
|
1539 |
+
# pyre-fixme[61]: Local variable `image` is undefined, or not always defined.
|
1540 |
+
data_dict["image"] = image
|
1541 |
+
elif "video" in self.list_data_dict[i]:
|
1542 |
+
# pyre-fixme[61]: Local variable `image` is undefined, or not always defined.
|
1543 |
+
data_dict["image"] = image
|
1544 |
+
elif self.data_args.is_multimodal:
|
1545 |
+
# image does not exist in the data, but the model is multimodal
|
1546 |
+
# crop_size = self.data_args.image_processor.crop_size
|
1547 |
+
# data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
|
1548 |
+
if self.data_args.image_aspect_ratio == "sam":
|
1549 |
+
if "video" in self.list_data_dict[i]:
|
1550 |
+
data_dict["image"] = np.zeros((1, 1024, 1024, 3)).astype(np.uint8)
|
1551 |
+
else:
|
1552 |
+
data_dict["image"] = np.zeros((1024, 1024, 3)).astype(np.uint8)
|
1553 |
+
else:
|
1554 |
+
crop_size = self.data_args.image_processor.crop_size
|
1555 |
+
if "video" in self.list_data_dict[i]:
|
1556 |
+
data_dict["image"] = torch.zeros(
|
1557 |
+
1, 3, crop_size["height"], crop_size["width"]
|
1558 |
+
)
|
1559 |
+
else:
|
1560 |
+
data_dict["image"] = torch.zeros(
|
1561 |
+
3, crop_size["height"], crop_size["width"]
|
1562 |
+
)
|
1563 |
+
|
1564 |
+
if has_image:
|
1565 |
+
if self.data_args.num_points > 0:
|
1566 |
+
if "box" in self.list_data_dict[i]:
|
1567 |
+
x1, y1, x2, y2 = self.list_data_dict[i]["box"]
|
1568 |
+
points = []
|
1569 |
+
x = random.uniform(x1, x2)
|
1570 |
+
y = random.uniform(y1, y2)
|
1571 |
+
points.append(torch.tensor([x, y, 1]))
|
1572 |
+
for _ in range(1, self.data_args.num_points):
|
1573 |
+
points.append(torch.tensor([0, 0, 0]))
|
1574 |
+
points = torch.stack(points, dim=0)
|
1575 |
+
data_dict["point"] = points
|
1576 |
+
else:
|
1577 |
+
if "point" in self.list_data_dict[i]:
|
1578 |
+
points = torch.tensor(self.list_data_dict[i]["point"])
|
1579 |
+
data_dict["point"] = points
|
1580 |
+
else:
|
1581 |
+
points = []
|
1582 |
+
grid = int(np.sqrt(self.data_args.num_points))
|
1583 |
+
height, width = image.shape[0], image.shape[1]
|
1584 |
+
for i in range(grid):
|
1585 |
+
for j in range(grid):
|
1586 |
+
points.append(
|
1587 |
+
torch.tensor(
|
1588 |
+
[
|
1589 |
+
width / grid / 2.0 + i / grid * width,
|
1590 |
+
height / grid / 2.0 + j / grid * height,
|
1591 |
+
1,
|
1592 |
+
]
|
1593 |
+
)
|
1594 |
+
)
|
1595 |
+
points = torch.stack(points, dim=0)
|
1596 |
+
data_dict["point"] = points
|
1597 |
+
elif self.data_args.is_multimodal:
|
1598 |
+
if self.data_args.num_points > 0:
|
1599 |
+
points = []
|
1600 |
+
grid = int(np.sqrt(self.data_args.num_points))
|
1601 |
+
height, width = data_dict["image"].shape[0], data_dict["image"].shape[1]
|
1602 |
+
for i in range(grid):
|
1603 |
+
for j in range(grid):
|
1604 |
+
points.append(
|
1605 |
+
torch.tensor(
|
1606 |
+
[
|
1607 |
+
width / grid / 2.0 + i / grid * width,
|
1608 |
+
height / grid / 2.0 + j / grid * height,
|
1609 |
+
1,
|
1610 |
+
]
|
1611 |
+
)
|
1612 |
+
)
|
1613 |
+
points = torch.stack(points, dim=0)
|
1614 |
+
data_dict["point"] = points
|
1615 |
+
|
1616 |
+
return data_dict
|
1617 |
+
|
1618 |
+
|
1619 |
+
@dataclass
|
1620 |
+
class DataCollatorForSupervisedDataset(object):
|
1621 |
+
"""Collate examples for supervised fine-tuning."""
|
1622 |
+
|
1623 |
+
tokenizer: transformers.PreTrainedTokenizer
|
1624 |
+
|
1625 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
1626 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
1627 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
1628 |
+
input_ids, labels = tuple(
|
1629 |
+
[instance[key] for instance in instances] for key in ("input_ids", "labels")
|
1630 |
+
)
|
1631 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
1632 |
+
input_ids,
|
1633 |
+
batch_first=True,
|
1634 |
+
# pyre-fixme[6]: For 3rd argument expected `float` but got `Optional[int]`.
|
1635 |
+
padding_value=self.tokenizer.pad_token_id,
|
1636 |
+
)
|
1637 |
+
labels = torch.nn.utils.rnn.pad_sequence(
|
1638 |
+
labels, batch_first=True, padding_value=IGNORE_INDEX
|
1639 |
+
)
|
1640 |
+
input_ids = input_ids[:, : self.tokenizer.model_max_length]
|
1641 |
+
labels = labels[:, : self.tokenizer.model_max_length]
|
1642 |
+
batch = dict(
|
1643 |
+
input_ids=input_ids,
|
1644 |
+
labels=labels,
|
1645 |
+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Optional[int]`.
|
1646 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
1647 |
+
)
|
1648 |
+
|
1649 |
+
# if "image" in instances[0]:
|
1650 |
+
# images = [instance["image"] for instance in instances]
|
1651 |
+
# if all(x is not None and x.shape == images[0].shape for x in images):
|
1652 |
+
# if type(images[0]) is torch.Tensor:
|
1653 |
+
# batch["images"] = torch.stack(images)
|
1654 |
+
# else:
|
1655 |
+
#
|
1656 |
+
# batch["images"] = np.stack(images)
|
1657 |
+
# else:
|
1658 |
+
#
|
1659 |
+
# # `List[typing.Any]`.
|
1660 |
+
# batch["images"] = images
|
1661 |
+
|
1662 |
+
if "image" in instances[0]:
|
1663 |
+
images = [instance["image"] for instance in instances]
|
1664 |
+
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `List[typing.Any]`.
|
1665 |
+
batch["images"] = images
|
1666 |
+
|
1667 |
+
if "point" in instances[0]:
|
1668 |
+
points = [instance["point"] for instance in instances]
|
1669 |
+
batch["points"] = torch.stack(points)
|
1670 |
+
|
1671 |
+
return batch
|
1672 |
+
|
1673 |
+
|
1674 |
+
def make_supervised_data_module(
|
1675 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
1676 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
1677 |
+
data_args,
|
1678 |
+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
|
1679 |
+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
|
1680 |
+
) -> Dict:
|
1681 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
1682 |
+
train_dataset = LazySupervisedDataset(
|
1683 |
+
tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args
|
1684 |
+
)
|
1685 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
1686 |
+
return dict(
|
1687 |
+
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
|
1688 |
+
)
|
longvu/mm_utils.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import base64
|
3 |
+
import math
|
4 |
+
from io import BytesIO
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from longvu.constants import IMAGE_TOKEN_INDEX
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from transformers import StoppingCriteria
|
11 |
+
|
12 |
+
|
13 |
+
def select_best_resolution(original_size, possible_resolutions):
|
14 |
+
"""
|
15 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
19 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
tuple: The best fit resolution in the format (width, height).
|
23 |
+
"""
|
24 |
+
original_width, original_height = original_size
|
25 |
+
best_fit = None
|
26 |
+
max_effective_resolution = 0
|
27 |
+
min_wasted_resolution = float("inf")
|
28 |
+
|
29 |
+
for width, height in possible_resolutions:
|
30 |
+
scale = min(width / original_width, height / original_height)
|
31 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(
|
32 |
+
original_height * scale
|
33 |
+
)
|
34 |
+
effective_resolution = min(
|
35 |
+
downscaled_width * downscaled_height, original_width * original_height
|
36 |
+
)
|
37 |
+
wasted_resolution = (width * height) - effective_resolution
|
38 |
+
|
39 |
+
if effective_resolution > max_effective_resolution or (
|
40 |
+
effective_resolution == max_effective_resolution
|
41 |
+
and wasted_resolution < min_wasted_resolution
|
42 |
+
):
|
43 |
+
max_effective_resolution = effective_resolution
|
44 |
+
min_wasted_resolution = wasted_resolution
|
45 |
+
best_fit = (width, height)
|
46 |
+
|
47 |
+
return best_fit
|
48 |
+
|
49 |
+
|
50 |
+
def resize_and_pad_image(image, target_resolution):
|
51 |
+
"""
|
52 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
image (PIL.Image.Image): The input image.
|
56 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
PIL.Image.Image: The resized and padded image.
|
60 |
+
"""
|
61 |
+
original_width, original_height = image.size
|
62 |
+
target_width, target_height = target_resolution
|
63 |
+
|
64 |
+
scale_w = target_width / original_width
|
65 |
+
scale_h = target_height / original_height
|
66 |
+
|
67 |
+
if scale_w < scale_h:
|
68 |
+
new_width = target_width
|
69 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
70 |
+
else:
|
71 |
+
new_height = target_height
|
72 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
73 |
+
|
74 |
+
# Resize the image
|
75 |
+
resized_image = image.resize((new_width, new_height))
|
76 |
+
|
77 |
+
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
|
78 |
+
paste_x = (target_width - new_width) // 2
|
79 |
+
paste_y = (target_height - new_height) // 2
|
80 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
81 |
+
|
82 |
+
return new_image
|
83 |
+
|
84 |
+
|
85 |
+
def divide_to_patches(image, patch_size):
|
86 |
+
"""
|
87 |
+
Divides an image into patches of a specified size.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
image (PIL.Image.Image): The input image.
|
91 |
+
patch_size (int): The size of each patch.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
95 |
+
"""
|
96 |
+
patches = []
|
97 |
+
width, height = image.size
|
98 |
+
for i in range(0, height, patch_size):
|
99 |
+
for j in range(0, width, patch_size):
|
100 |
+
box = (j, i, j + patch_size, i + patch_size)
|
101 |
+
patch = image.crop(box)
|
102 |
+
patches.append(patch)
|
103 |
+
|
104 |
+
return patches
|
105 |
+
|
106 |
+
|
107 |
+
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
108 |
+
"""
|
109 |
+
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
image_size (tuple): The size of the input image in the format (width, height).
|
113 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
114 |
+
patch_size (int): The size of each image patch.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
tuple: The shape of the image patch grid in the format (width, height).
|
118 |
+
"""
|
119 |
+
if type(grid_pinpoints) is list:
|
120 |
+
possible_resolutions = grid_pinpoints
|
121 |
+
else:
|
122 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
123 |
+
width, height = select_best_resolution(image_size, possible_resolutions)
|
124 |
+
return width // patch_size, height // patch_size
|
125 |
+
|
126 |
+
|
127 |
+
def process_anyres_image(image, processor, grid_pinpoints):
|
128 |
+
"""
|
129 |
+
Process an image with variable resolutions.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
image (PIL.Image.Image): The input image to be processed.
|
133 |
+
processor: The image processor object.
|
134 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
torch.Tensor: A tensor containing the processed image patches.
|
138 |
+
"""
|
139 |
+
if type(grid_pinpoints) is list:
|
140 |
+
possible_resolutions = grid_pinpoints
|
141 |
+
else:
|
142 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
143 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
144 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
145 |
+
|
146 |
+
patches = divide_to_patches(image_padded, processor.crop_size["height"])
|
147 |
+
|
148 |
+
image_original_resize = image.resize(
|
149 |
+
(processor.size["shortest_edge"], processor.size["shortest_edge"])
|
150 |
+
)
|
151 |
+
|
152 |
+
image_patches = [image_original_resize] + patches
|
153 |
+
image_patches = [
|
154 |
+
processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
|
155 |
+
for image_patch in image_patches
|
156 |
+
]
|
157 |
+
return torch.stack(image_patches, dim=0)
|
158 |
+
|
159 |
+
|
160 |
+
def load_image_from_base64(image):
|
161 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
162 |
+
|
163 |
+
|
164 |
+
def expand2square(pil_img, background_color):
|
165 |
+
width, height = pil_img.size
|
166 |
+
if width == height:
|
167 |
+
return pil_img
|
168 |
+
elif width > height:
|
169 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
170 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
171 |
+
return result
|
172 |
+
else:
|
173 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
174 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
175 |
+
return result
|
176 |
+
|
177 |
+
|
178 |
+
# def process_images(images, image_processor, model_cfg):
|
179 |
+
# image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
180 |
+
# new_images = []
|
181 |
+
# if image_aspect_ratio == 'pad':
|
182 |
+
# for image in images:
|
183 |
+
# image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
184 |
+
# image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
185 |
+
# new_images.append(image)
|
186 |
+
# elif image_aspect_ratio == "anyres":
|
187 |
+
# for image in images:
|
188 |
+
# image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
|
189 |
+
# new_images.append(image)
|
190 |
+
# else:
|
191 |
+
# return image_processor(images, return_tensors='pt')['pixel_values']
|
192 |
+
# if all(x.shape == new_images[0].shape for x in new_images):
|
193 |
+
# new_images = torch.stack(new_images, dim=0)
|
194 |
+
# return new_images
|
195 |
+
|
196 |
+
|
197 |
+
# multiple vision towers
|
198 |
+
def process_images(images, image_processor, model_cfg):
|
199 |
+
processor_aux_list = image_processor
|
200 |
+
new_images_aux_list = []
|
201 |
+
for image in images:
|
202 |
+
image_aux_list = []
|
203 |
+
for processor_aux in processor_aux_list:
|
204 |
+
image_aux = image
|
205 |
+
if hasattr(processor_aux, "image_mean"):
|
206 |
+
try:
|
207 |
+
target_resolution = processor_aux.crop_size["height"]
|
208 |
+
except:
|
209 |
+
target_resolution = processor_aux.size["height"]
|
210 |
+
image_aux = expand2square(
|
211 |
+
image_aux, tuple(int(x * 255) for x in processor_aux.image_mean)
|
212 |
+
).resize((target_resolution, target_resolution))
|
213 |
+
image_aux = processor_aux.preprocess(image_aux, return_tensors="pt")[
|
214 |
+
"pixel_values"
|
215 |
+
][0]
|
216 |
+
image_aux_list.append(image_aux)
|
217 |
+
new_images_aux_list.append(image_aux_list)
|
218 |
+
new_images_aux_list = [
|
219 |
+
list(batch_image_aux) for batch_image_aux in zip(*new_images_aux_list)
|
220 |
+
]
|
221 |
+
new_images_aux_list = [
|
222 |
+
torch.stack(image_aux).half().cuda() for image_aux in new_images_aux_list
|
223 |
+
]
|
224 |
+
return new_images_aux_list
|
225 |
+
|
226 |
+
|
227 |
+
def tokenizer_image_token(
|
228 |
+
prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
|
229 |
+
):
|
230 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
231 |
+
|
232 |
+
def insert_separator(X, sep):
|
233 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
234 |
+
|
235 |
+
input_ids = []
|
236 |
+
offset = 0
|
237 |
+
if (
|
238 |
+
len(prompt_chunks) > 0
|
239 |
+
and len(prompt_chunks[0]) > 0
|
240 |
+
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
241 |
+
):
|
242 |
+
offset = 1
|
243 |
+
input_ids.append(prompt_chunks[0][0])
|
244 |
+
|
245 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
246 |
+
input_ids.extend(x[offset:])
|
247 |
+
|
248 |
+
if return_tensors is not None:
|
249 |
+
if return_tensors == "pt":
|
250 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
251 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
252 |
+
return input_ids
|
253 |
+
|
254 |
+
|
255 |
+
def tokenizer_image_token_llama3(
|
256 |
+
prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
|
257 |
+
):
|
258 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
259 |
+
|
260 |
+
def insert_separator(X, sep):
|
261 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
262 |
+
|
263 |
+
input_ids = []
|
264 |
+
for x in insert_separator(prompt_chunks, [image_token_index]):
|
265 |
+
input_ids.extend(x)
|
266 |
+
|
267 |
+
if return_tensors is not None:
|
268 |
+
if return_tensors == "pt":
|
269 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
270 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
271 |
+
return input_ids
|
272 |
+
|
273 |
+
|
274 |
+
def get_model_name_from_path(model_path):
|
275 |
+
model_path = model_path.strip("/")
|
276 |
+
model_paths = model_path.split("/")
|
277 |
+
if model_paths[-1].startswith("checkpoint-"):
|
278 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
279 |
+
else:
|
280 |
+
return model_paths[-1]
|
281 |
+
|
282 |
+
|
283 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
284 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
285 |
+
self.keywords = keywords
|
286 |
+
self.keyword_ids = []
|
287 |
+
self.max_keyword_len = 0
|
288 |
+
for keyword in keywords:
|
289 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
290 |
+
if (
|
291 |
+
len(cur_keyword_ids) > 1
|
292 |
+
and cur_keyword_ids[0] == tokenizer.bos_token_id
|
293 |
+
):
|
294 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
295 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
296 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
297 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
298 |
+
self.tokenizer = tokenizer
|
299 |
+
self.start_len = input_ids.shape[1]
|
300 |
+
|
301 |
+
def call_for_batch(
|
302 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
303 |
+
) -> bool:
|
304 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
305 |
+
self.keyword_ids = [
|
306 |
+
keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
|
307 |
+
]
|
308 |
+
for keyword_id in self.keyword_ids:
|
309 |
+
truncated_output_ids = output_ids[0, -keyword_id.shape[0] :]
|
310 |
+
if torch.equal(truncated_output_ids, keyword_id):
|
311 |
+
return True
|
312 |
+
outputs = self.tokenizer.batch_decode(
|
313 |
+
output_ids[:, -offset:], skip_special_tokens=True
|
314 |
+
)[0]
|
315 |
+
for keyword in self.keywords:
|
316 |
+
if keyword in outputs:
|
317 |
+
return True
|
318 |
+
return False
|
319 |
+
|
320 |
+
def __call__(
|
321 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
322 |
+
) -> bool:
|
323 |
+
outputs = []
|
324 |
+
for i in range(output_ids.shape[0]):
|
325 |
+
# pyre-fixme[6]: For 1st argument expected `LongTensor` but got `Tensor`.
|
326 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
327 |
+
return all(outputs)
|
longvu/multimodal_encoder/__pycache__/base_encoder.cpython-310.pyc
ADDED
Binary file (4.33 kB). View file
|
|
longvu/multimodal_encoder/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (1 kB). View file
|
|
longvu/multimodal_encoder/__pycache__/dino_encoder.cpython-310.pyc
ADDED
Binary file (3.67 kB). View file
|
|
longvu/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc
ADDED
Binary file (2.6 kB). View file
|
|
longvu/multimodal_encoder/base_encoder.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class ProcessorWrapper:
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
transform,
|
11 |
+
height=378,
|
12 |
+
width=378,
|
13 |
+
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
14 |
+
):
|
15 |
+
self._crop_size = {
|
16 |
+
"height": height,
|
17 |
+
"width": width,
|
18 |
+
}
|
19 |
+
self._transforms = transform
|
20 |
+
# print(transform)
|
21 |
+
self.image_mean = image_mean
|
22 |
+
|
23 |
+
@property
|
24 |
+
def crop_size(self):
|
25 |
+
return self._crop_size
|
26 |
+
|
27 |
+
def preprocess(self, image, return_tensors="pt"):
|
28 |
+
# Ensure image is a PIL Image
|
29 |
+
output = {}
|
30 |
+
output["pixel_values"] = [self._transforms(image)]
|
31 |
+
return output
|
32 |
+
|
33 |
+
|
34 |
+
class BaseVisionTower(nn.Module):
|
35 |
+
def __init__(self, vision_tower_name, args, delay_load=False):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.is_loaded = False
|
39 |
+
self.args = args
|
40 |
+
|
41 |
+
self.vision_tower_name = vision_tower_name
|
42 |
+
self.select_layer = args.mm_vision_select_layer
|
43 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
44 |
+
self.unfreeze_mm_vision_tower = getattr(args, "unfreeze_mm_vision_tower", False)
|
45 |
+
self.delay_load = delay_load
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
def load_model(self, device_map=None):
|
49 |
+
raise NotImplementedError("Subclasses must implement load_model")
|
50 |
+
|
51 |
+
@abstractmethod
|
52 |
+
def _forward(self, images):
|
53 |
+
raise NotImplementedError("Subclasses must implement forward")
|
54 |
+
|
55 |
+
def forward(self, images):
|
56 |
+
if type(images) is list:
|
57 |
+
image_features = [self._forward(image.unsqueeze(0)) for image in images]
|
58 |
+
else:
|
59 |
+
image_features = self._forward(images)
|
60 |
+
|
61 |
+
return image_features
|
62 |
+
|
63 |
+
@property
|
64 |
+
def dummy_feature(self):
|
65 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
66 |
+
|
67 |
+
@property
|
68 |
+
def dtype(self):
|
69 |
+
# Dynamically infer the dtype from the first parameter, if not explicitly specified
|
70 |
+
if hasattr(self.vision_tower, "dtype"):
|
71 |
+
return self.vision_tower.dtype
|
72 |
+
else:
|
73 |
+
params = list(self.vision_tower.parameters())
|
74 |
+
return (
|
75 |
+
params[0].dtype if len(params) > 0 else torch.float32
|
76 |
+
) # Default to torch.float32 if no parameters
|
77 |
+
|
78 |
+
@property
|
79 |
+
def device(self):
|
80 |
+
# Dynamically infer the device from the first parameter, if not explicitly specified
|
81 |
+
if hasattr(self.vision_tower, "device"):
|
82 |
+
return self.vision_tower.device
|
83 |
+
else:
|
84 |
+
params = list(self.vision_tower.parameters())
|
85 |
+
return (
|
86 |
+
params[0].device if len(params) > 0 else torch.device("cpu")
|
87 |
+
) # Default to CPU if no parameters
|
88 |
+
|
89 |
+
@property
|
90 |
+
def config(self):
|
91 |
+
if self.is_loaded:
|
92 |
+
return self.vision_tower.config
|
93 |
+
else:
|
94 |
+
return self.cfg_only
|
95 |
+
|
96 |
+
@property
|
97 |
+
def hidden_size(self):
|
98 |
+
try:
|
99 |
+
return self.config.hidden_size
|
100 |
+
except:
|
101 |
+
return self._hidden_size
|
102 |
+
|
103 |
+
@property
|
104 |
+
def image_size(self): # resolution
|
105 |
+
# return self.config.image_size
|
106 |
+
try:
|
107 |
+
return self.config.image_size
|
108 |
+
except:
|
109 |
+
return self._image_size
|
110 |
+
|
111 |
+
@property
|
112 |
+
def patch_size(self):
|
113 |
+
# return self.config.patch_size
|
114 |
+
try:
|
115 |
+
return self.config.patch_size
|
116 |
+
except:
|
117 |
+
return self._patch_size
|
118 |
+
|
119 |
+
@property
|
120 |
+
def num_patches_per_side(self):
|
121 |
+
if self._interp_size is not None:
|
122 |
+
return int(self._interp_size**0.5)
|
123 |
+
try:
|
124 |
+
return self.image_size // self.patch_size
|
125 |
+
except:
|
126 |
+
return self._num_patches_per_side
|
127 |
+
|
128 |
+
@property
|
129 |
+
def num_patches(self):
|
130 |
+
if self._interp_size is not None:
|
131 |
+
return self._interp_size
|
132 |
+
try:
|
133 |
+
return self.num_patches_per_side**2
|
134 |
+
except:
|
135 |
+
return self._num_patches
|
longvu/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pyre-unsafe
|
2 |
+
import copy
|
3 |
+
|
4 |
+
from .dino_encoder import DinoVisionTower
|
5 |
+
from .siglip_encoder import SiglipVisionTower
|
6 |
+
|
7 |
+
|
8 |
+
def build_vision_tower_aux_list(vision_tower_cfg, **kwargs):
|
9 |
+
vision_tower_aux_name_list = getattr(
|
10 |
+
vision_tower_cfg,
|
11 |
+
"mm_vision_tower_aux_list",
|
12 |
+
getattr(vision_tower_cfg, "vision_tower_aux_list", None),
|
13 |
+
)
|
14 |
+
vision_tower_aux_token_len_list = getattr(
|
15 |
+
vision_tower_cfg,
|
16 |
+
"mm_vision_tower_aux_token_len_list",
|
17 |
+
getattr(vision_tower_cfg, "vision_tower_aux_token_len_list", None),
|
18 |
+
)
|
19 |
+
vision_tower_aux_list = []
|
20 |
+
for vision_tower_aux_name, vision_tower_aux_token_len in zip(
|
21 |
+
vision_tower_aux_name_list, vision_tower_aux_token_len_list
|
22 |
+
):
|
23 |
+
config = copy.deepcopy(vision_tower_cfg)
|
24 |
+
vision_tower_aux_name += "-interp{}".format(vision_tower_aux_token_len)
|
25 |
+
if "siglip" in vision_tower_aux_name.lower():
|
26 |
+
vision_tower_aux_list.append(
|
27 |
+
SiglipVisionTower(vision_tower_aux_name, args=config, **kwargs)
|
28 |
+
)
|
29 |
+
|
30 |
+
# SSL-based Vision Towers
|
31 |
+
elif "dinov2" in vision_tower_aux_name.lower():
|
32 |
+
vision_tower_aux_list.append(
|
33 |
+
DinoVisionTower(vision_tower_aux_name, args=config, **kwargs)
|
34 |
+
)
|
35 |
+
else:
|
36 |
+
raise ValueError(f"Unknown vision tower: {vision_tower_aux_name}")
|
37 |
+
return vision_tower_aux_list
|
longvu/multimodal_encoder/dino_encoder.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from transformers import AutoImageProcessor, Dinov2Config, Dinov2Model
|
5 |
+
|
6 |
+
from .base_encoder import BaseVisionTower, ProcessorWrapper
|
7 |
+
|
8 |
+
|
9 |
+
class DinoVisionTower(BaseVisionTower):
|
10 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
11 |
+
super(DinoVisionTower, self).__init__(vision_tower, args, delay_load)
|
12 |
+
|
13 |
+
model_path = "./checkpoints/dinov2-giant"
|
14 |
+
base_model_name, res, interp = model_path, 378, 576
|
15 |
+
self._vision_tower_name = vision_tower
|
16 |
+
self.vision_tower_name = base_model_name
|
17 |
+
self._image_size = res
|
18 |
+
self._interp_size = interp
|
19 |
+
self._patch_size = 14 # default patch size
|
20 |
+
|
21 |
+
if not self.delay_load:
|
22 |
+
self.load_model()
|
23 |
+
else:
|
24 |
+
self.cfg_only = Dinov2Config.from_pretrained(self.vision_tower_name)
|
25 |
+
|
26 |
+
def load_model(self, device_map=None):
|
27 |
+
|
28 |
+
self.vision_tower = Dinov2Model.from_pretrained(self.vision_tower_name)
|
29 |
+
"""ValueError: Dinov2Model does not support `device_map='auto'`. To implement support, the model class needs to implement the `_no_split_modules` attribute."""
|
30 |
+
self.vision_tower._no_split_modules = ["Dinov2SwiGLUFFN"]
|
31 |
+
|
32 |
+
_image_size = self.vision_tower.config.image_size
|
33 |
+
if self._image_size is None:
|
34 |
+
self._image_size = _image_size
|
35 |
+
|
36 |
+
# increase shortest edge to prevent edge case crops
|
37 |
+
default_shortest_ratio = 8 / 7 # 224/256
|
38 |
+
# shortest_edge = int(default_shortest_ratio * self._image_size)
|
39 |
+
shortest_edge = self._image_size
|
40 |
+
|
41 |
+
processor = AutoImageProcessor.from_pretrained(
|
42 |
+
self.vision_tower_name,
|
43 |
+
crop_size=dict(height=self._image_size, width=self._image_size),
|
44 |
+
size=dict(shortest_edge=shortest_edge),
|
45 |
+
)
|
46 |
+
self.image_processor = processor
|
47 |
+
|
48 |
+
# Assign the output channels of the projection convolution as the hidden size
|
49 |
+
self._hidden_size = (
|
50 |
+
self.vision_tower.embeddings.patch_embeddings.projection.out_channels
|
51 |
+
)
|
52 |
+
# Assign the first value of the stride of the projection convolution as the patch size
|
53 |
+
self._patch_size = (
|
54 |
+
self.vision_tower.embeddings.patch_embeddings.projection.stride[0]
|
55 |
+
)
|
56 |
+
|
57 |
+
# print(self._hidden_size, self._patch_size)
|
58 |
+
|
59 |
+
self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
|
60 |
+
self.is_loaded = True
|
61 |
+
|
62 |
+
@property
|
63 |
+
def image_size(self):
|
64 |
+
return self._image_size
|
65 |
+
|
66 |
+
def feature_select(self, outputs):
|
67 |
+
sequence_output = outputs[
|
68 |
+
"last_hidden_state"
|
69 |
+
] # batch_size, sequence_length, hidden_size
|
70 |
+
|
71 |
+
if self.select_feature == "cls_patch":
|
72 |
+
image_features = sequence_output
|
73 |
+
elif self.select_feature == "patch":
|
74 |
+
image_features = sequence_output[:, 1:]
|
75 |
+
elif self.select_feature == "cls":
|
76 |
+
image_features = sequence_output[:, 0]
|
77 |
+
else:
|
78 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
79 |
+
return image_features
|
80 |
+
|
81 |
+
def interpolate(self, image_features):
|
82 |
+
if self._interp_size is None:
|
83 |
+
return image_features
|
84 |
+
|
85 |
+
b, num_tokens, dim = image_features.shape
|
86 |
+
|
87 |
+
if num_tokens != self.num_patches:
|
88 |
+
target_h = target_w = int(self._interp_size**0.5)
|
89 |
+
h = w = int(num_tokens**0.5)
|
90 |
+
|
91 |
+
image_features = image_features.view(b, h, w, dim)
|
92 |
+
image_features = image_features.permute(0, 3, 1, 2).contiguous()
|
93 |
+
|
94 |
+
image_features = F.interpolate(
|
95 |
+
image_features.to(torch.float32),
|
96 |
+
size=(target_h, target_w),
|
97 |
+
mode="bilinear",
|
98 |
+
align_corners=False,
|
99 |
+
).to(image_features.dtype)
|
100 |
+
|
101 |
+
# Permute the dimensions back to (b, target_h, target_w, dim)
|
102 |
+
image_features = image_features.permute(0, 2, 3, 1).contiguous()
|
103 |
+
|
104 |
+
# Flatten the spatial dimensions (target_h, target_w) into a single dimension
|
105 |
+
image_features = image_features.flatten(1, 2)
|
106 |
+
|
107 |
+
return image_features
|
108 |
+
|
109 |
+
def _forward(self, images):
|
110 |
+
# logger.warning(f"images shape: {images.shape}")
|
111 |
+
with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
|
112 |
+
image_forward_outs = self.vision_tower.forward(
|
113 |
+
images.to(device=self.device, dtype=self.dtype)
|
114 |
+
)
|
115 |
+
# logger.warning(f"image_forward_outs shape: {image_forward_outs['last_hidden_state'].shape}")
|
116 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
117 |
+
# logger.warning(f"image_features shape: {image_features.shape}")
|
118 |
+
interp_features = self.interpolate(image_features)
|
119 |
+
# logger.warning(f"interp_features shape: {interp_features.shape}")
|
120 |
+
return interp_features
|
121 |
+
|
122 |
+
@property
|
123 |
+
def num_patches_per_side(self):
|
124 |
+
return int(self.num_patches**0.5)
|
125 |
+
|
126 |
+
@property
|
127 |
+
def num_patches(self):
|
128 |
+
if self._interp_size is None:
|
129 |
+
return (self._image_size // self._patch_size) ** 2
|
130 |
+
else:
|
131 |
+
return self._interp_size
|
longvu/multimodal_encoder/drop.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
|
17 |
+
# pyre-unsafe
|
18 |
+
"""Drop regularization layers."""
|
19 |
+
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
|
23 |
+
class DropPathV(nn.Module):
|
24 |
+
"""Set examples to zero randomly."""
|
25 |
+
|
26 |
+
def __init__(self, p=0.1, inplace=False):
|
27 |
+
super(DropPathV, self).__init__()
|
28 |
+
self.p = p
|
29 |
+
self.inplace = inplace
|
30 |
+
|
31 |
+
def forward(self, input):
|
32 |
+
if not self.training or self.p <= 0:
|
33 |
+
return input
|
34 |
+
keep_p = 1 - self.p
|
35 |
+
shape = (input.shape[0],) + (1,) * (input.dim() - 1)
|
36 |
+
scale = input.new_empty(shape).bernoulli_(keep_p).div_(keep_p)
|
37 |
+
return input.mul_(scale) if self.inplace else input.mul(scale)
|
38 |
+
|
39 |
+
def extra_repr(self):
|
40 |
+
inplace_str = ", inplace" if self.inplace else ""
|
41 |
+
return "p={}{}".format(self.p, inplace_str)
|
longvu/multimodal_encoder/image.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
|
17 |
+
# pyre-unsafe
|
18 |
+
"""Image utilities."""
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import PIL.Image
|
22 |
+
import torch
|
23 |
+
|
24 |
+
|
25 |
+
def im_resize(img, size=None, scale=None, mode="linear"):
|
26 |
+
"""Resize image by the scale or size."""
|
27 |
+
if size is None:
|
28 |
+
if not isinstance(scale, (tuple, list)):
|
29 |
+
scale = (scale, scale)
|
30 |
+
h, w = img.shape[:2]
|
31 |
+
size = int(h * scale[0] + 0.5), int(w * scale[1] + 0.5)
|
32 |
+
else:
|
33 |
+
if not isinstance(size, (tuple, list)):
|
34 |
+
size = (size, size)
|
35 |
+
resize_modes = {"linear": PIL.Image.BILINEAR}
|
36 |
+
from torchvision.transforms import ToPILImage
|
37 |
+
|
38 |
+
to_pil = ToPILImage()
|
39 |
+
img = to_pil(img.to(torch.float32).cpu())
|
40 |
+
# img = PIL.Image.fromarray(img)
|
41 |
+
return np.array(img.resize(size[::-1], resize_modes[mode]))
|
42 |
+
|
43 |
+
|
44 |
+
def im_rescale(img, scales, max_size=0):
|
45 |
+
"""Rescale image to match the detecting scales."""
|
46 |
+
im_shape = img.shape
|
47 |
+
img_list, img_scales = [], []
|
48 |
+
size_min = np.min(im_shape[:2])
|
49 |
+
size_max = np.max(im_shape[:2])
|
50 |
+
for target_size in scales:
|
51 |
+
im_scale = float(target_size) / float(size_min)
|
52 |
+
target_size_max = max_size if max_size > 0 else target_size
|
53 |
+
if np.round(im_scale * size_max) > target_size_max:
|
54 |
+
im_scale = float(target_size_max) / float(size_max)
|
55 |
+
img_list.append(im_resize(img, scale=im_scale))
|
56 |
+
img_scales.append((im_scale, im_scale))
|
57 |
+
return img_list, img_scales
|
58 |
+
|
59 |
+
|
60 |
+
def im_vstack(arrays, fill_value=None, dtype=None, size=None, align=None):
|
61 |
+
"""Stack image arrays in sequence vertically."""
|
62 |
+
if fill_value is None:
|
63 |
+
return np.vstack(arrays)
|
64 |
+
# Compute the max stack shape.
|
65 |
+
max_shape = np.max(np.stack([arr.shape for arr in arrays]), 0)
|
66 |
+
if size is not None and min(size) > 0:
|
67 |
+
max_shape[: len(size)] = size
|
68 |
+
if align is not None and min(align) > 0:
|
69 |
+
align_size = np.ceil(max_shape[: len(align)] / align)
|
70 |
+
max_shape[: len(align)] = align_size.astype("int64") * align
|
71 |
+
# Fill output with the given value.
|
72 |
+
output_dtype = dtype or arrays[0].dtype
|
73 |
+
output_shape = [len(arrays)] + list(max_shape)
|
74 |
+
output = np.empty(output_shape, output_dtype)
|
75 |
+
output[:] = fill_value
|
76 |
+
# Copy arrays.
|
77 |
+
for i, arr in enumerate(arrays):
|
78 |
+
copy_slices = (slice(0, d) for d in arr.shape)
|
79 |
+
output[(i,) + tuple(copy_slices)] = arr
|
80 |
+
return output
|
longvu/multimodal_encoder/logging.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
|
17 |
+
# pyre-unsafe
|
18 |
+
"""Logging utilities."""
|
19 |
+
|
20 |
+
import inspect
|
21 |
+
import logging as _logging
|
22 |
+
import os
|
23 |
+
import sys as _sys
|
24 |
+
import threading
|
25 |
+
|
26 |
+
|
27 |
+
_logger = None
|
28 |
+
_logger_lock = threading.Lock()
|
29 |
+
|
30 |
+
|
31 |
+
def get_logger():
|
32 |
+
global _logger
|
33 |
+
# Use double-checked locking to avoid taking lock unnecessarily.
|
34 |
+
if _logger:
|
35 |
+
return _logger
|
36 |
+
_logger_lock.acquire()
|
37 |
+
try:
|
38 |
+
if _logger:
|
39 |
+
return _logger
|
40 |
+
logger = _logging.getLogger("tokenize-anything")
|
41 |
+
logger.setLevel("INFO")
|
42 |
+
logger.propagate = False
|
43 |
+
logger._is_root = True
|
44 |
+
if True:
|
45 |
+
# Determine whether we are in an interactive environment.
|
46 |
+
_interactive = False
|
47 |
+
try:
|
48 |
+
# This is only defined in interactive shells.
|
49 |
+
if _sys.ps1:
|
50 |
+
_interactive = True
|
51 |
+
except AttributeError:
|
52 |
+
# Even now, we may be in an interactive shell with `python -i`.
|
53 |
+
_interactive = _sys.flags.interactive
|
54 |
+
# If we are in an interactive environment (like Jupyter), set loglevel
|
55 |
+
# to INFO and pipe the output to stdout.
|
56 |
+
if _interactive:
|
57 |
+
logger.setLevel("INFO")
|
58 |
+
_logging_target = _sys.stdout
|
59 |
+
else:
|
60 |
+
_logging_target = _sys.stderr
|
61 |
+
# Add the output handler.
|
62 |
+
_handler = _logging.StreamHandler(_logging_target)
|
63 |
+
_handler.setFormatter(_logging.Formatter("%(levelname)s %(message)s"))
|
64 |
+
logger.addHandler(_handler)
|
65 |
+
_logger = logger
|
66 |
+
return _logger
|
67 |
+
finally:
|
68 |
+
_logger_lock.release()
|
69 |
+
|
70 |
+
|
71 |
+
def _detailed_msg(msg):
|
72 |
+
file, lineno = inspect.stack()[:3][2][1:3]
|
73 |
+
return "{}:{}] {}".format(os.path.split(file)[-1], lineno, msg)
|
74 |
+
|
75 |
+
|
76 |
+
def log(level, msg, *args, **kwargs):
|
77 |
+
get_logger().log(level, _detailed_msg(msg), *args, **kwargs)
|
78 |
+
|
79 |
+
|
80 |
+
def debug(msg, *args, **kwargs):
|
81 |
+
if is_root():
|
82 |
+
get_logger().debug(_detailed_msg(msg), *args, **kwargs)
|
83 |
+
|
84 |
+
|
85 |
+
def error(msg, *args, **kwargs):
|
86 |
+
get_logger().error(_detailed_msg(msg), *args, **kwargs)
|
87 |
+
assert 0
|
88 |
+
|
89 |
+
|
90 |
+
def fatal(msg, *args, **kwargs):
|
91 |
+
get_logger().fatal(_detailed_msg(msg), *args, **kwargs)
|
92 |
+
assert 0
|
93 |
+
|
94 |
+
|
95 |
+
def info(msg, *args, **kwargs):
|
96 |
+
if is_root():
|
97 |
+
get_logger().info(_detailed_msg(msg), *args, **kwargs)
|
98 |
+
|
99 |
+
|
100 |
+
def warning(msg, *args, **kwargs):
|
101 |
+
if is_root():
|
102 |
+
get_logger().warning(_detailed_msg(msg), *args, **kwargs)
|
103 |
+
|
104 |
+
|
105 |
+
def get_verbosity():
|
106 |
+
"""Return how much logging output will be produced."""
|
107 |
+
return get_logger().getEffectiveLevel()
|
108 |
+
|
109 |
+
|
110 |
+
def set_verbosity(v):
|
111 |
+
"""Set the threshold for what messages will be logged."""
|
112 |
+
get_logger().setLevel(v)
|
113 |
+
|
114 |
+
|
115 |
+
def set_formatter(fmt=None, datefmt=None):
|
116 |
+
"""Set the formatter."""
|
117 |
+
handler = _logging.StreamHandler(_sys.stderr)
|
118 |
+
handler.setFormatter(_logging.Formatter(fmt, datefmt))
|
119 |
+
logger = get_logger()
|
120 |
+
logger.removeHandler(logger.handlers[0])
|
121 |
+
logger.addHandler(handler)
|
122 |
+
|
123 |
+
|
124 |
+
def set_root(is_root=True):
|
125 |
+
"""Set logger to the root."""
|
126 |
+
get_logger()._is_root = is_root
|
127 |
+
|
128 |
+
|
129 |
+
def is_root():
|
130 |
+
"""Return logger is the root."""
|
131 |
+
return get_logger()._is_root
|
longvu/multimodal_encoder/loss.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
|
17 |
+
# pyre-unsafe
|
18 |
+
"""Loss layers."""
|
19 |
+
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
|
23 |
+
def reduce_loss(loss, reduction="mean"):
|
24 |
+
"""Reduce the loss."""
|
25 |
+
if reduction == "mean" or reduction == "sum":
|
26 |
+
return getattr(loss, reduction)()
|
27 |
+
if reduction == "batch_mean":
|
28 |
+
return loss.sum().mul_(1.0 / loss.size(0))
|
29 |
+
return loss
|
30 |
+
|
31 |
+
|
32 |
+
class BinaryFocalLoss(nn.Module):
|
33 |
+
"""Binary focal loss."""
|
34 |
+
|
35 |
+
def __init__(self, alpha=0.25, reduction="none"):
|
36 |
+
super(BinaryFocalLoss, self).__init__()
|
37 |
+
self.alpha = alpha
|
38 |
+
self.reduction = reduction
|
39 |
+
|
40 |
+
def forward(self, input, target):
|
41 |
+
alpha, p = self.alpha, input.sigmoid()
|
42 |
+
neg_alpha, neg_target = 1.0 - alpha, 1.0 - target
|
43 |
+
alpha_weight = target.mul(alpha).add_(neg_target.mul(neg_alpha))
|
44 |
+
focal_weight = (1.0 - p).mul_(target).add_(p.mul(neg_target)).square()
|
45 |
+
loss = nn.functional.binary_cross_entropy_with_logits(
|
46 |
+
input, target, reduction="none"
|
47 |
+
)
|
48 |
+
return reduce_loss(loss * focal_weight.mul_(alpha_weight), self.reduction)
|
49 |
+
|
50 |
+
|
51 |
+
class BinaryDiceLoss(nn.Module):
|
52 |
+
"""Binary dice loss."""
|
53 |
+
|
54 |
+
def __init__(self, eps=1.0, reduction="none"):
|
55 |
+
super(BinaryDiceLoss, self).__init__()
|
56 |
+
self.eps = eps
|
57 |
+
self.reduction = reduction
|
58 |
+
|
59 |
+
def forward(self, input, target):
|
60 |
+
input = input.sigmoid()
|
61 |
+
num = input.mul(target).sum(-1).mul_(2).add_(self.eps)
|
62 |
+
den = input.add(target).sum(-1).add_(self.eps)
|
63 |
+
return reduce_loss(1.0 - num / den, self.reduction)
|
64 |
+
|
65 |
+
|
66 |
+
class CrossEntropyLoss(nn.Module):
|
67 |
+
"""Cross entropy loss with label smoothing."""
|
68 |
+
|
69 |
+
def __init__(self, epsilon=0, reduction="none"):
|
70 |
+
super(CrossEntropyLoss, self).__init__()
|
71 |
+
self.epsilon = epsilon
|
72 |
+
self.reduction = reduction
|
73 |
+
|
74 |
+
def forward_dense(self, input, target):
|
75 |
+
dim, target = input.shape[-1], target.squeeze_()
|
76 |
+
x = nn.functional.log_softmax(input, dim=-1)
|
77 |
+
y = nn.functional.one_hot(target, dim).float()
|
78 |
+
x = (
|
79 |
+
x.permute([0, x.dim() - 1] + list(range(x.dim()))[1:-1])
|
80 |
+
if x.dim() > 2
|
81 |
+
else x
|
82 |
+
)
|
83 |
+
y = (
|
84 |
+
y.permute([0, y.dim() - 1] + list(range(y.dim()))[1:-1])
|
85 |
+
if y.dim() > 2
|
86 |
+
else y
|
87 |
+
)
|
88 |
+
loss = nn.functional.cross_entropy(
|
89 |
+
x, y, reduction="none", label_smoothing=self.epsilon
|
90 |
+
)
|
91 |
+
return reduce_loss(loss, self.reduction)
|
92 |
+
|
93 |
+
def forward(self, input, target):
|
94 |
+
if self.epsilon > 0:
|
95 |
+
return self.forward_dense(input, target)
|
96 |
+
return nn.functional.cross_entropy(input, target, reduction=self.reduction)
|
longvu/multimodal_encoder/registry.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
|
17 |
+
# pyre-unsafe
|
18 |
+
"""Registry utilities."""
|
19 |
+
|
20 |
+
import collections
|
21 |
+
import functools
|
22 |
+
|
23 |
+
|
24 |
+
class Registry(object):
|
25 |
+
"""Registry class."""
|
26 |
+
|
27 |
+
def __init__(self, name):
|
28 |
+
self.name = name
|
29 |
+
self.registry = collections.OrderedDict()
|
30 |
+
|
31 |
+
def has(self, key):
|
32 |
+
return key in self.registry
|
33 |
+
|
34 |
+
def register(self, name, func=None, **kwargs):
|
35 |
+
def decorated(inner_function):
|
36 |
+
for key in name if isinstance(name, (tuple, list)) else [name]:
|
37 |
+
self.registry[key] = functools.partial(inner_function, **kwargs)
|
38 |
+
return inner_function
|
39 |
+
|
40 |
+
if func is not None:
|
41 |
+
return decorated(func)
|
42 |
+
return decorated
|
43 |
+
|
44 |
+
def get(self, name, default=None):
|
45 |
+
if name is None:
|
46 |
+
return None
|
47 |
+
if not self.has(name):
|
48 |
+
if default is not None:
|
49 |
+
return default
|
50 |
+
raise KeyError("`%s` is not registered in <%s>." % (name, self.name))
|
51 |
+
return self.registry[name]
|
52 |
+
|
53 |
+
def try_get(self, name):
|
54 |
+
if self.has(name):
|
55 |
+
return self.get(name)
|
56 |
+
return None
|
longvu/multimodal_encoder/siglip_encoder.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
|
5 |
+
|
6 |
+
from .base_encoder import BaseVisionTower, ProcessorWrapper
|
7 |
+
|
8 |
+
|
9 |
+
class SiglipVisionTower(BaseVisionTower):
|
10 |
+
def __init__(self, vision_tower_name, args, delay_load=False):
|
11 |
+
super(SiglipVisionTower, self).__init__(vision_tower_name, args, delay_load)
|
12 |
+
|
13 |
+
model_path = "./checkpoints/siglip-so400m-patch14-384"
|
14 |
+
base_model_name, res, interp = model_path, 384, 576
|
15 |
+
self.vision_tower_name = base_model_name
|
16 |
+
self._image_size = res if res is not None else 512
|
17 |
+
self._interp_size = interp
|
18 |
+
if not self.delay_load:
|
19 |
+
self.load_model()
|
20 |
+
elif self.unfreeze_mm_vision_tower:
|
21 |
+
self.load_model()
|
22 |
+
else:
|
23 |
+
self._hidden_size = 1152
|
24 |
+
|
25 |
+
def load_model(self, device_map=None):
|
26 |
+
self.vision_model = "siglip"
|
27 |
+
# clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
|
28 |
+
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
|
29 |
+
|
30 |
+
# self.vision_tower = clip_model.visual.trunk
|
31 |
+
self.vision_tower.output_tokens = True
|
32 |
+
|
33 |
+
self._hidden_size = self.vision_tower.config.hidden_size
|
34 |
+
self._image_size = self.vision_tower.config.image_size
|
35 |
+
self._patch_size = self.vision_tower.config.patch_size
|
36 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(
|
37 |
+
self.vision_tower_name
|
38 |
+
)
|
39 |
+
|
40 |
+
self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
|
41 |
+
self.is_loaded = True
|
42 |
+
|
43 |
+
def interpolate(self, image_features):
|
44 |
+
if self._interp_size is None:
|
45 |
+
return image_features
|
46 |
+
|
47 |
+
b, num_tokens, dim = image_features.shape
|
48 |
+
|
49 |
+
if num_tokens != self.num_patches:
|
50 |
+
target_h = target_w = int(self._interp_size**0.5)
|
51 |
+
h = w = int(num_tokens**0.5)
|
52 |
+
|
53 |
+
image_features = image_features.view(b, h, w, dim)
|
54 |
+
image_features = image_features.permute(0, 3, 1, 2).contiguous()
|
55 |
+
|
56 |
+
image_features = F.interpolate(
|
57 |
+
image_features.to(torch.float32),
|
58 |
+
size=(target_h, target_w),
|
59 |
+
mode="bilinear",
|
60 |
+
align_corners=False,
|
61 |
+
).to(image_features.dtype)
|
62 |
+
|
63 |
+
# Permute the dimensions back to (b, target_h, target_w, dim)
|
64 |
+
image_features = image_features.permute(0, 2, 3, 1).contiguous()
|
65 |
+
|
66 |
+
# Flatten the spatial dimensions (target_h, target_w) into a single dimension
|
67 |
+
image_features = image_features.flatten(1, 2)
|
68 |
+
|
69 |
+
return image_features
|
70 |
+
|
71 |
+
def _forward(self, images, interpolate_token=576):
|
72 |
+
with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
|
73 |
+
image_features = self.vision_tower.forward(
|
74 |
+
images.to(device=self.device, dtype=self.dtype),
|
75 |
+
output_hidden_states=True,
|
76 |
+
).hidden_states[-1]
|
77 |
+
interp_features = self.interpolate(image_features)
|
78 |
+
return interp_features
|
longvu/multimodal_encoder/utils.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
|
17 |
+
# pyre-unsafe
|
18 |
+
"""Layer utilities."""
|
19 |
+
|
20 |
+
import cv2
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
|
24 |
+
|
25 |
+
def init_cross_conv(blocks):
|
26 |
+
"""Initialize convolutional cross attention."""
|
27 |
+
for m in blocks.modules():
|
28 |
+
if isinstance(m, torch.nn.Conv2d):
|
29 |
+
torch.nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
30 |
+
for blk in blocks:
|
31 |
+
torch.nn.init.constant_(blk.norm3.weight, 0)
|
32 |
+
|
33 |
+
|
34 |
+
def set_dropout(module, dropout):
|
35 |
+
"""Initialize dropout."""
|
36 |
+
for m in [m for m in module.modules() if isinstance(m, torch.nn.Dropout)]:
|
37 |
+
m.p = dropout
|
38 |
+
|
39 |
+
|
40 |
+
def set_drop_path(blocks, drop_path):
|
41 |
+
"""Initialize drop path."""
|
42 |
+
if not isinstance(blocks, torch.nn.ModuleList):
|
43 |
+
blocks = getattr(blocks, "blocks", getattr(blocks, "layers", None))
|
44 |
+
for i, blk in enumerate(blocks):
|
45 |
+
for m in [m for m in blk.modules() if type(m).__name__ == "DropPath"]:
|
46 |
+
m.p = i * drop_path / (len(blocks) - 1)
|
47 |
+
|
48 |
+
|
49 |
+
def set_sync_batch_norm(module, ddp_group):
|
50 |
+
"""Set data parallelism group for sync batch norm."""
|
51 |
+
for m in module.modules():
|
52 |
+
if isinstance(m, torch.nn.SyncBatchNorm):
|
53 |
+
m.process_group = ddp_group
|
54 |
+
|
55 |
+
|
56 |
+
def resize_pos_embed(weight, out_len):
|
57 |
+
"""Resize position embedding weights."""
|
58 |
+
out_h = out_w = int(out_len**0.5)
|
59 |
+
h = w = int(weight.shape[0] ** 0.5)
|
60 |
+
weight = weight.reshape((h, w, weight.shape[1]))
|
61 |
+
out_weight = [
|
62 |
+
cv2.resize(x, (out_w, out_h), interpolation=cv2.INTER_CUBIC)
|
63 |
+
for x in np.split(weight.astype("float32", copy=False), 4, axis=-1)
|
64 |
+
]
|
65 |
+
out_weight = np.concatenate(out_weight, axis=-1)
|
66 |
+
return out_weight.reshape((-1, weight.shape[-1])).astype(weight.dtype, copy=False)
|
longvu/multimodal_projector/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (2.01 kB). View file
|
|
longvu/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pyre-unsafe
|
2 |
+
import re
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class IdentityMap(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
def forward(self, x, *args, **kwargs):
|
12 |
+
return x
|
13 |
+
|
14 |
+
@property
|
15 |
+
def config(self):
|
16 |
+
return {"mm_projector_type": "identity"}
|
17 |
+
|
18 |
+
|
19 |
+
class SimpleResBlock(nn.Module):
|
20 |
+
def __init__(self, channels):
|
21 |
+
super().__init__()
|
22 |
+
self.pre_norm = nn.LayerNorm(channels)
|
23 |
+
|
24 |
+
self.proj = nn.Sequential(
|
25 |
+
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.pre_norm(x)
|
30 |
+
return x + self.proj(x)
|
31 |
+
|
32 |
+
|
33 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
34 |
+
projector_type = getattr(config, "mm_projector_type", "linear")
|
35 |
+
config.mm_hidden_size = 256
|
36 |
+
|
37 |
+
if projector_type == "linear":
|
38 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
39 |
+
|
40 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
41 |
+
if mlp_gelu_match:
|
42 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
43 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
44 |
+
for _ in range(1, mlp_depth):
|
45 |
+
modules.append(nn.GELU())
|
46 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
47 |
+
return nn.Sequential(*modules)
|
48 |
+
|
49 |
+
if projector_type == "identity":
|
50 |
+
return IdentityMap()
|
51 |
+
|
52 |
+
raise ValueError(f"Unknown projector type: {projector_type}")
|
longvu/utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pyre-unsafe
|
2 |
+
from transformers import AutoConfig
|
3 |
+
|
4 |
+
|
5 |
+
def auto_upgrade(config):
|
6 |
+
cfg = AutoConfig.from_pretrained(config)
|
7 |
+
if "llava" in config and "llava" not in cfg.model_type:
|
8 |
+
assert cfg.model_type == "llama"
|
9 |
+
print(
|
10 |
+
"You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
|
11 |
+
)
|
12 |
+
print(
|
13 |
+
"You must upgrade the checkpoint to the new code base (this can be done automatically)."
|
14 |
+
)
|
15 |
+
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
16 |
+
if confirm.lower() in ["y", "yes"]:
|
17 |
+
print("Upgrading checkpoint...")
|
18 |
+
assert len(cfg.architectures) == 1
|
19 |
+
setattr(cfg.__class__, "model_type", "llava")
|
20 |
+
cfg.architectures[0] = "LlavaLlamaForCausalLM"
|
21 |
+
cfg.save_pretrained(config)
|
22 |
+
print("Checkpoint upgraded.")
|
23 |
+
else:
|
24 |
+
print("Checkpoint upgrade aborted.")
|
25 |
+
exit(1)
|
longvu/vision_sampler.py
ADDED
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.utils.checkpoint
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
10 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
11 |
+
"""
|
12 |
+
grid_size: int of the grid height and width
|
13 |
+
return:
|
14 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
15 |
+
"""
|
16 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
17 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
18 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
19 |
+
grid = np.stack(grid, axis=0)
|
20 |
+
|
21 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
22 |
+
|
23 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
24 |
+
if cls_token:
|
25 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
26 |
+
return pos_embed
|
27 |
+
|
28 |
+
|
29 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
30 |
+
assert embed_dim % 2 == 0
|
31 |
+
|
32 |
+
# use half of dimensions to encode grid_h
|
33 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
34 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
35 |
+
|
36 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
37 |
+
return emb
|
38 |
+
|
39 |
+
|
40 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
41 |
+
"""
|
42 |
+
embed_dim: output dimension for each position
|
43 |
+
pos: a list of positions to be encoded: size (M,)
|
44 |
+
out: (M, D)
|
45 |
+
"""
|
46 |
+
assert embed_dim % 2 == 0
|
47 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
48 |
+
omega /= embed_dim / 2.0
|
49 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
50 |
+
|
51 |
+
pos = pos.reshape(-1) # (M,)
|
52 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
53 |
+
|
54 |
+
emb_sin = np.sin(out) # (M, D/2)
|
55 |
+
emb_cos = np.cos(out) # (M, D/2)
|
56 |
+
|
57 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
58 |
+
return emb
|
59 |
+
|
60 |
+
|
61 |
+
class CrossAttention(nn.Module):
|
62 |
+
|
63 |
+
def __init__(self, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False):
|
64 |
+
super().__init__()
|
65 |
+
self.hidden_dim = hidden_dim
|
66 |
+
self.num_heads = num_heads
|
67 |
+
self.head_dim = self.hidden_dim // self.num_heads
|
68 |
+
|
69 |
+
if (self.head_dim * self.num_heads) != self.hidden_dim:
|
70 |
+
raise ValueError(
|
71 |
+
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
|
72 |
+
f" and `num_heads`: {self.num_heads})."
|
73 |
+
)
|
74 |
+
|
75 |
+
self.q_proj = nn.Sequential(
|
76 |
+
nn.LayerNorm(q_dim),
|
77 |
+
nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
|
78 |
+
)
|
79 |
+
self.k_proj = nn.Sequential(
|
80 |
+
nn.LayerNorm(kv_dim),
|
81 |
+
nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
|
82 |
+
)
|
83 |
+
self.v_proj = nn.Sequential(
|
84 |
+
nn.LayerNorm(kv_dim),
|
85 |
+
nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
|
86 |
+
)
|
87 |
+
self.o_proj = nn.Linear(
|
88 |
+
self.num_heads * self.head_dim, q_dim, bias=attention_bias
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, vision_latents, queries, attention_mask):
|
92 |
+
|
93 |
+
bsz, q_len, _ = queries.size()
|
94 |
+
bsz, v_len, _ = vision_latents.size()
|
95 |
+
|
96 |
+
query_states = self.q_proj(queries)
|
97 |
+
key_states = self.k_proj(vision_latents)
|
98 |
+
value_states = self.v_proj(vision_latents)
|
99 |
+
|
100 |
+
query_states = query_states.view(
|
101 |
+
bsz, q_len, self.num_heads, self.head_dim
|
102 |
+
).transpose(1, 2)
|
103 |
+
key_states = key_states.view(
|
104 |
+
bsz, v_len, self.num_heads, self.head_dim
|
105 |
+
).transpose(1, 2)
|
106 |
+
value_states = value_states.view(
|
107 |
+
bsz, v_len, self.num_heads, self.head_dim
|
108 |
+
).transpose(1, 2)
|
109 |
+
|
110 |
+
if attention_mask is not None:
|
111 |
+
if attention_mask.size() != (bsz, 1, q_len, v_len):
|
112 |
+
raise ValueError(
|
113 |
+
f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
|
114 |
+
)
|
115 |
+
|
116 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
117 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
118 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
119 |
+
query_states = query_states.contiguous()
|
120 |
+
key_states = key_states.contiguous()
|
121 |
+
value_states = value_states.contiguous()
|
122 |
+
|
123 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
124 |
+
query_states,
|
125 |
+
key_states,
|
126 |
+
value_states,
|
127 |
+
attn_mask=attention_mask,
|
128 |
+
)
|
129 |
+
|
130 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
131 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
|
132 |
+
|
133 |
+
attn_output = self.o_proj(attn_output)
|
134 |
+
|
135 |
+
return attn_output
|
136 |
+
|
137 |
+
|
138 |
+
class AggregationBlock(nn.Module):
|
139 |
+
def __init__(
|
140 |
+
self, attention, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False
|
141 |
+
):
|
142 |
+
super().__init__()
|
143 |
+
self.hidden_dim = hidden_dim
|
144 |
+
self.num_heads = num_heads
|
145 |
+
self.head_dim = self.hidden_dim // self.num_heads
|
146 |
+
|
147 |
+
if (self.head_dim * self.num_heads) != self.hidden_dim:
|
148 |
+
raise ValueError(
|
149 |
+
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
|
150 |
+
f" and `num_heads`: {self.num_heads})."
|
151 |
+
)
|
152 |
+
|
153 |
+
self.attention = attention
|
154 |
+
if attention:
|
155 |
+
self.attention_layer = CrossAttention(
|
156 |
+
q_dim, kv_dim, hidden_dim, num_heads, attention_bias
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
self.attention_layer = MLP(kv_dim, q_dim, q_dim)
|
160 |
+
|
161 |
+
def forward(self, vision_latents, queries, attention_mask):
|
162 |
+
if self.attention:
|
163 |
+
queries = self.attention_layer(vision_latents, queries, attention_mask)
|
164 |
+
else:
|
165 |
+
queries = self.attention_layer(vision_latents)
|
166 |
+
|
167 |
+
return queries
|
168 |
+
|
169 |
+
|
170 |
+
class MultiKVCrossAttention(nn.Module):
|
171 |
+
|
172 |
+
def __init__(self, q_dim, kv_dim_list, hidden_dim, num_heads, attention_bias=False):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.hidden_dim = hidden_dim
|
176 |
+
self.num_heads = num_heads
|
177 |
+
self.head_dim = self.hidden_dim // self.num_heads
|
178 |
+
|
179 |
+
if (self.head_dim * self.num_heads) != self.hidden_dim:
|
180 |
+
raise ValueError(
|
181 |
+
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
|
182 |
+
f" and `num_heads`: {self.num_heads})."
|
183 |
+
)
|
184 |
+
|
185 |
+
self.q_proj = nn.Sequential(
|
186 |
+
nn.LayerNorm(q_dim),
|
187 |
+
nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
|
188 |
+
)
|
189 |
+
self.num_of_kvs = len(kv_dim_list)
|
190 |
+
for i, kv_dim in enumerate(kv_dim_list):
|
191 |
+
setattr(
|
192 |
+
self,
|
193 |
+
"k_proj_{}".format(i),
|
194 |
+
nn.Sequential(
|
195 |
+
nn.LayerNorm(kv_dim),
|
196 |
+
nn.Linear(
|
197 |
+
kv_dim, self.num_heads * self.head_dim, bias=attention_bias
|
198 |
+
),
|
199 |
+
),
|
200 |
+
)
|
201 |
+
setattr(
|
202 |
+
self,
|
203 |
+
"v_proj_{}".format(i),
|
204 |
+
nn.Sequential(
|
205 |
+
nn.LayerNorm(kv_dim),
|
206 |
+
nn.Linear(
|
207 |
+
kv_dim, self.num_heads * self.head_dim, bias=attention_bias
|
208 |
+
),
|
209 |
+
),
|
210 |
+
)
|
211 |
+
self.o_proj = nn.Linear(
|
212 |
+
self.num_heads * self.head_dim, q_dim, bias=attention_bias
|
213 |
+
)
|
214 |
+
|
215 |
+
def forward(
|
216 |
+
self,
|
217 |
+
queries,
|
218 |
+
*vision_latents_attention_mask_list,
|
219 |
+
):
|
220 |
+
|
221 |
+
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
|
222 |
+
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
|
223 |
+
|
224 |
+
bsz, q_len, _ = queries.size()
|
225 |
+
|
226 |
+
query_states = self.q_proj(queries)
|
227 |
+
key_states = torch.cat(
|
228 |
+
[
|
229 |
+
getattr(self, "k_proj_{}".format(i))(vision_latents_list[i])
|
230 |
+
for i in range(self.num_of_kvs)
|
231 |
+
],
|
232 |
+
dim=1,
|
233 |
+
)
|
234 |
+
value_states = torch.cat(
|
235 |
+
[
|
236 |
+
getattr(self, "v_proj_{}".format(i))(vision_latents_list[i])
|
237 |
+
for i in range(self.num_of_kvs)
|
238 |
+
],
|
239 |
+
dim=1,
|
240 |
+
)
|
241 |
+
|
242 |
+
v_len = key_states.shape[1]
|
243 |
+
|
244 |
+
query_states = query_states.view(
|
245 |
+
bsz, q_len, self.num_heads, self.head_dim
|
246 |
+
).transpose(1, 2)
|
247 |
+
key_states = key_states.view(
|
248 |
+
bsz, v_len, self.num_heads, self.head_dim
|
249 |
+
).transpose(1, 2)
|
250 |
+
value_states = value_states.view(
|
251 |
+
bsz, v_len, self.num_heads, self.head_dim
|
252 |
+
).transpose(1, 2)
|
253 |
+
|
254 |
+
# if kv_weight is not None:
|
255 |
+
# kv_weight = kv_weight.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
256 |
+
|
257 |
+
attention_mask = torch.cat(attention_mask_list, dim=-1)
|
258 |
+
|
259 |
+
if attention_mask is not None:
|
260 |
+
if attention_mask.size() != (bsz, 1, q_len, v_len):
|
261 |
+
raise ValueError(
|
262 |
+
f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
|
263 |
+
)
|
264 |
+
|
265 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
266 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
267 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
268 |
+
query_states = query_states.contiguous()
|
269 |
+
key_states = key_states.contiguous()
|
270 |
+
value_states = value_states.contiguous()
|
271 |
+
|
272 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
273 |
+
query_states,
|
274 |
+
key_states,
|
275 |
+
value_states,
|
276 |
+
attn_mask=attention_mask,
|
277 |
+
)
|
278 |
+
# attn_output = spda(
|
279 |
+
# query_states,
|
280 |
+
# key_states,
|
281 |
+
# value_states,
|
282 |
+
# attn_mask=attention_mask,
|
283 |
+
# additional_score=kv_weight
|
284 |
+
# )
|
285 |
+
|
286 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
287 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
|
288 |
+
|
289 |
+
attn_output = self.o_proj(attn_output)
|
290 |
+
|
291 |
+
return attn_output
|
292 |
+
|
293 |
+
|
294 |
+
class MLP(nn.Module):
|
295 |
+
def __init__(self, d_in, d_hidden, d_out):
|
296 |
+
super().__init__()
|
297 |
+
self.linear_1 = nn.Linear(d_in, d_hidden, bias=False)
|
298 |
+
self.act = nn.GELU()
|
299 |
+
self.linear_2 = nn.Linear(d_hidden, d_out, bias=False)
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
return self.linear_2(self.act(self.linear_1(x)))
|
303 |
+
|
304 |
+
|
305 |
+
class VisionCrossAttentionLayer(nn.Module):
|
306 |
+
def __init__(
|
307 |
+
self,
|
308 |
+
q_dim,
|
309 |
+
context_dim,
|
310 |
+
kv_dim_list,
|
311 |
+
kv_size_list,
|
312 |
+
hidden_dim=1024,
|
313 |
+
layer_idx=0,
|
314 |
+
):
|
315 |
+
super().__init__()
|
316 |
+
num_heads = 16
|
317 |
+
self.num_of_kvs = len(kv_dim_list)
|
318 |
+
|
319 |
+
self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
|
320 |
+
self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
|
321 |
+
# if self.num_of_kvs > 1:
|
322 |
+
# self.weight_mlp = MLP(q_dim+hidden_dim, hidden_dim, self.num_of_kvs)
|
323 |
+
# self.tower_weight = nn.Parameter(torch.zeros((self.num_of_kvs)))
|
324 |
+
self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
|
325 |
+
|
326 |
+
self.norm = nn.LayerNorm(hidden_dim)
|
327 |
+
|
328 |
+
self.cross_attn = MultiKVCrossAttention(
|
329 |
+
hidden_dim, kv_dim_list, hidden_dim, num_heads
|
330 |
+
)
|
331 |
+
self.kv_size_list = kv_size_list
|
332 |
+
for i, kv_size in enumerate(kv_size_list):
|
333 |
+
if kv_size > 1:
|
334 |
+
setattr(
|
335 |
+
self,
|
336 |
+
"pos_embed_{}".format(i),
|
337 |
+
nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
|
338 |
+
)
|
339 |
+
# self.register_buffer("pos_embed_{}".format(i), torch.from_numpy(get_2d_sincos_pos_embed(hidden_dim, kv_size)).float(), persistent=False)
|
340 |
+
|
341 |
+
def forward(
|
342 |
+
self,
|
343 |
+
queries,
|
344 |
+
context_feature,
|
345 |
+
*vision_latents_attention_mask_list,
|
346 |
+
) -> torch.FloatTensor:
|
347 |
+
|
348 |
+
residual = queries
|
349 |
+
# queries = self.proj_in(queries)
|
350 |
+
context_feature = self.proj_context(context_feature)
|
351 |
+
# queries = queries + context_feature
|
352 |
+
queries = torch.cat([queries, context_feature], -1)
|
353 |
+
|
354 |
+
# if self.num_of_kvs > 1:
|
355 |
+
# kv_weight = self.weight_mlp(queries) # B * 1 * num_tower
|
356 |
+
# kv_weight = kv_weight + self.tower_weight.view(1, 1, -1)
|
357 |
+
# kv_weight = kv_weight.softmax(-1)
|
358 |
+
# kv_number_list = [size**2 for size in self.kv_size_list]
|
359 |
+
# kv_weight = torch.repeat_interleave(kv_weight, torch.tensor(kv_number_list).to(kv_weight.device), dim=-1)
|
360 |
+
# else:
|
361 |
+
# kv_weight = None
|
362 |
+
|
363 |
+
queries = self.proj_in(queries)
|
364 |
+
|
365 |
+
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
|
366 |
+
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
|
367 |
+
|
368 |
+
attention_mask_list_reshaped = []
|
369 |
+
if attention_mask_list is not None:
|
370 |
+
for attention_mask in attention_mask_list:
|
371 |
+
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
|
372 |
+
attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
|
373 |
+
attention_mask_list_reshaped.append(attention_mask)
|
374 |
+
|
375 |
+
vision_latents_pos_list = []
|
376 |
+
for i, vision_latents in enumerate(vision_latents_list):
|
377 |
+
if vision_latents.shape[1] > 1:
|
378 |
+
vision_latents_pos_list.append(
|
379 |
+
vision_latents
|
380 |
+
+ getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
|
381 |
+
vision_latents.dtype
|
382 |
+
)
|
383 |
+
)
|
384 |
+
else:
|
385 |
+
vision_latents_pos_list.append(vision_latents)
|
386 |
+
|
387 |
+
# Cross Attention
|
388 |
+
attention_output = self.cross_attn(
|
389 |
+
queries, *vision_latents_pos_list, *attention_mask_list_reshaped
|
390 |
+
)
|
391 |
+
|
392 |
+
# attention_output = (attention_output * combination_weight).sum(2)
|
393 |
+
queries = queries + attention_output
|
394 |
+
|
395 |
+
queries = self.norm(queries)
|
396 |
+
|
397 |
+
queries = self.proj_out(queries)
|
398 |
+
|
399 |
+
queries = queries + residual
|
400 |
+
|
401 |
+
return queries
|
402 |
+
|
403 |
+
|
404 |
+
class VisionAggregationLayer(nn.Module):
|
405 |
+
def __init__(
|
406 |
+
self,
|
407 |
+
q_dim,
|
408 |
+
context_dim,
|
409 |
+
kv_dim_list,
|
410 |
+
kv_size_list,
|
411 |
+
hidden_dim=1024,
|
412 |
+
layer_idx=0,
|
413 |
+
):
|
414 |
+
super().__init__()
|
415 |
+
num_heads = 16
|
416 |
+
self.num_of_kvs = len(kv_dim_list)
|
417 |
+
|
418 |
+
self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
|
419 |
+
self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
|
420 |
+
|
421 |
+
self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
|
422 |
+
|
423 |
+
self.norm = nn.LayerNorm(hidden_dim)
|
424 |
+
|
425 |
+
if self.num_of_kvs > 1:
|
426 |
+
self.weight_mlp = MLP(q_dim + hidden_dim, hidden_dim, self.num_of_kvs)
|
427 |
+
|
428 |
+
for i, kv_size in enumerate(kv_size_list):
|
429 |
+
if kv_size > 1:
|
430 |
+
setattr(
|
431 |
+
self,
|
432 |
+
"pos_embed_{}".format(i),
|
433 |
+
nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
|
434 |
+
)
|
435 |
+
setattr(
|
436 |
+
self,
|
437 |
+
"aggregate_{}".format(i),
|
438 |
+
AggregationBlock(
|
439 |
+
True, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
|
440 |
+
),
|
441 |
+
)
|
442 |
+
else:
|
443 |
+
setattr(
|
444 |
+
self,
|
445 |
+
"aggregate_{}".format(i),
|
446 |
+
AggregationBlock(
|
447 |
+
False, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
|
448 |
+
),
|
449 |
+
)
|
450 |
+
|
451 |
+
def forward(
|
452 |
+
self,
|
453 |
+
queries,
|
454 |
+
context_feature,
|
455 |
+
*vision_latents_attention_mask_list,
|
456 |
+
) -> torch.FloatTensor:
|
457 |
+
|
458 |
+
residual = queries
|
459 |
+
# queries = self.proj_in(queries)
|
460 |
+
context_feature = self.proj_context(context_feature)
|
461 |
+
# queries = queries + context_feature
|
462 |
+
queries = torch.cat([queries, context_feature], -1)
|
463 |
+
|
464 |
+
if self.num_of_kvs > 1:
|
465 |
+
combination_weight = self.weight_mlp(queries).softmax(
|
466 |
+
-1
|
467 |
+
) # B * 1 * num_tower
|
468 |
+
combination_weight = combination_weight.unsqueeze(-1)
|
469 |
+
else:
|
470 |
+
combination_weight = 1
|
471 |
+
|
472 |
+
queries = self.proj_in(queries)
|
473 |
+
|
474 |
+
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
|
475 |
+
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
|
476 |
+
|
477 |
+
attention_mask_list_reshaped = []
|
478 |
+
if attention_mask_list is not None:
|
479 |
+
for attention_mask in attention_mask_list:
|
480 |
+
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
|
481 |
+
attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
|
482 |
+
attention_mask_list_reshaped.append(attention_mask)
|
483 |
+
|
484 |
+
vision_latents_pos_list = []
|
485 |
+
for i, vision_latents in enumerate(vision_latents_list):
|
486 |
+
if vision_latents.shape[1] > 1:
|
487 |
+
vision_latents_pos_list.append(
|
488 |
+
vision_latents
|
489 |
+
+ getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
|
490 |
+
vision_latents.dtype
|
491 |
+
)
|
492 |
+
)
|
493 |
+
else:
|
494 |
+
vision_latents_pos_list.append(vision_latents)
|
495 |
+
|
496 |
+
aggregated_vision_latents_list = []
|
497 |
+
for i, (vision_latents, attention_mask) in enumerate(
|
498 |
+
zip(vision_latents_pos_list, attention_mask_list_reshaped)
|
499 |
+
):
|
500 |
+
aggregated_vision_latents_list.append(
|
501 |
+
getattr(self, "aggregate_{}".format(i))(
|
502 |
+
vision_latents, queries, attention_mask
|
503 |
+
)
|
504 |
+
)
|
505 |
+
|
506 |
+
aggregated_vision_latents = torch.stack(aggregated_vision_latents_list, 2)
|
507 |
+
|
508 |
+
queries = queries + (aggregated_vision_latents * combination_weight).sum(2)
|
509 |
+
|
510 |
+
queries = self.norm(queries)
|
511 |
+
|
512 |
+
queries = self.proj_out(queries)
|
513 |
+
|
514 |
+
queries = queries + residual
|
515 |
+
|
516 |
+
return queries
|
517 |
+
|
518 |
+
|
519 |
+
class VisionTokenSampler(nn.Module):
|
520 |
+
def __init__(
|
521 |
+
self,
|
522 |
+
q_dim,
|
523 |
+
context_dim,
|
524 |
+
kv_dim_list,
|
525 |
+
kv_size_list,
|
526 |
+
vision_hidden_size,
|
527 |
+
num_of_layers=1,
|
528 |
+
layer_type="joint",
|
529 |
+
):
|
530 |
+
super().__init__()
|
531 |
+
assert layer_type in ["joint", "sep"]
|
532 |
+
if layer_type == "joint":
|
533 |
+
self.layers = nn.ModuleList(
|
534 |
+
[
|
535 |
+
VisionCrossAttentionLayer(
|
536 |
+
q_dim,
|
537 |
+
context_dim,
|
538 |
+
kv_dim_list,
|
539 |
+
kv_size_list,
|
540 |
+
vision_hidden_size,
|
541 |
+
idx,
|
542 |
+
)
|
543 |
+
for idx in range(num_of_layers)
|
544 |
+
]
|
545 |
+
)
|
546 |
+
else:
|
547 |
+
self.layers = nn.ModuleList(
|
548 |
+
[
|
549 |
+
VisionAggregationLayer(
|
550 |
+
q_dim,
|
551 |
+
context_dim,
|
552 |
+
kv_dim_list,
|
553 |
+
kv_size_list,
|
554 |
+
vision_hidden_size,
|
555 |
+
idx,
|
556 |
+
)
|
557 |
+
for idx in range(num_of_layers)
|
558 |
+
]
|
559 |
+
)
|
560 |
+
|
561 |
+
def forward(self, queries, context_feature, *vision_latents_attention_mask_list):
|
562 |
+
for layer in self.layers:
|
563 |
+
queries = layer(
|
564 |
+
queries, context_feature, *vision_latents_attention_mask_list
|
565 |
+
)
|
566 |
+
return queries
|
requirements.txt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.22.2
|
2 |
+
torch==2.1.2
|
3 |
+
numpy==1.26.4
|
4 |
+
torchvision
|
5 |
+
transformers==4.42.4
|
6 |
+
tokenizers==0.15.2
|
7 |
+
sentencepiece==0.1.99
|
8 |
+
shortuuid
|
9 |
+
accelerate==0.34.2
|
10 |
+
peft==0.4.0
|
11 |
+
bitsandbytes==0.41.0
|
12 |
+
pydantic<2,>=1
|
13 |
+
markdown2
|
14 |
+
scikit-learn==1.2.2
|
15 |
+
gradio==3.35.2
|
16 |
+
gradio_client==0.2.9
|
17 |
+
requests
|
18 |
+
httpx==0.24.0
|
19 |
+
uvicorn
|
20 |
+
fastapi
|
21 |
+
einops==0.6.1
|
22 |
+
einops-exts==0.0.4
|
23 |
+
timm==0.9.16
|
24 |
+
decord
|
25 |
+
ninja
|
26 |
+
deepspeed==0.12.2
|
27 |
+
protobuf
|
28 |
+
iopath
|