File size: 3,118 Bytes
9862f98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0c6d64
9862f98
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from typing import Any, Dict, Optional
import PIL
import torch
import PIL
import torch
from typing import Dict
from io import BytesIO
from transformers import SiglipImageProcessor
from sentence_transformers.models import Transformer as BaseTransformer


class MultiModalTransformer(BaseTransformer):

    def __init__(
        self,
        model_name_or_path: str,
        cache_dir: Optional[str] = None,
        tokenizer_args: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        super().__init__(model_name_or_path, **kwargs)
        if tokenizer_args is None:
            tokenizer_args = {}
        self.processor = SiglipImageProcessor.from_pretrained(
            model_name_or_path, cache_dir=cache_dir, **tokenizer_args
        )

    def forward(
        self, features: dict[str, torch.Tensor], **kwargs
    ) -> dict[str, torch.Tensor]:
        trans_features = {
            "input_ids": features["input_ids"],
            "attention_mask": features["attention_mask"],
        }
        if "pixel_values" in features:
            trans_features["pixel_values"] = features["pixel_values"].to(
                self.auto_model.dtype
            )

        sentence_embedding = self.auto_model(**trans_features, **kwargs)[
            "sentence_embedding"
        ]
        features.update({"sentence_embedding": sentence_embedding})
        return features

    def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
        img_start_token = "<|jasper_img_start|>"
        img_token = "<|jasper_img_token|>"
        img_end_token = "<|jasper_img_end|>"
        num_img_tokens = 300

        def process_text_item(item):
            if isinstance(item, str):
                return item, []
            text, images = "", []
            for sub_item in item:
                if sub_item["type"] == "text":
                    text += sub_item["content"]
                elif sub_item["type"] == "image_bytes":
                    text += img_start_token + img_token * num_img_tokens + img_end_token
                    images.append(
                        PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB")
                    )
                elif sub_item["type"] == "image_path":
                    text += img_start_token + img_token * num_img_tokens + img_end_token
                    images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
                else:
                    raise ValueError(f"unknown data type {sub_item['type']}")
            return text, images

        all_texts, all_images = [], []
        for item in texts:
            text, images = process_text_item(item)
            all_texts.append(text)
            all_images.extend(images)
        ipt = self.tokenizer(
            all_texts,
            padding="longest",
            truncation=True,
            max_length=self.max_seq_length,
            return_tensors="pt",
        )
        if all_images:
            ipt["pixel_values"] = self.processor(
                images=all_images, return_tensors="pt"
            )["pixel_values"]
        return ipt