DiffusionText2WorldGeneration / ar_tokenizer_image_text_tokenizer.py
EthanZyh's picture
copied from EthanZyh/DiffusionText2WorldGeneration
8c31d70
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
import transformers
from transformers import AutoImageProcessor
from transformers.image_utils import ImageInput, is_valid_image, load_image
from .ar_tokenizer_text_tokenizer import TextTokenizer
from .log import log
# Configuration for different vision-language models
IMAGE_CONFIGS = {
"pixtral": {
"patch_size": 16,
"image_token": "[IMG]",
"image_break_token": "[IMG_BREAK]",
"image_end_token": "[IMG_END]",
}
}
# Chat template for Pixtral-12B-Instruct
PIXTRAL_CHAT_TEMPLATE = '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["content"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {{- message["content"] + eos_token}}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}'
# Copied from transformers.models.pixtral.processing_pixtral.is_url
def is_url(val) -> bool:
"""Check if the given value is a URL."""
return isinstance(val, str) and val.startswith("http")
# Copied from transformers.models.pixtral.processing_pixtral.is_image_or_image_url
def is_image_or_image_url(elem):
"""Check if the given element is an image or an image URL."""
return is_url(elem) or is_valid_image(elem)
def load_image_list(
image_list: List[Union[str, "PIL.Image.Image"]], timeout: Optional[float] = None
) -> List["PIL.Image.Image"]:
"""
Load a list of images.
Args:
image_list (List[Union[str, PIL.Image.Image]]): The list of images to load.
timeout (Optional[float]): The timeout for loading the image.
Returns:
List[PIL.Image.Image]: The list of loaded images.
"""
return [load_image(image, timeout=timeout) for image in image_list]
class ImageTextTokenizer(TextTokenizer):
"""
Image-text tokenizer class that extends the text tokenizer to support vision tokens as well.
"""
def __init__(
self,
model_family: str,
is_instruct_model: bool,
tokenizer_path: str,
image_processor_path: str,
):
"""
Initialize the ImageTextTokenizer.
Args:
model_family (str): The model family.
is_instruct_model (bool): Whether the model is an instruct model.
s3_credential_path (str): The path to the s3 credential file. Defaults to "credentials/pbss_dir.secret".
Raises:
AssertionError: If the model family is not supported or if the transformers version is incompatible.
"""
super().__init__(
model_family=model_family,
is_instruct_model=is_instruct_model,
local_path=tokenizer_path,
)
assert model_family in ["pixtral"], f"Unsupported model family: {model_family}"
if model_family == "pixtral":
# Need transformers>=4.45.0
assert transformers.__version__ >= "4.45.0", "Pixtral requires transformers>=4.45.0"
assert is_instruct_model, "Pixtral requires is_instruct_model=True"
if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None:
setattr(self.tokenizer, "chat_template", PIXTRAL_CHAT_TEMPLATE)
log.debug(f"Pixtral tokenizer chat template set to: {PIXTRAL_CHAT_TEMPLATE}")
# Set up image-specific configurations
image_config = IMAGE_CONFIGS[model_family]
self.patch_size = image_config["patch_size"]
self.image_token = image_config["image_token"]
self.image_break_token = image_config["image_break_token"]
self.image_end_token = image_config["image_end_token"]
# Initialize the image processor
self.image_processor = AutoImageProcessor.from_pretrained(image_processor_path)
def encode(
self,
text: Union[str, List[str], List[int]],
*, # Enforce keyword-only arguments
images: Optional[ImageInput] = None,
image_kwargs: Optional[Dict[str, Any]] = None,
**text_kwargs,
) -> List[int]:
"""
Process the images and return the tokenized images and text.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded.
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared.
image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing.
**text_kwargs: Additional keyword arguments for text processing.
Returns:
A dictionary with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **pixel_values** -- Pixel values to be fed to a model.
Raises:
ValueError: If the input images are in an invalid format.
"""
output_dict, image_inputs = {}, {}
if images is not None:
# Preprocess images
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
images = [images]
elif (
not isinstance(images, list)
and not isinstance(images[0], list)
and not is_image_or_image_url(images[0][0])
):
raise ValueError(
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
)
# Load and process images
images = [load_image_list(sample) for sample in images]
image_kwargs = image_kwargs or {}
image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors="np", **image_kwargs)
# Validate image inputs
assert "pixel_values" in image_inputs, "pixel_values not found in image_inputs"
assert "image_sizes" in image_inputs, "image_sizes not found in image_inputs"
assert len(image_inputs.keys()) == 2, "Only one key is allowed in image_inputs, got {}".format(
image_inputs.keys()
)
# Extract pixel values and image sizes
pixel_values = image_inputs["pixel_values"][0]
image_sizes = image_inputs["image_sizes"][0]
unique_sizes = np.unique(image_sizes, axis=0)
assert len(unique_sizes) == 1, "All images must have the same size, got {}".format(unique_sizes)
# Convert pixel values to PyTorch tensor
pixel_values = np.asarray(pixel_values)
pixel_values = torch.from_numpy(pixel_values)
output_dict["pixel_values"] = pixel_values
output_dict["image_sizes"] = image_sizes
# Expand image tokens in text
if image_inputs.get("pixel_values") is not None:
replace_strings = []
# Calculate the number of tokens needed for each image and create a placeholder
for image_size in image_sizes:
height, width = image_size
num_height_tokens = height // self.patch_size
num_width_tokens = width // self.patch_size
replace_tokens = [[self.image_token] * num_width_tokens + [self.image_break_token]] * num_height_tokens
# Flatten list
replace_tokens = [item for sublist in replace_tokens for item in sublist]
replace_tokens[-1] = self.image_end_token
replace_str = "".join(replace_tokens)
replace_strings.append(replace_str)
text = text.replace(self.image_token, "<placeholder>", 1)
# Replace placeholders with actual image token sequences
while "<placeholder>" in text:
replace_str = replace_strings.pop(0)
text = text.replace("<placeholder>", replace_str, 1)
# Encode the text
text_inputs = super(ImageTextTokenizer, self).encode(text, **text_kwargs)
output_dict["input_ids"] = text_inputs
return output_dict
def apply_chat_template(
self,
conversation: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
*,
images: Optional[ImageInput] = None,
image_kwargs: Optional[Dict[str, Any]] = None,
add_generation_prompt: bool = False,
tokenize: bool = True,
padding: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
return_tensors: Optional[str] = None,
return_dict: bool = True,
return_assistant_tokens_mask: bool = False,
generation_prefix: str = "",
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Apply the chat template to the conversation.
Args:
conversation (List[Dict[str, Any]] | List[List[Dict[str, Any]]]): The conversation to process.
images (Optional[ImageInput]): Images to include in the conversation.
image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing.
add_generation_prompt (bool): Whether to add a generation prompt.
tokenize (bool): Whether to tokenize the output.
padding (bool): Whether to pad the output.
truncation (bool): Whether to truncate the output.
max_length (Optional[int]): Maximum length of the output.
return_tensors (Optional[str]): The type of tensors to return.
return_dict (bool): Whether to return a dictionary.
return_assistant_tokens_mask (bool): Whether to return the assistant tokens mask.
generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "".
tokenizer_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer.
**kwargs: Additional keyword arguments.
Returns:
The processed conversation with applied chat template.
Raises:
AssertionError: If return_dict is False or if the conversation format is invalid.
"""
assert return_dict, "return_dict must be True for ImageTextTokenizer"
assert isinstance(conversation, list), "conversation must be a list"
if isinstance(conversation[0], list):
assert len(conversation) == 1, "Only support single-conversation input, got {}".format(conversation)
conversation = conversation[0]
# Extract images from the conversation if not provided
if images is None:
images = []
for msg in conversation:
if msg.get("images", None) is not None:
images = images + (msg["images"])
images = load_image_list(images)
# In case the input does not have images, will ignore
# Useful in feeding VLM inputs with and without images
if isinstance(images, list) and len(images) == 0:
images = None
# Apply the chat template to the text
text = super().apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=add_generation_prompt,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_dict=False,
return_assistant_tokens_mask=return_assistant_tokens_mask,
generation_prefix=generation_prefix,
tokenizer_kwargs=tokenizer_kwargs,
**kwargs,
)
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
# Encode the text and images
output = self.encode(
text,
images=images,
image_kwargs=image_kwargs,
tokenize=tokenize,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
return output
@property
def model_input_names(self):
"""
Get the combined model input names from both the text tokenizer and image processor.
Returns:
List[str]: A list of unique input names.
"""
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))