Upload modeling_mplug_owl2.py with huggingface_hub
Browse files- modeling_mplug_owl2.py +58 -1
modeling_mplug_owl2.py
CHANGED
@@ -37,6 +37,40 @@ IMAGE_TOKEN_INDEX = -200
|
|
37 |
DEFAULT_IMAGE_TOKEN = "<|image|>"
|
38 |
from icecream import ic
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
class MPLUGOwl2MetaModel:
|
41 |
def __init__(self, config):
|
42 |
super(MPLUGOwl2MetaModel, self).__init__(config)
|
@@ -218,13 +252,36 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
|
218 |
self.model = MPLUGOwl2LlamaModel(config)
|
219 |
|
220 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
221 |
|
222 |
# Initialize weights and apply final processing
|
223 |
self.post_init()
|
224 |
|
225 |
def get_model(self):
|
226 |
return self.model
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
def forward(
|
229 |
self,
|
230 |
input_ids: torch.LongTensor = None,
|
|
|
37 |
DEFAULT_IMAGE_TOKEN = "<|image|>"
|
38 |
from icecream import ic
|
39 |
|
40 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
41 |
+
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
|
42 |
+
|
43 |
+
def insert_separator(X, sep):
|
44 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
45 |
+
|
46 |
+
input_ids = []
|
47 |
+
offset = 0
|
48 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
49 |
+
offset = 1
|
50 |
+
input_ids.append(prompt_chunks[0][0])
|
51 |
+
|
52 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
53 |
+
input_ids.extend(x[offset:])
|
54 |
+
|
55 |
+
if return_tensors is not None:
|
56 |
+
if return_tensors == 'pt':
|
57 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
58 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
59 |
+
return input_ids
|
60 |
+
|
61 |
+
def expand2square(pil_img, background_color):
|
62 |
+
width, height = pil_img.size
|
63 |
+
if width == height:
|
64 |
+
return pil_img
|
65 |
+
elif width > height:
|
66 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
67 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
68 |
+
return result
|
69 |
+
else:
|
70 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
71 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
72 |
+
return result
|
73 |
+
|
74 |
class MPLUGOwl2MetaModel:
|
75 |
def __init__(self, config):
|
76 |
super(MPLUGOwl2MetaModel, self).__init__(config)
|
|
|
252 |
self.model = MPLUGOwl2LlamaModel(config)
|
253 |
|
254 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
255 |
+
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
|
256 |
+
self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(model.device)
|
257 |
|
258 |
# Initialize weights and apply final processing
|
259 |
self.post_init()
|
260 |
|
261 |
def get_model(self):
|
262 |
return self.model
|
263 |
+
|
264 |
+
def score(self, images,
|
265 |
+
task_: str = "quality",
|
266 |
+
input_: str = "image",
|
267 |
+
):
|
268 |
+
prompt = "USER: How would you rate the {} of this {}?\n<|image|>\nASSISTANT: The {} of the {} is".format(task_, input_, input_, task_)
|
269 |
+
if input_ == "image":
|
270 |
+
images = [expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in images]
|
271 |
+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
272 |
+
with torch.inference_mode():
|
273 |
+
image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
|
274 |
+
output_logits = model(input_ids.repeat(image_tensor.shape[0], 1),
|
275 |
+
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|
276 |
+
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
277 |
+
else:
|
278 |
+
video = [[expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in images]
|
279 |
+
with torch.inference_mode():
|
280 |
+
video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
|
281 |
+
output_logits = self.model(self.input_ids.repeat(len(video_tensors), 1),
|
282 |
+
images=video_tensors)["logits"][:,-1, self.preferential_ids_]
|
283 |
+
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
284 |
+
|
285 |
def forward(
|
286 |
self,
|
287 |
input_ids: torch.LongTensor = None,
|