Tom Aarsen
Use self.max_seq_length to inform the maximum tokenize length
c0c6d64
raw
history blame
3.12 kB
from typing import Any, Dict, Optional
import PIL
import torch
import PIL
import torch
from typing import Dict
from io import BytesIO
from transformers import SiglipImageProcessor
from sentence_transformers.models import Transformer as BaseTransformer
class MultiModalTransformer(BaseTransformer):
def __init__(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(model_name_or_path, **kwargs)
if tokenizer_args is None:
tokenizer_args = {}
self.processor = SiglipImageProcessor.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **tokenizer_args
)
def forward(
self, features: dict[str, torch.Tensor], **kwargs
) -> dict[str, torch.Tensor]:
trans_features = {
"input_ids": features["input_ids"],
"attention_mask": features["attention_mask"],
}
if "pixel_values" in features:
trans_features["pixel_values"] = features["pixel_values"].to(
self.auto_model.dtype
)
sentence_embedding = self.auto_model(**trans_features, **kwargs)[
"sentence_embedding"
]
features.update({"sentence_embedding": sentence_embedding})
return features
def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
img_start_token = "<|jasper_img_start|>"
img_token = "<|jasper_img_token|>"
img_end_token = "<|jasper_img_end|>"
num_img_tokens = 300
def process_text_item(item):
if isinstance(item, str):
return item, []
text, images = "", []
for sub_item in item:
if sub_item["type"] == "text":
text += sub_item["content"]
elif sub_item["type"] == "image_bytes":
text += img_start_token + img_token * num_img_tokens + img_end_token
images.append(
PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB")
)
elif sub_item["type"] == "image_path":
text += img_start_token + img_token * num_img_tokens + img_end_token
images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
else:
raise ValueError(f"unknown data type {sub_item['type']}")
return text, images
all_texts, all_images = [], []
for item in texts:
text, images = process_text_item(item)
all_texts.append(text)
all_images.extend(images)
ipt = self.tokenizer(
all_texts,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt",
)
if all_images:
ipt["pixel_values"] = self.processor(
images=all_images, return_tensors="pt"
)["pixel_values"]
return ipt