Feature Extraction
Safetensors
English
minicpmv
VisRAG
custom_code
VisRAG-Ret / modeling_visrag_ret.py
tcy5156
upload weights
c7a0be3
raw
history blame
4.61 kB
import torch
from dataclasses import dataclass
from transformers.utils import ModelOutput
from typing import Optional
from .modeling_minicpmv import MiniCPMV
from .modeling_minicpm import MiniCPMForCausalLM
from .resampler import Resampler
from concurrent.futures import ThreadPoolExecutor
def transform_image_mp(img_list, transform, device, max_workers=None):
pixel_values = []
# 使用ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for img_batch in img_list:
img_inps = list(executor.map(transform, img_batch))
for i in range(len(img_inps)):
img_inps[i] = img_inps[i].to(device)
pixel_values.append(img_inps if img_inps else [])
return pixel_values
@dataclass
class BaseModelOutputWithAttentionMask(ModelOutput):
last_hidden_state: torch.FloatTensor = None
attention_mask: Optional[torch.Tensor] = None
class VisRAG_Ret(MiniCPMV): # -> MiniCPMV -> Ultimately a CausalLM
def fused_tokenize(
self,
data_list=None, # List[str]
img_list=None, # List[List[PIL.Image]]
tokenizer=None,
max_inp_length: Optional[int] = None,
vision_hidden_states=None, # default None
return_vision_hidden_states=False,
**kwargs):
assert data_list is not None
bs = len(data_list)
if img_list == None:
img_list = [[] for i in range(bs)]
assert bs == len(img_list)
model_inputs = self._process_list(tokenizer, data_list, max_inp_length, padding_side="right")
if vision_hidden_states is None:
pixel_values = transform_image_mp(img_list, self.transform, self.device, max_workers=8)
model_inputs["pixel_values"] = pixel_values
else:
model_inputs["vision_hidden_states"] = vision_hidden_states
return model_inputs
def prepare_context(self, inputs, tokenizer):
text_, image_ = inputs
if not isinstance(text_, str):
raise NotImplementedError(f"chatml format expected, expect outmost type to be str but got {type(text_)}")
# 1.add text
content = text_
# 2. add image
if image_:
if self.config.slice_mode:
images, final_placeholder = self.get_slice_image_placeholder(
image_, tokenizer
) # crop one image into multiple sub images -> List[Image]
content = final_placeholder + "\n" + content
else:
images = [image_] # only keep one image without cropping -> List[Image]
content = (
tokenizer.im_start
+ tokenizer.unk_token * self.config.query_num
+ tokenizer.im_end
+ "\n"
+ content
)
else:
images = []
return content, images
def forward(
self,
text, # List[str] B*str
image, # List[ PIL.Image ] B*PIL.Image, one image for each data
tokenizer,
vision_hidden_states=None,
max_inp_length=2048,
**kwargs):
processed_image = []
processed_text = []
with ThreadPoolExecutor(max_workers=8) as executor:
contexts = list(executor.map(lambda inputs: self.prepare_context(inputs, tokenizer), zip(text, image)))
for context in contexts:
content_, image_ = context
processed_text.append(content_)
processed_image.append(image_)
model_inputs = self.fused_tokenize(
data_list=processed_text, # List[str]
img_list=processed_image, # List[List[PIL.Image]]
tokenizer=tokenizer,
max_inp_length=max_inp_length
)
# this is vision encoder forward.
model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs)
vlm_outputs = self.llm.model(
input_ids=None, # because image and text have been merged into model_inputs["inputs_embeds"] here, we don't give input_ids
position_ids=None,
inputs_embeds=model_inputs["inputs_embeds"],
attention_mask=model_inputs["attention_mask"],
return_dict=True
)
return BaseModelOutputWithAttentionMask(
last_hidden_state=vlm_outputs.last_hidden_state,
attention_mask=model_inputs.attention_mask
)