|
from typing import Union, Optional |
|
|
|
import PIL.Image |
|
import torch |
|
from torch.nn.functional import softmax, gumbel_softmax, pad |
|
from transformers import PretrainedConfig, PreTrainedModel, AutoImageProcessor, AutoModel, AutoConfig |
|
from ovis.util.constants import IMAGE_INDICATOR_IDS, IMAGE_ATOM_ID |
|
|
|
|
|
class BaseVisualTokenizerConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
vocab_size=16384, |
|
tokenize_function="softmax", |
|
tau=1.0, |
|
depths=None, |
|
drop_cls_token=False, |
|
backbone_config: Optional[Union[PretrainedConfig, dict]] = None, |
|
hidden_stride: int = 1, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.vocab_size = vocab_size |
|
self.tokenize_function = tokenize_function |
|
self.tau = tau |
|
if isinstance(depths, str): |
|
depths = [int(x) for x in depths.split('|')] |
|
self.depths = depths |
|
self.backbone_kwargs = {} |
|
self.drop_cls_token = drop_cls_token |
|
if backbone_config is not None: |
|
assert isinstance(backbone_config, (PretrainedConfig, dict)), \ |
|
f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" |
|
if not isinstance(backbone_config, PretrainedConfig): |
|
model_type = backbone_config['model_type'] |
|
backbone_config.pop('model_type') |
|
backbone_config = AutoConfig.for_model(model_type, **backbone_config) |
|
self.backbone_config = backbone_config |
|
self.hidden_stride = hidden_stride |
|
|
|
|
|
class BaseVisualTokenizer(PreTrainedModel): |
|
base_model_prefix = "backbone" |
|
main_input_name = None |
|
_image_processor_class = None |
|
_image_processor_kwargs = {} |
|
_backbone_class = None |
|
_backbone_name_or_path = None |
|
|
|
def __init__(self, config: BaseVisualTokenizerConfig, *inputs, **kwargs): |
|
super().__init__(config, *inputs, **kwargs) |
|
if kwargs.get('train_from_scratch'): |
|
self.image_processor = self._image_processor_class.from_pretrained(self._backbone_name_or_path, |
|
**self._image_processor_kwargs) |
|
self.backbone = self._backbone_class.from_pretrained(self._backbone_name_or_path, |
|
**self.config.backbone_kwargs) |
|
self.config.backbone_config = self.backbone.config |
|
else: |
|
self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path']) |
|
self.backbone = AutoModel.from_config(self.config.backbone_config) |
|
head_dim = self.config.vocab_size - len(IMAGE_INDICATOR_IDS) |
|
self.head = torch.nn.Sequential( |
|
torch.nn.Linear( |
|
self.backbone.config.hidden_size * self.config.hidden_stride * self.config.hidden_stride, head_dim, |
|
bias=False |
|
), |
|
torch.nn.LayerNorm(head_dim) |
|
) |
|
|
|
assert all((self.image_processor.do_resize, |
|
not getattr(self.image_processor, 'do_center_crop', False), |
|
self.image_processor.do_rescale, |
|
self.image_processor.do_normalize |
|
)), f"image_processor `{self.image_processor}` is not supported currently" |
|
|
|
def get_backbone(self): |
|
return self.backbone |
|
|
|
def get_monitor_tensors(self): |
|
raise NotImplementedError |
|
|
|
def get_image_processor(self): |
|
return self.image_processor |
|
|
|
def mock_input(self): |
|
height, width = self.get_image_size() |
|
return torch.zeros(1, 3, height, width), self.construct_image_placeholders((1, 1)) |
|
|
|
def get_head(self): |
|
return self.head |
|
|
|
def get_image_size(self): |
|
raise NotImplementedError |
|
|
|
@staticmethod |
|
def construct_image_placeholders(grid): |
|
image_placeholders = [IMAGE_INDICATOR_IDS[0], IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS[1]] |
|
if grid[0] * grid[1] > 1: |
|
for r in range(grid[0]): |
|
for c in range(grid[1]): |
|
image_placeholders.append(IMAGE_ATOM_ID) |
|
if c < grid[1] - 1: |
|
image_placeholders.append(IMAGE_INDICATOR_IDS[2]) |
|
if r < grid[0] - 1: |
|
image_placeholders.append(IMAGE_INDICATOR_IDS[3]) |
|
image_placeholders.append(IMAGE_INDICATOR_IDS[4]) |
|
return image_placeholders |
|
|
|
def preprocess_image(self, image: PIL.Image.Image, max_partition=9, covering_threshold=0.9, convert_to_rgb=True): |
|
def _preprocess(img: PIL.Image.Image, side): |
|
|
|
w, h = img.size |
|
if w == h: |
|
new_width = new_height = side |
|
elif w > h: |
|
new_width = side |
|
new_height = int(h / w * new_width) |
|
else: |
|
new_height = side |
|
new_width = int(w / h * new_height) |
|
new_size = dict(height=new_height, width=new_width) |
|
pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors='pt')['pixel_values'] |
|
|
|
|
|
square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device) |
|
new_height, new_width = pixel_values.shape[2:] |
|
if new_height == new_width: |
|
square_values[:, :, :, :] = pixel_values |
|
elif new_height > new_width: |
|
from_index = (side - new_width) // 2 |
|
square_values[:, :, :, from_index:from_index + new_width] = pixel_values |
|
else: |
|
from_index = (side - new_height) // 2 |
|
square_values[:, :, from_index:from_index + new_height, :] = pixel_values |
|
|
|
return square_values |
|
|
|
def _partition(img, grid): |
|
w, h = img.size |
|
row_height = h // grid[0] |
|
col_width = w // grid[1] |
|
|
|
partition = [] |
|
for row in range(grid[0]): |
|
for col in range(grid[1]): |
|
left = col * col_width |
|
upper = row * row_height |
|
right = w if col == grid[1] - 1 else (col + 1) * col_width |
|
lower = h if row == grid[0] - 1 else (row + 1) * row_height |
|
partition.append((left, upper, right, lower)) |
|
|
|
return partition |
|
|
|
def _covering_area(left, upper, right, lower, side): |
|
w = right - left |
|
h = lower - upper |
|
w, h = max(w, h), min(w, h) |
|
if w > side: |
|
h = h / w * side |
|
w = side |
|
return w * h |
|
|
|
def _get_best_grid(img, side): |
|
img_area = img.size[0] * img.size[1] |
|
|
|
candidate_grids = [] |
|
for i in range(1, max_partition + 1): |
|
for j in range(1, max_partition + 1): |
|
if i * j <= max_partition: |
|
candidate_grids.append((i, j)) |
|
|
|
all_grids = [] |
|
good_grids = [] |
|
for grid in candidate_grids: |
|
partition = _partition(img, grid) |
|
covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area |
|
assert covering_ratio <= 1.0 |
|
all_grids.append((grid, covering_ratio)) |
|
if covering_ratio > covering_threshold: |
|
good_grids.append((grid, covering_ratio)) |
|
|
|
if len(good_grids) > 0: |
|
|
|
return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0] |
|
else: |
|
|
|
return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0] |
|
|
|
if convert_to_rgb and image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
sides = self.get_image_size() |
|
if sides[0] != sides[1]: |
|
raise ValueError('get_image_size() returns non-square size') |
|
side = sides[0] |
|
grid = _get_best_grid(image, side) |
|
partition = _partition(image, grid) |
|
crops = [image.crop(p) for p in partition] |
|
if len(crops) > 1: |
|
crops.insert(0, image) |
|
pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0) |
|
image_placeholders = self.construct_image_placeholders(grid) |
|
return pixel_values, image_placeholders |
|
|
|
def get_backbone_layer(self, index): |
|
return self.backbone.vision_model.encoder.layers[index] |
|
|
|
def tokenize(self, logits): |
|
def st_argmax(y_soft, dim): |
|
index = y_soft.max(dim, keepdim=True)[1] |
|
y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) |
|
ret = y_hard - y_soft.detach() + y_soft |
|
return ret |
|
|
|
if self.config.tokenize_function == 'softmax': |
|
tokens = softmax(logits, dim=-1) |
|
elif self.config.tokenize_function == 'gumbel_argmax': |
|
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) |
|
elif self.config.tokenize_function == 'st_argmax': |
|
tokens = st_argmax(logits, dim=-1) |
|
else: |
|
raise ValueError( |
|
f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}') |
|
return tokens |
|
|
|
def encode(self, pixel_values): |
|
output = self.backbone(pixel_values, output_hidden_states=True, return_dict=True) |
|
features = output.hidden_states[-1] |
|
if self.config.drop_cls_token: |
|
features = features[:, 1:, :] |
|
|
|
|
|
|
|
if self.config.hidden_stride > 1: |
|
n, l, d = features.shape |
|
sqrt_l = int(l ** 0.5) |
|
assert sqrt_l ** 2 == l, "The token sequence length should be a perfect square." |
|
features = features.reshape(n, sqrt_l, sqrt_l, d) |
|
pl = (self.config.hidden_stride - (sqrt_l % self.config.hidden_stride)) % self.config.hidden_stride |
|
features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) |
|
sqrt_l += pl |
|
features = features.reshape(n, sqrt_l // self.config.hidden_stride, self.config.hidden_stride, |
|
sqrt_l // self.config.hidden_stride, self.config.hidden_stride, d) |
|
features = features.permute(0, 1, 3, 2, 4, 5) |
|
features = features.flatten(3) |
|
features = features.reshape( |
|
n, -1, self.config.hidden_stride * self.config.hidden_stride * d) |
|
|
|
return features |
|
|
|
def forward(self, pixel_values) -> torch.Tensor: |
|
features = self.encode(pixel_values) |
|
logits = self.head(features) |
|
tokens = self.tokenize(logits) |
|
|
|
|
|
batch_size, token_len, _ = tokens.shape |
|
padding_tensor = torch.zeros(size=(batch_size, token_len, len(IMAGE_INDICATOR_IDS)), |
|
dtype=tokens.dtype, |
|
device=tokens.device, |
|
layout=tokens.layout, |
|
requires_grad=False) |
|
tokens = torch.cat((tokens, padding_tensor), dim=2) |
|
return tokens |
|
|