|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Processor class for Img2HTML. |
|
""" |
|
|
|
from typing import Callable, List, Optional, Union |
|
from urllib.parse import urlparse |
|
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy |
|
from transformers.utils import TensorType, is_torch_available |
|
|
|
from .image_processing_img2html import Img2HTMLImageProcessor |
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
|
|
IMAGE_TOKEN = "<image>" |
|
|
|
|
|
def is_url(string): |
|
"""Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately |
|
invalidated the url""" |
|
if " " in string: |
|
return False |
|
result = urlparse(string) |
|
return all([result.scheme, result.netloc]) |
|
|
|
class Img2HTMLProcessor(ProcessorMixin): |
|
r""" |
|
Constructs a Img2HTML processor which wraps a LLama tokenizer and Img2HTML image processor into a single processor. |
|
|
|
[`Img2HTMLProcessor`] offers all the functionalities of [`Img2HTMLImageProcessor`] and [`LlamaTokenizerFast`]. See |
|
the docstring of [`~Img2HTMLProcessor.__call__`] and [`~Img2HTMLProcessor.decode`] for more information. |
|
|
|
Args: |
|
image_processor (`Img2HTMLImageProcessor`): |
|
An instance of [`Img2HTMLImageProcessor`]. The image processor is a required input. |
|
tokenizer (`LlamaTokenizerFast`): |
|
An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input. |
|
image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image) |
|
""" |
|
|
|
attributes = ["image_processor", "tokenizer"] |
|
image_processor_class = "Img2HTMLImageProcessor" |
|
tokenizer_class = "LlamaTokenizerFast" |
|
|
|
def __init__(self, image_processor, tokenizer=None, image_size=960, **kwargs): |
|
if image_processor is None: |
|
raise ValueError("You need to specify an `image_processor`.") |
|
if tokenizer is None: |
|
raise ValueError("You need to specify a `tokenizer`.") |
|
|
|
super().__init__(image_processor, tokenizer) |
|
self.current_processor = self.image_processor |
|
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) |
|
|
|
self.default_image_dims = ( |
|
self.image_processor.image_num_channels, |
|
self.image_processor.image_size, |
|
self.image_processor.image_size, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__( |
|
self, |
|
prompts: Union[List[TextInput], List[List[TextInput]]], |
|
padding: Union[bool, str, PaddingStrategy] = False, |
|
truncation: Union[bool, str, TruncationStrategy] = None, |
|
max_length: Optional[int] = None, |
|
transform: Callable = None, |
|
add_eos_token=False, |
|
debug=False, |
|
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, |
|
) -> BatchEncoding: |
|
"""This method takes batched or non-batched prompts made of text and images and converts them into prompts that |
|
the model was trained on and prepares the image pixel values for the model to process. |
|
|
|
Args: |
|
prompts (`Union[List[TextInput], [List[List[TextInput]]]]`): |
|
either a single prompt or a batched list of prompts - see the detailed description immediately after |
|
the end of the arguments doc section. |
|
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): |
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding |
|
index) among: |
|
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single |
|
sequence if provided). |
|
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum |
|
acceptable input length for the model if that argument is not provided. |
|
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different |
|
lengths). |
|
max_length (`int`, *optional*): |
|
Maximum length of the returned list and optionally padding length (see above). |
|
truncation (`bool`, *optional*): |
|
Activates truncation to cut input sequences longer than `max_length` to `max_length`. |
|
transform (`Callable`, *optional*): |
|
A custom transform function that accepts a single image can be passed for training. For example, |
|
`torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific |
|
set of transforms will be applied to the images |
|
add_eos_token (`bool`, *optional*, defaults to `False`): |
|
Adds `eos_token` at the end of the final prompt if True` |
|
debug (`bool`, *optional*, defaults to `False`): |
|
`True` value will help debug prompt generation by dumping useful information |
|
return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`): |
|
The type of tensors to return. Can be one of: |
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. |
|
|
|
Returns: |
|
a dict with entries: `input_ids`, `attention_mask`, `pixel_values`, `image_attention_mask` which can be |
|
directly passed to `model.generate` |
|
|
|
Detailed explanation: |
|
|
|
Each entry in `prompts` is either a text to be passed as is or an image that will be processed. |
|
|
|
An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved. |
|
|
|
When the processor encounters an image it'll inject `<fake_token_around_image><image><fake_token_around_image>` |
|
entry into the prompt. |
|
|
|
Example: |
|
|
|
```python |
|
checkpoint = "HuggingFaceM4/Img2HTML-9b" |
|
processor = AutoProcessor.from_pretrained(checkpoint) |
|
url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg" |
|
img = processor.image_processor.fetch_images([url])[0] |
|
|
|
prompts = [ |
|
"User:", |
|
img, |
|
"Describe this image.\nAssistant: An image of two kittens in grass.\n", |
|
"User:", |
|
"https://hips.hearstapps.com/hmg-prod/images/dog-puns-1581708208.jpg", |
|
"Describe this image.\nAssistant:", |
|
] |
|
|
|
inputs = processor(prompts, return_tensors="pt") |
|
generated_ids = model.generate(**inputs, max_length=100) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
``` |
|
|
|
In this example the `prompts` will be converted into: |
|
|
|
``` |
|
<s>User:<fake_token_around_image><image><fake_token_around_image>Describe this image. |
|
Assistant: An image of two kittens in grass. |
|
User:<fake_token_around_image><image><fake_token_around_image>Describe this image. |
|
Assistant:' |
|
``` |
|
|
|
and the two images will be massaged using [`Img2HTMLImageProcessor.__call__`] method and placed inside the |
|
`pixel_values` dict entry of the return value. |
|
|
|
This example also examplifies that images can be passed as objects or as text urls. It can be seen that the |
|
first image is passed as object and the second one as a url. |
|
|
|
To do training do: |
|
|
|
```python |
|
image_transform = transforms.Compose( |
|
[ |
|
transforms.RandomResizedCrop( |
|
(w, h), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC |
|
), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=self.image_mean, std=self.image_std), |
|
] |
|
) |
|
inputs = processor(prompts, transform=image_transform, return_tensors="pt") |
|
``` |
|
|
|
In order to help debug prompt generation enable `debug=True` which will show you what's happening. |
|
|
|
""" |
|
|
|
if not any(isinstance(i, list) for i in prompts): |
|
prompts = [prompts] |
|
|
|
fake_token = "<fake_token_around_image>" |
|
image_token = "<image>" |
|
|
|
def image_tokens(last_was_image): |
|
if last_was_image: |
|
return image_token + fake_token |
|
else: |
|
return fake_token + image_token + fake_token |
|
|
|
all_prompts = [] |
|
all_images = [] |
|
for sample in prompts: |
|
|
|
full_text = f"{self.tokenizer.bos_token}" |
|
|
|
|
|
image_objects = [] |
|
last_was_image = False |
|
last_was_text = False |
|
for i, item in enumerate(sample): |
|
if i > 0: |
|
last_was_text = True if not last_was_image else False |
|
|
|
if isinstance(item, str): |
|
item = item.strip(" ") |
|
if is_url(item): |
|
image = self.image_processor.fetch_images(item) |
|
full_text += image_tokens(last_was_image) |
|
image_objects.append(image) |
|
last_was_image = True |
|
else: |
|
full_text += item |
|
last_was_image = False |
|
else: |
|
|
|
full_text += image_tokens(last_was_image) |
|
image_objects.append(item) |
|
last_was_image = True |
|
|
|
if add_eos_token: |
|
full_text += self.tokenizer.eos_token |
|
|
|
if debug is True: |
|
print(f"{full_text=}") |
|
|
|
image_objects = self.image_processor(image_objects, transform=transform) |
|
|
|
all_prompts.append(full_text) |
|
all_images.append(image_objects) |
|
|
|
text_encoding = self.tokenizer( |
|
text=all_prompts, |
|
add_special_tokens=False, |
|
padding=padding, |
|
truncation=truncation, |
|
max_length=max_length, |
|
) |
|
all_texts = text_encoding["input_ids"] |
|
|
|
max_seq_len = max(len(x) for x in all_texts) |
|
|
|
|
|
max_num_images = max(len(x) for x in all_images) |
|
max_num_images = max(1, max_num_images) |
|
|
|
output_input_ids = [] |
|
output_images = [] |
|
output_attention_masks = [] |
|
for text, images in zip(all_texts, all_images): |
|
padded_input_ids = [self.tokenizer.pad_token_id] * max_seq_len |
|
unpadded_seq_len = len(text) |
|
start = max_seq_len - unpadded_seq_len |
|
padded_input_ids[start:] = text[:max_seq_len] |
|
|
|
attention_mask = torch.zeros((max_seq_len,), dtype=torch.long) |
|
attention_mask[start:] = 1 |
|
|
|
image_count = padded_input_ids.count(self.image_token_id) |
|
local_max_num_images = min(image_count, max_num_images) |
|
|
|
current_images = images[:local_max_num_images] |
|
|
|
if len(current_images) > 0: |
|
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) |
|
padded_image_tensor[: current_images.size(0)] = current_images |
|
else: |
|
padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims) |
|
|
|
output_images.append(padded_image_tensor) |
|
output_input_ids.append(torch.tensor(padded_input_ids)) |
|
|
|
output_attention_masks.append(attention_mask) |
|
|
|
output_input_ids = torch.stack(output_input_ids) |
|
output_images = torch.stack(output_images) |
|
output_attention_masks = torch.stack(output_attention_masks) |
|
|
|
|
|
return BatchFeature( |
|
data={ |
|
"input_ids": output_input_ids, |
|
"attention_mask": output_attention_masks, |
|
"pixel_values": output_images, |
|
} |
|
) |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
refer to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
|
the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
@property |
|
def model_input_names(self): |
|
tokenizer_input_names = self.tokenizer.model_input_names |
|
image_processor_input_names = self.image_processor.model_input_names |
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
|
|