File size: 24,664 Bytes
312e8ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 |
""" Core wrapper patching class on mllama-11b OV - excludes all conversion components - and is only for inference.
-- Generation loop flows through GenerationMixin - will need to remove torch + transformers
"""
from pathlib import Path
from transformers import AutoConfig, GenerationConfig
from typing import Optional, Union, List, Tuple, Dict
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import ModelOutput
import openvino.runtime.opset13 as ops
import openvino as ov
import torch
import numpy as np
from dataclasses import dataclass
from openvino.runtime.passes import Manager, MatcherPass, WrapType, Matcher
import time
core = ov.Core()
LANGUAGE_MODEL = "llm_int4_asym_r10_gs64_max_activation_variance_scale_all_layers.xml"
IMAGE_ENCODER = "openvino_vision_encoder_int8.xml"
@dataclass
class MLlamaOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attn_key_values: Optional[List[torch.FloatTensor]] = None
class InsertSlice(MatcherPass):
def __init__(self):
MatcherPass.__init__(self)
self.model_changed = False
param = WrapType("opset10.Result")
def callback(matcher: Matcher) -> bool:
root = matcher.get_match_root()
if root is None:
return False
if len(root.get_output_partial_shape(0)) == 3:
parent = root.input_value(0).get_node()
grand_parent = parent.input_value(0).get_node()
grand_parent_output = parent.input(0).get_source_output()
consumers = grand_parent_output.get_target_inputs()
start = np.array([0, -1, 0], dtype=np.int32)
stop = np.array([1, -2, grand_parent_output.get_partial_shape()[-1].get_length()], dtype=np.int32)
step = np.array([1, -1, 1], dtype=np.int32)
axes = np.array([0, 1, 2], dtype=np.int32)
slice = ops.slice(grand_parent, start, stop, step, axes, name="inserted_slice")
for consumer in consumers:
consumer.replace_source_output(slice.output(0))
self.model_changed = True
# Use new operation for additional matching
self.register_new_node(slice)
print("applied slice for lm head")
return True
self.register_matcher(Matcher(param, "InsertSlice"), callback)
STR_TO_OV_TYPE = {
"boolean": ov.Type.boolean,
"f16": ov.Type.f16,
"f32": ov.Type.f32,
"f64": ov.Type.f64,
"i8": ov.Type.i8,
"i16": ov.Type.i16,
"i32": ov.Type.i32,
"i64": ov.Type.i64,
"u8": ov.Type.u8,
"u16": ov.Type.u16,
"u32": ov.Type.u32,
"u64": ov.Type.u64,
"bf16": ov.Type.bf16,
}
class OVMLlamaForConditionalGeneration(GenerationMixin):
def __init__(
self,
model_dir: Union[str, Path],
device: str = "CPU",
ov_config: Optional[Dict[str, str]] = None,
language_model_name=None,
image_encoder_name=None,
slice_lm_head=True,
use_remote_tensors=True,
dynamic_shape=False,
):
model_dir = Path(model_dir)
self.config = AutoConfig.from_pretrained(model_dir)
self.generation_config = GenerationConfig.from_pretrained(model_dir)
self.main_input_name = "input_ids"
self.device = torch.device("cpu")
self._device = device
self.ov_config = ov_config
self.num_pkv = 2
self._supports_cache_class = False
self.next_beam_idx = None
self._past_length = None
if language_model_name:
self.model = core.read_model(model_dir / language_model_name)
else:
self.model = core.read_model(model_dir / LANGUAGE_MODEL)
if image_encoder_name:
self.vision_model = core.read_model(model_dir / image_encoder_name)
else:
self.vision_model = core.read_model(model_dir / IMAGE_ENCODER)
if not dynamic_shape:
self.reshape_vision_model()
self.update_pkv_precision()
if slice_lm_head:
self.slice_lm_head()
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.lm_cross_attn_inputs = [key for key in self.input_names if "cross_attn_key_values" in key]
compiled_model = core.compile_model(self.model, device, ov_config)
self.request = compiled_model.create_infer_request()
self.cross_attn_outputs = [key.get_any_name() for key in self.vision_model.outputs if "cross_attn_key_values" in key.get_any_name()]
compiled_vision_model = core.compile_model(self.vision_model, device, ov_config)
self.vision_request = compiled_vision_model.create_infer_request()
self.use_remote_tensors = use_remote_tensors and self._device == "GPU"
if self.use_remote_tensors:
self.prepare_remote_tensors()
self.next_beam_idx = None
self.num_patches = (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2 + 1
self._past_length = 0
self.llm_infer_time = []
self.vision_encoder_infer_time = []
def _get_past_length(self, past_key_values=None):
if past_key_values is None:
return 0
return self._past_length
def reshape_vision_model(self):
self.vision_model.reshape(
{
0: ov.PartialShape([1, 1, 4, 3, self.config.vision_config.image_size, self.config.vision_config.image_size]),
1: ov.PartialShape([1, 1]),
2: ov.PartialShape([1, 1, 4]),
}
)
def update_pkv_precision(self, force_fp32=False):
pkv_precision = ov.Type.f32
if not force_fp32:
device = self._device.upper()
try:
if "INFERENCE_PRECISION_HINT" in core.get_property(device, "SUPPORTED_PROPERTIES"):
pkv_precision = core.get_property(device, "INFERENCE_PRECISION_HINT")
except RuntimeError: # use default precision when get_property fails, e.g. when device is "AUTO:GPU"
pass
# ov_config["INFERENCE_PRECISION_HINT"] may override the prefer precision
if self.ov_config:
inference_precision_hint = self.ov_config.get("INFERENCE_PRECISION_HINT", "")
if inference_precision_hint in STR_TO_OV_TYPE:
pkv_precision = STR_TO_OV_TYPE[inference_precision_hint]
ppp = ov.preprocess.PrePostProcessor(self.model)
for key in self.model.inputs:
if "cross_attn_key_values" in key.get_any_name() and pkv_precision != key.get_element_type():
ppp.input(key.get_any_name()).tensor().set_element_type(pkv_precision)
self.model = ppp.build()
ppp_v = ov.preprocess.PrePostProcessor(self.vision_model)
for key in self.vision_model.outputs:
if "cross_attn_key_values" in key.get_any_name() and pkv_precision != key.get_element_type():
ppp_v.output(key.get_any_name()).tensor().set_element_type(pkv_precision)
self.vision_model = ppp_v.build()
self._pkv_precision = pkv_precision
def slice_lm_head(self):
manager = Manager()
manager.register_pass(InsertSlice())
manager.run_passes(self.model)
self.model.validate_nodes_and_infer_types()
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[List[List[int]]] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[List[List[List[int]]]] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cross_attn_key_values: Optional[List[torch.Tensor]] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, MLlamaOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one")
if pixel_values is not None and cross_attention_states is not None:
raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
# get vision tokens from vision model
cross_attn_key_values = self.visual_encoder(pixel_values, aspect_ratio_ids, aspect_ratio_mask)
cross_attention_mask, full_text_row_masked_out_mask = self._prepare_cross_attention_mask(
cross_attention_mask,
past_key_values=past_key_values,
num_vision_tokens=self.num_patches,
cross_attention_layers=cross_attn_key_values if past_key_values is not None else None,
cross_attention_states=((),),
device=self.device,
dtype=torch.float32,
)
if cross_attention_mask is not None and cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
return self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
cache_position=cache_position,
cross_attention_key_values=cross_attn_key_values,
)
def language_model(
self,
input_ids,
attention_mask,
position_ids,
cross_attention_mask,
full_text_row_masked_out_mask,
past_key_values,
cache_position,
cross_attention_key_values,
):
model_inputs = {
"input_ids": ov.Tensor(np.array(input_ids)),
"attention_mask": ov.Tensor(np.array(attention_mask)),
"position_ids": ov.Tensor(np.array(position_ids)),
"cross_attention_mask": ov.Tensor(np.array(cross_attention_mask)),
"full_text_row_masked_out_mask": ov.Tensor(np.array(full_text_row_masked_out_mask)),
"cache_position": ov.Tensor(np.array(cache_position)),
}
if past_key_values is None:
self.request.reset_state()
self.next_beam_idx = np.arange(input_ids.shape[0], dtype=int)
self._past_length = 0
self.llm_infer_time = []
if not self.use_remote_tensors:
model_inputs.update(dict(zip(self.lm_cross_attn_inputs, cross_attention_key_values)))
if "beam_idx" in self.input_names:
model_inputs["beam_idx"] = self.next_beam_idx if self.next_beam_idx is not None else np.arange(input_ids.shape[0], dtype=int)
start = time.perf_counter()
self.request.start_async(model_inputs, share_inputs=True)
self.request.wait()
end = time.perf_counter()
self.llm_infer_time.append(end - start)
logits = torch.from_numpy(self.request.get_tensor("logits").data)
past_key_values = ((),)
self._past_length += input_ids.shape[1]
out = MLlamaOutputWithPast(logits=logits, past_key_values=past_key_values, cross_attn_key_values=cross_attention_key_values)
return out
def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True
def __call__(self, *args, **kwargs) -> MLlamaOutputWithPast:
return self.forward(
*args,
**kwargs,
)
def _reorder_cache(self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called.
This is required to match `past_key_values` with the correct beam_idx at every generation step.
"""
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
return past_key_values
def prepare_inputs_for_generation(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
pixel_values=None,
aspect_ratio_ids=None,
aspect_ratio_mask=None,
cross_attention_mask=None,
past_key_values=None,
use_cache=False,
cache_position=None,
cross_attn_key_values=None,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"cross_attention_mask": cross_attention_mask,
"cross_attn_key_values": cross_attn_key_values,
}
)
# If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
# to compute image hidden states, otherwise they are cache/home/ea/llama3.2/Llama-3.2-11B-Vision-Early/OVd within each cross attn layer
if (input_ids == self.config.image_token_index).any():
model_inputs["pixel_values"] = pixel_values
model_inputs["aspect_ratio_ids"] = aspect_ratio_ids
model_inputs["aspect_ratio_mask"] = aspect_ratio_mask
return model_inputs
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# add cross-attn mask for new token
if cross_attention_mask_prev is not None:
model_kwargs["cross_attention_mask"] = torch.cat([cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1)
model_kwargs["cross_attn_key_values"] = outputs.cross_attn_key_values
return model_kwargs
def _prepare_cross_attention_mask(
self,
cross_attention_mask: torch.Tensor,
past_key_values: Tuple,
num_vision_tokens: int,
cross_attention_states: torch.Tensor,
cross_attention_layers: List[int],
device: str,
dtype: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_mask is None:
# should we raise error or prepare a full attn mask with all ones?
return None, None
else:
# reshape so it can be used by attn module
batch_size, text_total_length, *_ = cross_attention_mask.shape
cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3)
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
cross_attention_mask = cross_attention_mask.unsqueeze(1)
# invert the mask
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
cross_attention_mask = inverted_cross_attn_mask.masked_fill(inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min)
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
# last dimension contains negative infinity values, otherwise it's 1
negative_inf_value = torch.finfo(dtype).min
full_text_row_masked_out_mask = (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
cross_attention_mask *= full_text_row_masked_out_mask
# In case we receive a new image but already have previous cross-attention key/values in cache,
# then we need to extend the attention-mask and add previous images' lengths
if past_key_values is not None and cross_attention_states is not None and cross_attention_layers is not None:
# make all zeros mask for cross-attn-mask from previuos cached hidden_states, all zeros right?
# i.e. extend current cross-attn-mask on image-seq-length dimension to account for past_seen_tokens
past_cross_attn_kv_length = cross_attention_layers[0].shape[-2]
past_cross_attn_mask = torch.zeros((*cross_attention_mask.shape[:-1], past_cross_attn_kv_length), dtype=dtype, device=device)
# concatenate both on image-seq-length dimension
cross_attention_mask = torch.cat([past_cross_attn_mask, cross_attention_mask], dim=-1)
return cross_attention_mask, full_text_row_masked_out_mask
def visual_encoder(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
self.vision_encoder_infer_time = []
start = time.perf_counter()
# get vision tokens from vision model
self.vision_request.start_async([pixel_values, aspect_ratio_ids, aspect_ratio_mask], share_inputs=True)
self.vision_request.wait()
end = time.perf_counter()
cross_attn_key_values = [self.vision_request.get_tensor(name) for name in self.cross_attn_outputs]
self.vision_encoder_infer_time.append(end - start)
return cross_attn_key_values
def prepare_vision_outputs(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask, cross_attention_mask=None, past_key_values=None, cache_position=None):
cross_attn_key_values = self.visual_encoder(pixel_values, aspect_ratio_ids, aspect_ratio_mask)
cross_attn_key_values = [v.data for v in cross_attn_key_values]
cross_attention_mask, full_text_row_masked_out_mask = self._prepare_cross_attention_mask(
cross_attention_mask,
past_key_values=past_key_values,
num_vision_tokens=self.num_patches,
cross_attention_layers=cross_attn_key_values if past_key_values is not None else None,
cross_attention_states=1,
device=self.device,
dtype=torch.float32,
)
if cross_attention_mask is not None and cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
return {
"cross_attention_mask": cross_attention_mask,
"full_text_row_masked_out_mask": full_text_row_masked_out_mask,
"past_key_values": past_key_values,
"cache_position": cache_position,
"cross_attention_key_values": cross_attn_key_values,
}
def prepare_llm_inputs(
self,
input_ids,
attention_mask,
position_ids,
cross_attention_mask,
full_text_row_masked_out_mask,
past_key_values,
cache_position,
cross_attention_key_values,
):
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"cross_attention_mask": cross_attention_mask,
"full_text_row_masked_out_mask": full_text_row_masked_out_mask,
"cache_position": cache_position,
}
if past_key_values is None:
self.request.reset_state()
self.next_beam_idx = np.arange(input_ids.shape[0], dtype=int)
self._past_length = 0
model_inputs.update(dict(zip(self.lm_cross_attn_inputs, cross_attention_key_values)))
if "beam_idx" in self.input_names:
model_inputs["beam_idx"] = self.next_beam_idx if self.next_beam_idx is not None else np.arange(input_ids.shape[0], dtype=int)
return model_inputs
def prepare_remote_tensors(self):
context = core.get_default_context("GPU")
for idx, name in enumerate(self.lm_cross_attn_inputs):
remote_tensor = context.create_tensor(ov.Type.f16, ov.Shape([1, 32, 6404, 128]), {})
self.vision_request.set_tensor(self.cross_attn_outputs[idx], remote_tensor)
self.request.set_tensor(name, remote_tensor)
|