OCR_V2_6_awq / modeling_minicpmv.py
hp1992's picture
Upload folder using huggingface_hub
0e88af3 verified
import math
from typing import List, Optional
import json
import torch
import torchvision
from threading import Thread
from copy import deepcopy
from PIL import Image
from transformers import AutoProcessor, Qwen2PreTrainedModel, Qwen2ForCausalLM, TextIteratorStreamer
from .configuration_minicpm import MiniCPMVConfig
from .modeling_navit_siglip import SiglipVisionTransformer
from .resampler import Resampler
class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel):
config_class = MiniCPMVConfig
class MiniCPMV(MiniCPMVPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.llm = Qwen2ForCausalLM(config)
self.vpm = self.init_vision_module()
self.vision_dim = self.vpm.embed_dim
self.embed_dim = self.llm.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.processor = None
self.terminators = ['<|im_end|>', '<|endoftext|>']
def init_vision_module(self):
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
if self.config._attn_implementation == 'flash_attention_2':
self.config.vision_config._attn_implementation = 'flash_attention_2'
else:
# not suport sdpa
self.config.vision_config._attn_implementation = 'eager'
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
setattr(model, 'embed_dim', model.embeddings.embed_dim)
setattr(model, 'patch_size', model.embeddings.patch_size)
return model
def init_resampler(self, embed_dim, vision_dim):
return Resampler(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
adaptive=True
)
def get_input_embeddings(self):
return self.llm.get_input_embeddings()
def set_input_embeddings(self, value):
self.llm.embed_tokens = value
def get_output_embeddings(self):
return self.llm.lm_head
def set_output_embeddings(self, new_embeddings):
self.llm.lm_head = new_embeddings
def set_decoder(self, decoder):
self.llm = decoder
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
def get_decoder(self):
return self.llm
def get_vllm_embedding(self, data):
if 'vision_hidden_states' not in data:
dtype = self.llm.model.embed_tokens.weight.dtype
device = self.llm.model.embed_tokens.weight.device
tgt_sizes = data['tgt_sizes']
pixel_values_list = data['pixel_values']
vision_hidden_states = []
all_pixel_values = []
img_cnt = []
for pixel_values in pixel_values_list:
img_cnt.append(len(pixel_values))
all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
# exist image
if all_pixel_values:
tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True,
padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
for i in range(B):
patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
vision_batch_size = self.config.vision_batch_size
all_pixel_values = all_pixel_values.type(dtype)
if B > vision_batch_size:
hs = []
for i in range(0, B, vision_batch_size):
start_idx = i
end_idx = i + vision_batch_size
tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state
hs.append(tmp_hs)
vision_embedding = torch.cat(hs, dim=0)
else:
vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
start = 0
for pixel_values in pixel_values_list:
img_cnt = len(pixel_values)
if img_cnt > 0:
vision_hidden_states.append(vision_embedding[start: start + img_cnt])
start += img_cnt
else:
vision_hidden_states.append([])
else: # no image
if self.training:
dummy_image = torch.zeros(
(1, 3, 224, 224),
device=device, dtype=dtype
)
tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32)
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
else:
dummy_feature = []
for _ in range(len(pixel_values_list)):
vision_hidden_states.append(dummy_feature)
else:
vision_hidden_states = data['vision_hidden_states']
if hasattr(self.llm.config, 'scale_emb'):
vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
else:
vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
i, torch.Tensor) else i for i in vision_hidden_states]
bs = len(data['input_ids'])
for i in range(bs):
cur_vs_hs = vision_hidden_states[i]
if len(cur_vs_hs) > 0:
cur_vllm_emb = vllm_embedding[i]
cur_image_bound = data['image_bound'][i]
if len(cur_image_bound) > 0:
image_indices = torch.stack(
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
).to(vllm_embedding.device)
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
elif self.training:
cur_vllm_emb += cur_vs_hs[0].mean() * 0
return vllm_embedding, vision_hidden_states
def forward(self, data, **kwargs):
if isinstance(data, torch.Tensor):
return self.llm(
input_ids=data,
**kwargs
)
else:
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
position_ids = data["position_ids"]
if position_ids.dtype != torch.int64:
position_ids = position_ids.long()
return self.llm(
input_ids=None,
position_ids=position_ids,
inputs_embeds=vllm_embedding,
**kwargs
)
def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
output = self.llm.generate(
inputs_embeds=inputs_embeds,
pad_token_id=0,
eos_token_id=terminators,
attention_mask=attention_mask,
**kwargs
)
if decode_text:
return self._decode_text(output, tokenizer)
return output
def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
streamer = TextIteratorStreamer(tokenizer=tokenizer)
generation_kwargs = {
'inputs_embeds': inputs_embeds,
'pad_token_id': 0,
'eos_token_id': terminators,
'streamer': streamer
}
generation_kwargs.update(kwargs)
thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def _decode_text(self, result_ids, tokenizer):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
result_text = []
for result in result_ids:
result = result[result != 0]
if result[0] == tokenizer.bos_id:
result = result[1:]
if result[-1] in terminators:
result = result[:-1]
result_text.append(tokenizer.decode(result).strip())
return result_text
def generate(
self,
input_ids=None,
pixel_values=None,
tgt_sizes=None,
image_bound=None,
attention_mask=None,
tokenizer=None,
vision_hidden_states=None,
return_vision_hidden_states=False,
stream=False,
decode_text=False,
**kwargs
):
assert input_ids is not None
assert len(input_ids) == len(pixel_values)
model_inputs = {
"input_ids": input_ids,
"image_bound": image_bound,
}
if vision_hidden_states is None:
model_inputs["pixel_values"] = pixel_values
model_inputs['tgt_sizes'] = tgt_sizes
else:
model_inputs["vision_hidden_states"] = vision_hidden_states
with torch.inference_mode():
(
model_inputs["inputs_embeds"],
vision_hidden_states,
) = self.get_vllm_embedding(model_inputs)
if stream:
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
else:
result = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs)
if return_vision_hidden_states:
return result, vision_hidden_states
return result
def chat(
self,
image,
msgs,
tokenizer,
processor=None,
vision_hidden_states=None,
max_new_tokens=2048,
min_new_tokens=0,
sampling=True,
max_inp_length=8192,
system_prompt='',
stream=False,
max_slice_nums=None,
use_image_id=None,
**kwargs
):
if isinstance(msgs[0], list):
batched = True
else:
batched = False
msgs_list = msgs
images_list = image
if batched is False:
images_list, msgs_list = [images_list], [msgs_list]
else:
assert images_list is None, "Please integrate image to msgs when using batch inference."
images_list = [None] * len(msgs_list)
assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same."
if processor is None:
if self.processor is None:
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
processor = self.processor
assert self.config.query_num == processor.image_processor.image_feature_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
assert self.config.patch_size == processor.image_processor.patch_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
assert self.config.use_image_id == processor.image_processor.use_image_id, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
assert self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
assert self.config.slice_mode == processor.image_processor.slice_mode, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
prompts_lists = []
input_images_lists = []
for image, msgs in zip(images_list, msgs_list):
if isinstance(msgs, str):
msgs = json.loads(msgs)
copy_msgs = deepcopy(msgs)
assert len(msgs) > 0, "msgs is empty"
assert sampling or not stream, "if use stream mode, make sure sampling=True"
if image is not None and isinstance(copy_msgs[0]["content"], str):
copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
images = []
for i, msg in enumerate(copy_msgs):
role = msg["role"]
content = msg["content"]
assert role in ["user", "assistant"]
if i == 0:
assert role == "user", "The role of first msg should be user"
if isinstance(content, str):
content = [content]
cur_msgs = []
for c in content:
if isinstance(c, Image.Image):
images.append(c)
cur_msgs.append("(<image>./</image>)")
elif isinstance(c, str):
cur_msgs.append(c)
msg["content"] = "\n".join(cur_msgs)
if system_prompt:
sys_msg = {'role': 'system', 'content': system_prompt}
copy_msgs = [sys_msg] + copy_msgs
prompts_lists.append(processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True))
input_images_lists.append(images)
inputs = processor(
prompts_lists,
input_images_lists,
max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
return_tensors="pt",
max_length=max_inp_length
).to(self.device)
if sampling:
generation_config = {
"top_p": 0.8,
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
"repetition_penalty": 1.05
}
else:
generation_config = {
"num_beams": 3,
"repetition_penalty": 1.2,
}
if min_new_tokens > 0:
generation_config['min_new_tokens'] = min_new_tokens
generation_config.update(
(k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
)
inputs.pop("image_sizes")
with torch.inference_mode():
res = self.generate(
**inputs,
tokenizer=tokenizer,
max_new_tokens=max_new_tokens,
vision_hidden_states=vision_hidden_states,
stream=stream,
decode_text=True,
**generation_config
)
if stream:
def stream_gen():
for text in res:
for term in self.terminators:
text = text.replace(term, '')
yield text
return stream_gen()
else:
if batched:
answer = res
else:
answer = res[0]
return answer