from typing import Union, Optional import PIL.Image import torch from torch.nn.functional import softmax, gumbel_softmax, pad from transformers import PretrainedConfig, PreTrainedModel, AutoImageProcessor, AutoModel, AutoConfig from ovis.util.constants import IMAGE_INDICATOR_IDS, IMAGE_ATOM_ID class BaseVisualTokenizerConfig(PretrainedConfig): def __init__( self, vocab_size=16384, tokenize_function="softmax", tau=1.0, depths=None, drop_cls_token=False, backbone_config: Optional[Union[PretrainedConfig, dict]] = None, hidden_stride: int = 1, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.tokenize_function = tokenize_function self.tau = tau if isinstance(depths, str): depths = [int(x) for x in depths.split('|')] self.depths = depths self.backbone_kwargs = {} self.drop_cls_token = drop_cls_token if backbone_config is not None: assert isinstance(backbone_config, (PretrainedConfig, dict)), \ f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" if not isinstance(backbone_config, PretrainedConfig): model_type = backbone_config['model_type'] backbone_config.pop('model_type') backbone_config = AutoConfig.for_model(model_type, **backbone_config) self.backbone_config = backbone_config self.hidden_stride = hidden_stride class BaseVisualTokenizer(PreTrainedModel): base_model_prefix = "backbone" main_input_name = None _image_processor_class = None _image_processor_kwargs = {} _backbone_class = None _backbone_name_or_path = None def __init__(self, config: BaseVisualTokenizerConfig, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) if kwargs.get('train_from_scratch'): self.image_processor = self._image_processor_class.from_pretrained(self._backbone_name_or_path, **self._image_processor_kwargs) self.backbone = self._backbone_class.from_pretrained(self._backbone_name_or_path, **self.config.backbone_kwargs) self.config.backbone_config = self.backbone.config else: self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path']) self.backbone = AutoModel.from_config(self.config.backbone_config) head_dim = self.config.vocab_size - len(IMAGE_INDICATOR_IDS) # reserved tokens for IMAGE_INDICATORS self.head = torch.nn.Sequential( torch.nn.Linear( self.backbone.config.hidden_size * self.config.hidden_stride * self.config.hidden_stride, head_dim, bias=False ), torch.nn.LayerNorm(head_dim) ) assert all((self.image_processor.do_resize, not getattr(self.image_processor, 'do_center_crop', False), self.image_processor.do_rescale, self.image_processor.do_normalize )), f"image_processor `{self.image_processor}` is not supported currently" def get_backbone(self): return self.backbone def get_monitor_tensors(self): raise NotImplementedError def get_image_processor(self): return self.image_processor def mock_input(self): height, width = self.get_image_size() return torch.zeros(1, 3, height, width), self.construct_image_placeholders((1, 1)) def get_head(self): return self.head def get_image_size(self): raise NotImplementedError @staticmethod def construct_image_placeholders(grid): image_placeholders = [IMAGE_INDICATOR_IDS[0], IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS[1]] if grid[0] * grid[1] > 1: for r in range(grid[0]): for c in range(grid[1]): image_placeholders.append(IMAGE_ATOM_ID) if c < grid[1] - 1: image_placeholders.append(IMAGE_INDICATOR_IDS[2]) if r < grid[0] - 1: image_placeholders.append(IMAGE_INDICATOR_IDS[3]) image_placeholders.append(IMAGE_INDICATOR_IDS[4]) return image_placeholders def preprocess_image(self, image: PIL.Image.Image, max_partition=9, covering_threshold=0.9, convert_to_rgb=True): def _preprocess(img: PIL.Image.Image, side): # first resize and preprocess w, h = img.size if w == h: new_width = new_height = side elif w > h: new_width = side new_height = int(h / w * new_width) else: new_height = side new_width = int(w / h * new_height) new_size = dict(height=new_height, width=new_width) pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors='pt')['pixel_values'] # then pad to square square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device) new_height, new_width = pixel_values.shape[2:] if new_height == new_width: square_values[:, :, :, :] = pixel_values elif new_height > new_width: from_index = (side - new_width) // 2 square_values[:, :, :, from_index:from_index + new_width] = pixel_values else: from_index = (side - new_height) // 2 square_values[:, :, from_index:from_index + new_height, :] = pixel_values return square_values def _partition(img, grid): w, h = img.size row_height = h // grid[0] col_width = w // grid[1] partition = [] for row in range(grid[0]): for col in range(grid[1]): left = col * col_width upper = row * row_height right = w if col == grid[1] - 1 else (col + 1) * col_width lower = h if row == grid[0] - 1 else (row + 1) * row_height partition.append((left, upper, right, lower)) return partition def _covering_area(left, upper, right, lower, side): w = right - left h = lower - upper w, h = max(w, h), min(w, h) if w > side: h = h / w * side w = side return w * h def _get_best_grid(img, side): img_area = img.size[0] * img.size[1] candidate_grids = [] for i in range(1, max_partition + 1): for j in range(1, max_partition + 1): if i * j <= max_partition: candidate_grids.append((i, j)) all_grids = [] good_grids = [] for grid in candidate_grids: partition = _partition(img, grid) covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area assert covering_ratio <= 1.0 all_grids.append((grid, covering_ratio)) if covering_ratio > covering_threshold: good_grids.append((grid, covering_ratio)) if len(good_grids) > 0: # pick the good partition with minimum #sub_images and break the tie using covering_ratio return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0] else: # pick the partition with maximum covering_ratio and break the tie using #sub_images return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0] if convert_to_rgb and image.mode != 'RGB': image = image.convert('RGB') sides = self.get_image_size() if sides[0] != sides[1]: raise ValueError('get_image_size() returns non-square size') side = sides[0] grid = _get_best_grid(image, side) partition = _partition(image, grid) crops = [image.crop(p) for p in partition] if len(crops) > 1: crops.insert(0, image) pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0) image_placeholders = self.construct_image_placeholders(grid) return pixel_values, image_placeholders def get_backbone_layer(self, index): return self.backbone.vision_model.encoder.layers[index] def tokenize(self, logits): def st_argmax(y_soft, dim): # straight-through softmax index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft return ret if self.config.tokenize_function == 'softmax': tokens = softmax(logits, dim=-1) elif self.config.tokenize_function == 'gumbel_argmax': tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) elif self.config.tokenize_function == 'st_argmax': tokens = st_argmax(logits, dim=-1) else: raise ValueError( f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}') return tokens def encode(self, pixel_values): output = self.backbone(pixel_values, output_hidden_states=True, return_dict=True) features = output.hidden_states[-1] if self.config.drop_cls_token: features = features[:, 1:, :] # merge number of `hidden_stride * hidden_stride` hidden states together to reduce token sequence length # e.g., for hidden_stride=3, this leads to a token length reduction: 729 -> 81 for siglip if self.config.hidden_stride > 1: n, l, d = features.shape # this `d` maybe different from the above `d sqrt_l = int(l ** 0.5) assert sqrt_l ** 2 == l, "The token sequence length should be a perfect square." features = features.reshape(n, sqrt_l, sqrt_l, d) pl = (self.config.hidden_stride - (sqrt_l % self.config.hidden_stride)) % self.config.hidden_stride features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) sqrt_l += pl features = features.reshape(n, sqrt_l // self.config.hidden_stride, self.config.hidden_stride, sqrt_l // self.config.hidden_stride, self.config.hidden_stride, d) features = features.permute(0, 1, 3, 2, 4, 5) # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d] features = features.flatten(3) # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d] features = features.reshape( n, -1, self.config.hidden_stride * self.config.hidden_stride * d) return features def forward(self, pixel_values) -> torch.Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize] features = self.encode(pixel_values) logits = self.head(features) tokens = self.tokenize(logits) # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with [BatchSize, #Token, 5], after # which, tokens' shape should become [BatchSize, #Token, VocabSize] batch_size, token_len, _ = tokens.shape padding_tensor = torch.zeros(size=(batch_size, token_len, len(IMAGE_INDICATOR_IDS)), dtype=tokens.dtype, device=tokens.device, layout=tokens.layout, requires_grad=False) tokens = torch.cat((tokens, padding_tensor), dim=2) return tokens