diff --git a/egogpt/__pycache__/constants.cpython-310.pyc b/egogpt/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..303d5f29d6bd88200e99a795cc038fc0a9cac2f0 Binary files /dev/null and b/egogpt/__pycache__/constants.cpython-310.pyc differ diff --git a/egogpt/__pycache__/conversation.cpython-310.pyc b/egogpt/__pycache__/conversation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79a844f1b9b8675fb872712a81d3946b88a53129 Binary files /dev/null and b/egogpt/__pycache__/conversation.cpython-310.pyc differ diff --git a/egogpt/__pycache__/mm_utils.cpython-310.pyc b/egogpt/__pycache__/mm_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39bc72c927621427b43722714856e3869b6857f8 Binary files /dev/null and b/egogpt/__pycache__/mm_utils.cpython-310.pyc differ diff --git a/egogpt/__pycache__/utils.cpython-310.pyc b/egogpt/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15f50c1230391a619da6ede06807d073a9f17db2 Binary files /dev/null and b/egogpt/__pycache__/utils.cpython-310.pyc differ diff --git a/egogpt/constants.py b/egogpt/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..b976de237c27a9819c4386365e12fe7d5f1dbb14 --- /dev/null +++ b/egogpt/constants.py @@ -0,0 +1,11 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +SPEECH_TOKEN_INDEX = -200 +DEFAULT_SPEECH_TOKEN = "" +IMAGE_TOKEN_INDEX = -300 +DEFAULT_IMAGE_TOKEN = "" diff --git a/egogpt/conversation.py b/egogpt/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..5df932f3e7a3ba3a89b20a0c8103291cb0306315 --- /dev/null +++ b/egogpt/conversation.py @@ -0,0 +1,287 @@ +# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import dataclasses +from enum import Enum, auto +from io import BytesIO +from typing import Any, List, Tuple, Union + +from PIL import Image + + +class SeparatorStyle(Enum): + """Different separator style.""" + + TWO = auto() + PLAIN = auto() + CHATML = auto() + LLAMA_2 = auto() + LLAMA_3 = auto() + QWEN2 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.PLAIN + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + tokenizer_id: str = "" + tokenizer: Any = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + + if self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.LLAMA_3: + wrap_sys = ( + lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" + if len(msg) > 0 + else msg + ) + ret = "<|begin_of_text|>" + wrap_sys(self.system) + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + ret += message.strip() + self.sep2 + else: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + return ret + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = ( + lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg + ) + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: + message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if self.system == "" else self.system + self.sep + "\n" + for role, message in messages: + if message: + if type(message) is tuple: + message, images = message + message = "" * len(images) + message + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.QWEN2: + start = "<|im_start|>" + end = "<|im_end|>\n" + ret = start + "system\n" + self.system + end + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + + if message.endswith("<|endoftext|>"): + message = message.replace("<|endoftext|>", "") + ret += start + role + "\n" + message + end + "<|endoftext|>" + else: + assert ( + not "<|endoftext|>" in message + ), f"Invalid message: {message}" + ret += start + role + "\n" + message + end + else: + ret += start + role + "\n" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + msg, speech = msg + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version, + ) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [ + [x, y[0] if type(y) is tuple else y] for x, y in self.messages + ], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=[], + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llama_2 = Conversation( + system="You are a helpful language and speech assistant. " + "You are able to understand the speech content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=[], + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_llama_3 = Conversation( + system="You are a helpful language and speech assistant. " + "You are able to understand the speech content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=("user", "assistant"), + version="llama_v3", + messages=[], + offset=0, + sep_style=SeparatorStyle.LLAMA_3, + sep="", + sep2="<|eot_id|>", +) + + +conv_qwen_v1 = Conversation( + system="You are a helpful assistant.", + roles=("user", "assistant"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.QWEN2, +) + +conv_plain = Conversation( + system="", + roles=("", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", +) + +conv_qwen = Conversation( + system="""<|im_start|>system +You are a helpful assistant.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + version="qwen", + messages=[], + offset=0, + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", +) + +default_conversation = conv_llama_3 +conv_templates = { + "v1": conv_vicuna_v1, + "plain": conv_plain, + "llama_2": conv_llama_2, + "llama_3": conv_llama_3, + "v1_qwen2": conv_qwen_v1, + "qwen_1_5": conv_qwen, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/egogpt/mm_utils.py b/egogpt/mm_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..9a6ad4e8edaae46181031b14a38ab58d3295d63d --- /dev/null +++ b/egogpt/mm_utils.py @@ -0,0 +1,450 @@ +import ast +import base64 +import math +import re +from io import BytesIO + +import torch +from PIL import Image +from transformers import StoppingCriteria + + +def resize_and_center_crop(image, shortest_edge_length): + # Calculate new dimensions and resize + aspect_ratio = float(image.width) / float(image.height) + if aspect_ratio > 1: + new_width = int(shortest_edge_length * aspect_ratio) + new_height = shortest_edge_length + else: + new_width = shortest_edge_length + new_height = int(shortest_edge_length / aspect_ratio) + resized_image = image.resize((new_width, new_height), Image.ANTIALIAS) + + # Calculate the position and perform the center crop + left = (new_width - shortest_edge_length) / 2 + top = (new_height - shortest_edge_length) / 2 + right = (new_width + shortest_edge_length) / 2 + bottom = (new_height + shortest_edge_length) / 2 + cropped_image = resized_image.crop((left, top, right, bottom)) + + return cropped_image + + +def auto_pad_images(image, grid_params): + assert isinstance(image, Image.Image), "Input should be a Pillow Image" + assert len(grid_params) > 0, "Grid parameters should not be empty" + + # Step 1: Calculate and find the closest aspect ratio + input_width, input_height = image.size + input_aspect_ratio = input_width / input_height + candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params] + closest_aspect_ratio = min( + candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]) + ) + + candidate_resolutions = [ + (x[1], x[2]) + for x in candidate_resolutions + if abs(x[0] - closest_aspect_ratio[0]) < 1e-3 + ] + + target_resolution = min( + candidate_resolutions, + key=lambda res: abs(max(input_width, input_height) / max(res) - 1), + ) + + resize_width, resize_height = target_resolution + if input_width > input_height: + resize_height = int(resize_width / input_aspect_ratio) + else: + resize_width = int(resize_height * input_aspect_ratio) + resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS) + + # Step 5: Pad the resized image if necessary to match the target resolution + pad_width = target_resolution[0] - resize_width + pad_height = target_resolution[1] - resize_height + padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0)) + padded_image.paste(resized_image, (pad_width // 2, pad_height // 2)) + + return padded_image + + +def extract_patches(image, patch_size, overlap_ratio): + assert isinstance(image, Image.Image), "Input should be a Pillow Image" + assert patch_size > 0, "Patch size should be greater than 0" + assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1" + + W, H = image.size + patches = [] + + stride = int(patch_size * (1 - overlap_ratio)) + + num_patches_y = (H - patch_size) // stride + 1 + num_patches_x = (W - patch_size) // stride + 1 + + y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2 + x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2 + + for y in range(y_start, y_start + num_patches_y * stride, stride): + for x in range(x_start, x_start + num_patches_x * stride, stride): + patch = image.crop((x, y, x + patch_size, y + patch_size)) + patches.append(patch) + + return patches + + +def process_highres_image_crop_split(image, data_args, processor=None): + crop_resolution = data_args.image_crop_resolution + split_resolution = data_args.image_split_resolution + if processor is None: + processor = data_args.image_processor + image_crop = resize_and_center_crop(image, crop_resolution) + image_patches = extract_patches( + image_crop, patch_size=split_resolution, overlap_ratio=0 + ) + image_patches = [ + processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] + for image_patch in image_patches + ] + return torch.stack(image_patches, dim=0) + + +def process_highres_image(image, processor, grid_pinpoints): + grid_params = [int(x) for x in grid_pinpoints.split(",")] + width_height = max(image.size) + fit_grid_params = [x for x in grid_params if x >= width_height] + if len(fit_grid_params) == 0: + select_size = max(grid_params) + else: + select_size = min(fit_grid_params) + # FIXME: always select the 448 + select_size = max(grid_params) + image_padded = expand2square( + image, tuple(int(x * 255) for x in processor.image_mean) + ) + + # FIXME: this seems to be a bug that it always resizes instead of padding + image_original_resize = image.resize( + (processor.size["shortest_edge"], processor.size["shortest_edge"]) + ) + image_padded = image_padded.resize((select_size, select_size)) + image_patches = extract_patches( + image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0 + ) + image_patches = [image_original_resize] + image_patches + image_patches = [ + processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] + for image_patch in image_patches + ] + return torch.stack(image_patches, dim=0) + + +def select_best_resolution(original_size, possible_resolutions): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for width, height in possible_resolutions: + # Calculate the downscaled size to keep the aspect ratio + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int( + original_height * scale + ) + + # Calculate effective and wasted resolutions + effective_resolution = min( + downscaled_width * downscaled_height, original_width * original_height + ) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + +def resize_and_pad_image(image, target_resolution): + """ + Resize and pad an image to a target resolution while maintaining aspect ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ + original_width, original_height = image.size + target_width, target_height = target_resolution + + # Determine which dimension (width or height) to fill + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + # Width will be filled completely + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + # Height will be filled completely + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + # Resize the image + resized_image = image.resize((new_width, new_height)) + + # Create a new image with the target size and paste the resized image onto it + new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + new_image.paste(resized_image, (paste_x, paste_y)) + + return new_image + + +def divide_to_patches(image, patch_size): + """ + Divides an image into patches of a specified size. + + Args: + image (PIL.Image.Image): The input image. + patch_size (int): The size of each patch. + + Returns: + list: A list of PIL.Image.Image objects representing the patches. + """ + patches = [] + width, height = image.size + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + box = (j, i, j + patch_size, i + patch_size) + patch = image.crop(box) + patches.append(patch) + + return patches + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (tuple): The size of the input image in the format (width, height). + grid_pinpoints (str): A string representation of a list of possible resolutions. + patch_size (int): The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + assert patch_size in [ + 224, + 336, + 384, + 448, + 512, + ], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + width, height = select_best_resolution(image_size, possible_resolutions) + return width // patch_size, height // patch_size + + +def process_anyres_image(image, processor, grid_pinpoints): + """ + Process an image with variable resolutions. + + Args: + image (PIL.Image.Image): The input image to be processed. + processor: The image processor object. + grid_pinpoints (str): A string representation of a list of possible resolutions. + + Returns: + torch.Tensor: A tensor containing the processed image patches. + """ + # Convert grid_pinpoints from string to list + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + try: + patch_size = processor.size[0] + except Exception as e: + patch_size = processor.size["shortest_edge"] + assert patch_size in [ + 224, + 336, + 384, + 448, + 512, + ], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + best_resolution = select_best_resolution(image.size, possible_resolutions) + image_padded = resize_and_pad_image(image, best_resolution) + + patches = divide_to_patches(image_padded, processor.crop_size["height"]) + + # FIXME: this seems to be a bug that it resizes instead of pad. + # but to keep it consistent with previous, i will keep it as it is + # TODO: uncomment below to ablate with the padding + if isinstance(processor.size, dict): + shortest_edge = processor.size["shortest_edge"] + else: + shortest_edge = min(processor.size) + image_original_resize = image.resize((shortest_edge, shortest_edge)) + # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) + # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) + + image_patches = [image_original_resize] + patches + image_patches = [ + processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] + for image_patch in image_patches + ] + return torch.stack(image_patches, dim=0) + + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def process_images(images, image_processor, model_cfg): + image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) + new_images = [] + try: + image = images[0].convert("RGB") + except Exception as e: + print(f"Failed to open image {images[0]}. Exception:", e) + raise e + + image_sizes = image.size + if image_aspect_ratio == "highres": + for image in images: + image = process_highres_image( + image, image_processor, model_cfg.image_grid_pinpoints + ) + new_images.append(image) + elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: + for image in images: + image = process_anyres_image( + image, image_processor, model_cfg.image_grid_pinpoints + ) + new_images.append(image) + elif image_aspect_ratio == "crop_split": + for image in images: + image = process_highres_image_crop_split(image, model_cfg, image_processor) + new_images.append(image) + elif image_aspect_ratio == "pad": + for image in images: + image = expand2square( + image, tuple(int(x * 255) for x in image_processor.image_mean) + ) + image = image_processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] + new_images.append(image) + else: + return image_processor.preprocess(images, return_tensors="pt")["pixel_values"] + if all(x.shape == new_images[0].shape for x in new_images): + new_images = torch.stack(new_images, dim=0) + return new_images + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith("checkpoint-"): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if ( + len(cur_keyword_ids) > 1 + and cur_keyword_ids[0] == tokenizer.bos_token_id + ): + cur_keyword_ids = cur_keyword_ids[1:] + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def __call__( + self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO + offset = min(output_ids.shape[1] - self.start_len, 3) + self.keyword_ids = [ + keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids + ] + for keyword_id in self.keyword_ids: + if output_ids[0, -keyword_id.shape[0] :] == keyword_id: + return True + outputs = self.tokenizer.batch_decode( + output_ids[:, -offset:], skip_special_tokens=True + )[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False diff --git a/egogpt/model/__init__.py b/egogpt/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf752ec121b0721958558021910d257ab9d5802 --- /dev/null +++ b/egogpt/model/__init__.py @@ -0,0 +1,2 @@ +from .language_model.egogpt_llama import EgoGPTConfig, EgoGPTLlamaForCausalLM +from .language_model.egogpt_qwen import EgoGPTConfigQwen, EgoGPTQwenForCausalLM diff --git a/egogpt/model/__pycache__/__init__.cpython-310.pyc b/egogpt/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6c556e02b05ce2d48d971f15b26c2e4cd80ae0a Binary files /dev/null and b/egogpt/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/egogpt/model/__pycache__/builder.cpython-310.pyc b/egogpt/model/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a21661e62cfaad86919ffa6073e64c51e341671 Binary files /dev/null and b/egogpt/model/__pycache__/builder.cpython-310.pyc differ diff --git a/egogpt/model/__pycache__/egogpt_arch.cpython-310.pyc b/egogpt/model/__pycache__/egogpt_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af31e81643804009195ca7855321ab2a35c9d199 Binary files /dev/null and b/egogpt/model/__pycache__/egogpt_arch.cpython-310.pyc differ diff --git a/egogpt/model/builder.py b/egogpt/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..00cf3d56283ac48c3bd470765617f85a0ffb71bf --- /dev/null +++ b/egogpt/model/builder.py @@ -0,0 +1,127 @@ +# Adopted from https://github.com/haotian-liu/LLaVA. We modify the code to support speech input. Below is the original copyright: +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import warnings + +import torch +import torch.distributed as dist +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, +) + +from egogpt.model import * +from egogpt.model.speech_encoder.builder import build_speech_encoder + + +def load_pretrained_model( + model_path, + model_base=None, + is_lora=False, + load_8bit=False, + load_4bit=False, + device="cuda", + use_flash_attn=False, + **kwargs, +): + # if dist.is_available() and not dist.is_initialized(): + # dist.init_process_group(backend='nccl',init_method='env://') + if load_8bit: + kwargs["load_in_8bit"] = True + elif load_4bit: + kwargs["load_in_4bit"] = True + kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + else: + kwargs["torch_dtype"] = torch.float16 + + if use_flash_attn: + kwargs["attn_implementation"] = "flash_attention_2" + + model_cls = EgoGPTQwenForCausalLM + + # Load EgoGPT model + if is_lora: + assert model_base is not None, "model_base is required for LoRA models." + from egogpt.model.language_model.egogpt_llama import EgoGPTConfig + + lora_cfg_pretrained = EgoGPTConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + print("Loading EgoGPT from base model...") + model = model_cls.from_pretrained( + model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs + ) + print("Loading additional EgoGPT weights...") + if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): + non_lora_trainables = torch.load( + os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu" + ) + non_lora_trainables = { + (k[11:] if k.startswith("base_model.") else k): v + for k, v in non_lora_trainables.items() + } + if any(k.startswith("model.model.") for k in non_lora_trainables): + non_lora_trainables = { + (k[6:] if k.startswith("model.") else k): v + for k, v in non_lora_trainables.items() + } + model.load_state_dict(non_lora_trainables, strict=False) + + from peft import PeftModel + + print("Loading LoRA weights...") + model = PeftModel.from_pretrained(model, model_path) + print("Merging LoRA weights...") + model = model.merge_and_unload() + print("Model is loaded...") + elif model_base is not None: + print("Loading EgoGPT from base model...") + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = model_cls.from_pretrained( + model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs + ) + + speech_projector_weights = torch.load( + os.path.join(model_path, "speech_projector.bin"), map_location="cpu" + ) + speech_projector_weights = { + k: v.to(torch.float16) for k, v in speech_projector_weights.items() + } + model.load_state_dict(speech_projector_weights, strict=False) + model = model.to(device=device) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = model_cls.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + model = model.to(device=device) + + context_len = 4096 + # model.get_model().speech_encoder = build_speech_encoder(model.config) + # model.get_model().speech_encoder.to(device=device, dtype=torch.float16) + + # if hasattr(model.config, "max_sequence_length"): + # context_len = model.config.max_sequence_length + # else: + # context_len = 2048 + + return tokenizer, model, context_len diff --git a/egogpt/model/egogpt_arch.py b/egogpt/model/egogpt_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..eabd53ac1caf5ab5c23a0b745d2e185211d9b6e2 --- /dev/null +++ b/egogpt/model/egogpt_arch.py @@ -0,0 +1,1357 @@ +# Adopted from https://github.com/haotian-liu/LLaVA. We modify the code to support speech input. Below is the original copyright: +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import re +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from egogpt.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, SPEECH_TOKEN_INDEX +from egogpt.mm_utils import get_anyres_image_grid_shape +from egogpt.utils import lengths_to_padding_mask, rank0_print, rank_print + +from .multimodal_encoder.builder import build_vision_tower +from .multimodal_projector.builder import build_vision_projector +from .multimodal_resampler.builder import build_vision_resampler +from .speech_encoder.builder import build_speech_encoder +from .speech_projector.builder import build_speech_projector + + +class EgoGPTMetaModel: + def __init__(self, config): + super(EgoGPTMetaModel, self).__init__(config) + + if hasattr(config, "mm_vision_tower"): + delay_load = getattr(config, "delay_load", False) + self.vision_tower = build_vision_tower(config, delay_load=delay_load) + self.vision_resampler = build_vision_resampler( + config, vision_tower=self.vision_tower + ) + self.mm_projector = build_vision_projector( + config, vision_cfg=self.vision_tower.config + ) + + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.image_newline = nn.Parameter( + torch.empty(config.hidden_size, dtype=self.dtype) + ) + + if hasattr(config, "speech_encoder"): + self.speech_encoder = build_speech_encoder(config) + self.speech_projector = build_speech_projector(config) + + def get_vision_tower(self): + vision_tower = getattr(self, "vision_tower", None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def get_speech_encoder(self): + speech_encoder = getattr(self, "speech_encoder", None) + if type(speech_encoder) is list: + speech_encoder = speech_encoder[0] + return speech_encoder + + def initialize_vision_modules(self, model_args, fsdp=None): + vision_tower = model_args.vision_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + mm_patch_merge_type = model_args.mm_patch_merge_type + + self.config.mm_vision_tower = vision_tower + self.config.vision_tower_pretrained = getattr( + model_args, "vision_tower_pretrained", "" + ) + + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) + vision_resampler = build_vision_resampler( + model_args, vision_tower=vision_tower + ) + for k, v in vision_resampler.config.items(): + setattr(self.config, k, v) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + self.vision_resampler = [vision_resampler] + else: + self.vision_tower = vision_tower + self.vision_resampler = vision_resampler + else: + if fsdp is not None and len(fsdp) > 0: + vision_resampler = self.vision_resampler[0] + vision_tower = self.vision_tower[0] + else: + vision_resampler = self.vision_resampler + vision_tower = self.vision_tower + vision_tower.load_model() + + # In case it is frozen by LoRA + for p in self.vision_resampler.parameters(): + p.requires_grad = True + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr( + model_args, "mm_projector_type", "linear" + ) + self.config.mm_hidden_size = getattr( + vision_resampler, "hidden_size", vision_tower.hidden_size + ) + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + self.config.mm_patch_merge_type = mm_patch_merge_type + + if not hasattr(self.config, "add_faster_video"): + if model_args.add_faster_video: + embed_std = 1 / torch.sqrt( + torch.tensor(self.config.hidden_size, dtype=self.dtype) + ) + self.faster_token = nn.Parameter( + torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std + ) + + if getattr(self, "mm_projector", None) is None: + self.mm_projector = build_vision_projector( + self.config, vision_cfg=vision_tower.config + ) + + if "unpad" in mm_patch_merge_type: + embed_std = 1 / torch.sqrt( + torch.tensor(self.config.hidden_size, dtype=self.dtype) + ) + self.image_newline = nn.Parameter( + torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std + ) + else: + # In case it is frozen by LoRA + for p in self.mm_projector.parameters(): + p.requires_grad = True + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load( + pretrain_mm_mlp_adapter, map_location="cpu" + ) + + def get_w(weights, keyword): + return { + k.split(keyword + ".")[1]: v + for k, v in weights.items() + if keyword in k + } + + incompatible_keys = self.mm_projector.load_state_dict( + get_w(mm_projector_weights, "mm_projector") + ) + rank0_print( + f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}" + ) + incompatible_keys = self.vision_resampler.load_state_dict( + get_w(mm_projector_weights, "vision_resampler"), strict=False + ) + rank0_print( + f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}" + ) + + def initialize_speech_modules(self, model_args, fsdp=None): + self.config.speech_encoder = getattr(model_args, "speech_encoder", None) + self.config.speech_encoder_type = getattr( + model_args, "speech_encoder_type", None + ) + self.config.speech_projector_type = getattr( + model_args, "speech_projector_type", "linear" + ) + self.config.speech_encoder_ds_rate = getattr( + model_args, "speech_encoder_ds_rate", 5 + ) + self.config.speech_encoder_hidden_size = getattr( + model_args, "speech_encoder_hidden_size", 1280 + ) + self.config.delay_load_audio = getattr(model_args, "delay_load_audio", True) + + if self.get_speech_encoder() is None: + speech_encoder = build_speech_encoder(self.config) + if fsdp is not None and len(fsdp) > 0: + self.speech_encoder = [speech_encoder] + else: + self.speech_encoder = speech_encoder + else: + if fsdp is not None and len(fsdp) > 0: + speech_encoder = self.speech_encoder[0] + else: + speech_encoder = self.speech_encoder + speech_encoder.load_model(self.config) + + if getattr(self, "speech_projector", None) is None: + self.speech_projector = build_speech_projector(self.config) + else: + # In case it is frozen by LoRA + for p in self.speech_projector.parameters(): + p.requires_grad = True + + if model_args.pretrain_speech_projector is not None: + pretrain_speech_projector_weights = torch.load( + model_args.pretrain_speech_projector, map_location="cpu" + ) + + def get_w(weights, keyword): + return { + k.split(keyword + ".")[1]: v + for k, v in weights.items() + if keyword in k + } + + self.speech_projector.load_state_dict( + get_w(pretrain_speech_projector_weights, "speech_projector"), + strict=False, + ) + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of the image (height, width). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + # Compute aspect ratios + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + # Determine padding size and direction + if original_aspect_ratio > current_aspect_ratio: + # Padding was added to the height + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + # Padding was added to the width + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +class EgoGPTMetaForCausalLM(ABC): + @abstractmethod + def get_model(self): + pass + + def get_speech_encoder(self): + return self.get_model().get_speech_encoder() + + def get_speech_projector(self): + return self.get_model().speech_projector + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_2dPool(self, image_feature, stride=2): + height = width = self.get_vision_tower().num_patches_per_side + num_frames, num_tokens, num_dim = image_feature.shape + image_feature = image_feature.view(num_frames, height, width, -1) + image_feature = image_feature.permute(0, 3, 1, 2).contiguous() + image_feature = nn.functional.avg_pool2d(image_feature, stride) + # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride) + # if self.config.mm_spatial_pool_mode == "average": + # image_feature = nn.functional.avg_pool2d(image_feature, stride) + # elif self.config.mm_spatial_pool_mode == "max": + # image_feature = nn.functional.max_pool2d(image_feature, stride) + # elif self.config.mm_spatial_pool_mode == "bilinear": + # height, width = image_feature.shape[2:] + # scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)] + # image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear') + # else: + # raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}") + image_feature = image_feature.permute(0, 2, 3, 1) + image_feature = image_feature.view(num_frames, -1, num_dim) + return image_feature + + def encode_images(self, images): + image_features = self.get_model().get_vision_tower()(images) + # image_features = self.get_model().vision_resampler(image_features, images=images) + image_features = self.get_model().mm_projector(image_features) + return image_features + + def encode_speech(self, speech, speech_lengths): + # audio cuttting + speech_encoder_type = self.config.speech_encoder_type + speech_encoder = self.get_speech_encoder() + if "whisper" in speech_encoder_type.lower(): + encoder_outs = speech_encoder(speech.permute(0, 2, 1)) + speech_lengths = (speech_lengths + 1) // 2 + else: + raise ValueError(f"Unknown speech encoder: {speech_encoder}") + speech_projector_type = self.config.speech_projector_type + speech_projector = self.get_speech_projector() + if speech_projector_type == "linear": + encoder_outs = speech_projector(encoder_outs) + speech_lengths = speech_lengths // speech_projector.k + else: + raise ValueError(f"Unknown speech projector: {speech_projector_type}") + speech_features = [ + encoder_outs[i, : speech_lengths[i]] for i in range(len(encoder_outs)) + ] + return speech_features + + def add_token_per_grid(self, image_feature): + resize_h = int(math.sqrt(image_feature.shape[1])) + num_frames = image_feature.shape[0] + feature_dim = image_feature.shape[-1] + + image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = torch.cat( + ( + image_feature, + self.model.image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), + ), + dim=-1, + ) + if getattr(self.config, "add_faster_video", False): + # import pdb; pdb.set_trace() + # (3584, 832, 14) -> (3584, 64, 13, 14) + image_feature = image_feature.view(feature_dim, num_frames, resize_h, -1) + # (3584, 64, 13, 14) -> (64, 13, 14, 3584) + image_feature = image_feature.permute(1, 2, 3, 0).contiguous() + # (64, 13, 14, 3584) -> (64, 13*14, 3584) + image_feature = image_feature.flatten(1, 2) + # import pdb; pdb.set_trace() + return image_feature + # import pdb; pdb.set_trace() + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + return image_feature + + def prepare_inputs_labels_for_speech_and_text( + self, + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + speech, + speech_lengths, + images, + image_sizes=None, + modalities=["image"], + ): + vision_tower = self.get_vision_tower() + # rank_print(modalities) + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return ( + input_ids, + position_ids, + attention_mask, + past_key_values, + None, + labels, + ) + speech_encoder = self.get_speech_encoder() + if speech_encoder is None or speech is None or input_ids.shape[1] == 1: + return ( + input_ids, + position_ids, + attention_mask, + past_key_values, + None, + labels, + ) + + speech_features = self.encode_speech(speech, speech_lengths) + + if isinstance(modalities, str): + modalities = [modalities] + + # import pdb; pdb.set_trace() + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + + video_idx_in_batch = [] + for _ in range(len(modalities)): + if modalities[_] == "video": + video_idx_in_batch.append(_) + + # print(f"Images: {images}, {type(images)}, {len(images)}") + # print(f"Video idx in batch: {modalities}") + images_list = [] + for image in images: + if image.ndim == 4: + images_list.append(image) + else: + images_list.append(image.unsqueeze(0)) + + # concat_images = torch.cat([torch.tensor(image) for image in images_list], dim=0) + concat_images = torch.cat([image for image in images_list], dim=0) + split_sizes = [image.shape[0] for image in images_list] + concat_images.requires_grad_(True) + encoded_image_features = self.encode_images(concat_images) + # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + + # This is a list, each element is [num_images, patch * patch, dim] + # rank_print(f"Concat images : {concat_images.shape}") + encoded_image_features = torch.split(encoded_image_features, split_sizes) + image_features = [] + for idx, image_feat in enumerate(encoded_image_features): + if idx in video_idx_in_batch: + image_features.append(self.get_2dPool(image_feat)) + else: + image_features.append(image_feat) + # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}") + # image_features = torch.split(image_features, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + mm_newline_position = getattr( + self.config, "mm_newline_position", "one_token" + ) + + if mm_patch_merge_type == "flat": + image_features = [x.flatten(0, 1) for x in image_features] + + elif mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + # FIXME: now assume the image is square, and split to 2x2 patches + # num_patches = h * w, where h = w = sqrt(num_patches) + # currently image_feature is a tensor of shape (4, num_patches, hidden_size) + # we want to first unflatten it to (2, 2, h, w, hidden_size) + # rank0_print("At least we are reaching here") + # import pdb; pdb.set_trace() + if image_idx in video_idx_in_batch: # video operations + # rank0_print("Video") + if mm_newline_position == "grid": + # Grid-wise + image_feature = self.add_token_per_grid(image_feature) + if getattr(self.config, "add_faster_video", False): + faster_video_feature = self.add_token_per_grid( + all_faster_video_features[image_idx] + ) + # Add a token for each frame + concat_slow_fater_token = [] + # import pdb; pdb.set_trace() + for _ in range(image_feature.shape[0]): + if _ % self.config.faster_token_stride == 0: + concat_slow_fater_token.append( + torch.cat( + ( + image_feature[_], + self.model.faster_token[None].to( + image_feature.device + ), + ), + dim=0, + ) + ) + else: + concat_slow_fater_token.append( + torch.cat( + ( + faster_video_feature[_], + self.model.faster_token[None].to( + image_feature.device + ), + ), + dim=0, + ) + ) + # import pdb; pdb.set_trace() + image_feature = torch.cat(concat_slow_fater_token) + + new_image_features.append(image_feature) + elif mm_newline_position == "frame": + # Frame-wise + image_feature = self.add_token_per_frame(image_feature) + + new_image_features.append(image_feature.flatten(0, 1)) + + elif mm_newline_position == "one_token": + # one-token + image_feature = image_feature.flatten(0, 1) + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat( + ( + image_feature, + self.model.image_newline[None].to( + image_feature.device + ), + ), + dim=0, + ) + new_image_features.append(image_feature) + elif mm_newline_position == "no_token": + new_image_features.append(image_feature.flatten(0, 1)) + else: + raise ValueError( + f"Unexpected mm_newline_position: {mm_newline_position}" + ) + elif ( + image_feature.shape[0] > 1 + ): # multi patches and multi images operations + # rank0_print("Single-images") + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.get_vision_tower().num_patches_per_side + assert height * width == base_image_feature.shape[0] + + if "anyres_max" in image_aspect_ratio: + matched_anyres_max_num_patches = re.match( + r"anyres_max_(\d+)", image_aspect_ratio + ) + if matched_anyres_max_num_patches: + max_num_patches = int( + matched_anyres_max_num_patches.group(1) + ) + + if ( + image_aspect_ratio == "anyres" + or "anyres_max" in image_aspect_ratio + ): + if hasattr(self.get_vision_tower(), "image_size"): + vision_tower_image_size = ( + self.get_vision_tower().image_size + ) + else: + raise ValueError( + "vision_tower_image_size is not found in the vision tower." + ) + try: + ( + num_patch_width, + num_patch_height, + ) = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + vision_tower_image_size, + ) + except Exception as e: + rank0_print(f"Error: {e}") + num_patch_width, num_patch_height = 2, 2 + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + else: + image_feature = image_feature.view(2, 2, height, width, -1) + + if "maxpool2x2" in mm_patch_merge_type: + image_feature = image_feature.permute( + 4, 0, 2, 1, 3 + ).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = nn.functional.max_pool2d(image_feature, 2) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif ( + "unpad" in mm_patch_merge_type + and "anyres_max" in image_aspect_ratio + and matched_anyres_max_num_patches + ): + unit = image_feature.shape[2] + image_feature = image_feature.permute( + 4, 0, 2, 1, 3 + ).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image( + image_feature, image_sizes[image_idx] + ) + c, h, w = image_feature.shape + times = math.sqrt(h * w / (max_num_patches * unit**2)) + if times > 1.1: + image_feature = image_feature[None] + image_feature = nn.functional.interpolate( + image_feature, + [int(h // times), int(w // times)], + mode="bilinear", + )[0] + image_feature = torch.cat( + ( + image_feature, + self.model.image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif "unpad" in mm_patch_merge_type: + image_feature = image_feature.permute( + 4, 0, 2, 1, 3 + ).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image( + image_feature, image_sizes[image_idx] + ) + image_feature = torch.cat( + ( + image_feature, + self.model.image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute( + 0, 2, 1, 3, 4 + ).contiguous() + image_feature = image_feature.flatten(0, 3) + if "nobase" in mm_patch_merge_type: + pass + else: + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + new_image_features.append(image_feature) + else: # single image operations + image_feature = image_feature[0] + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat( + (image_feature, self.model.image_newline[None]), dim=0 + ) + + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError( + f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}" + ) + else: + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr( + self.config, "mm_use_im_start_end", False + ): + raise NotImplementedError + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange( + 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device + ) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [ + cur_input_ids[cur_attention_mask] + for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) + ] + labels = [ + cur_labels[cur_attention_mask] + for cur_labels, cur_attention_mask in zip(labels, attention_mask) + ] + new_input_embeds = [] + new_labels = [] + cur_speech_idx = 0 + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + num_image = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + # if num_speech: + # print("has ") + # if num_image: + # print("has ") + num_speech_images = num_speech + num_image + + if num_speech_images == 0: + cur_speech_features = speech_features[cur_speech_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat( + [cur_input_embeds_1, cur_speech_features[0:0]], dim=0 + ) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_speech_idx += 1 + cur_image_idx += 1 + continue + + multimodal_token_indices = ( + [-1] + + torch.where( + (cur_input_ids == SPEECH_TOKEN_INDEX) + | (cur_input_ids == IMAGE_TOKEN_INDEX) + )[0].tolist() + + [cur_input_ids.shape[0]] + ) + + cur_input_ids_nospeech_image = [] + cur_labels = labels[batch_idx] + cur_labels_nospeech_image = [] + for i in range(len(multimodal_token_indices) - 1): + cur_input_ids_nospeech_image.append( + cur_input_ids[ + multimodal_token_indices[i] + + 1 : multimodal_token_indices[i + 1] + ] + ) + cur_labels_nospeech_image.append( + cur_labels[ + multimodal_token_indices[i] + + 1 : multimodal_token_indices[i + 1] + ] + ) + + split_sizes = [x.shape[0] for x in cur_labels_nospeech_image] + cur_input_embeds = self.get_model().embed_tokens( + torch.cat(cur_input_ids_nospeech_image) + ) + cur_input_embeds_no_speech_image = torch.split( + cur_input_embeds, split_sizes, dim=0 + ) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_speech_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i]) + cur_new_labels.append(cur_labels_nospeech_image[i]) + if i < num_speech_images: + if i < num_image: + cur_images_features = image_features[cur_image_idx] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_images_features) + cur_new_labels.append( + torch.full( + (cur_images_features.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype, + ) + ) + else: + cur_speech_features = speech_features[cur_speech_idx] + cur_speech_idx += 1 + cur_new_input_embeds.append(cur_speech_features) + cur_new_labels.append( + torch.full( + (cur_speech_features.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype, + ) + ) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + if num_image == 0: + cur_new_input_embeds = torch.cat( + [cur_new_input_embeds, image_features[cur_image_idx][0:0]], dim=0 + ) + cur_image_idx += 1 + + if num_speech == 0: + cur_new_input_embeds = torch.cat( + [cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0 + ) + cur_speech_idx += 1 + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as speech features can make the sequence longer + tokenizer_model_max_length = getattr( + self.config, "tokenizer_model_max_length", None + ) + if tokenizer_model_max_length is not None: + new_input_embeds = [ + x[:tokenizer_model_max_length] for x in new_input_embeds + ] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full( + (batch_size, max_len), + IGNORE_INDEX, + dtype=new_labels[0].dtype, + device=new_labels[0].device, + ) + attention_mask = torch.zeros( + (batch_size, max_len), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + position_ids = torch.zeros( + (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device + ) + + for i, (cur_new_embed, cur_new_labels) in enumerate( + zip(new_input_embeds, new_labels) + ): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append( + torch.cat( + ( + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + cur_new_embed, + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + else: + new_input_embeds_padded.append( + torch.cat( + ( + cur_new_embed, + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return ( + None, + position_ids, + attention_mask, + past_key_values, + new_input_embeds, + new_labels, + ) + + def prepare_inputs_labels_for_speech_and_text_debug( + self, + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + speech, + speech_lengths, + images, + image_sizes=None, + modalities=["image"], + ): + # vision_tower = self.get_vision_tower() + # # rank_print(modalities) + # if vision_tower is None or images is None or input_ids.shape[1] == 1: + # return input_ids, position_ids, attention_mask, past_key_values, None, labels + # speech_encoder = self.get_speech_encoder() + # if speech_encoder is None or speech is None or input_ids.shape[1] == 1: + # return input_ids, position_ids, attention_mask, past_key_values, None, labels + + speech_features = self.encode_speech(speech, speech_lengths) + + if isinstance(modalities, str): + modalities = [modalities] + + # import pdb; pdb.set_trace() + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + + video_idx_in_batch = [] + for _ in range(len(modalities)): + if modalities[_] == "video": + video_idx_in_batch.append(_) + + # print(f"Images: {images}, {type(images)}, {len(images)}") + # print(f"Video idx in batch: {modalities}") + images_list = [] + for image in images: + if image.ndim == 4: + images_list.append(image) + else: + images_list.append(image.unsqueeze(0)) + + # concat_images = torch.cat([torch.tensor(image) for image in images_list], dim=0) + concat_images = torch.cat([image for image in images_list], dim=0) + split_sizes = [image.shape[0] for image in images_list] + concat_images.requires_grad_(True) + encoded_image_features = self.encode_images(concat_images) + # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + + # This is a list, each element is [num_images, patch * patch, dim] + # rank_print(f"Concat images : {concat_images.shape}") + encoded_image_features = torch.split(encoded_image_features, split_sizes) + image_features = [] + for idx, image_feat in enumerate(encoded_image_features): + if idx in video_idx_in_batch: + image_features.append(self.get_2dPool(image_feat)) + else: + image_features.append(image_feat) + # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}") + # image_features = torch.split(image_features, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + mm_newline_position = getattr( + self.config, "mm_newline_position", "one_token" + ) + + if mm_patch_merge_type == "flat": + image_features = [x.flatten(0, 1) for x in image_features] + + elif mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + # FIXME: now assume the image is square, and split to 2x2 patches + # num_patches = h * w, where h = w = sqrt(num_patches) + # currently image_feature is a tensor of shape (4, num_patches, hidden_size) + # we want to first unflatten it to (2, 2, h, w, hidden_size) + # rank0_print("At least we are reaching here") + # import pdb; pdb.set_trace() + if image_idx in video_idx_in_batch: # video operations + # rank0_print("Video") + if mm_newline_position == "grid": + # Grid-wise + image_feature = self.add_token_per_grid(image_feature) + new_image_features.append(image_feature) + elif mm_newline_position == "frame": + # Frame-wise + image_feature = self.add_token_per_frame(image_feature) + new_image_features.append(image_feature.flatten(0, 1)) + elif mm_newline_position == "one_token": + # one-token + image_feature = image_feature.flatten(0, 1) + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat( + ( + image_feature, + self.model.image_newline[None].to( + image_feature.device + ), + ), + dim=0, + ) + new_image_features.append(image_feature) + elif mm_newline_position == "no_token": + new_image_features.append(image_feature.flatten(0, 1)) + else: + raise ValueError( + f"Unexpected mm_newline_position: {mm_newline_position}" + ) + elif ( + image_feature.shape[0] > 1 + ): # multi patches and multi images operations + # rank0_print("Single-images") + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.get_vision_tower().num_patches_per_side + assert height * width == base_image_feature.shape[0] + + if "anyres_max" in image_aspect_ratio: + matched_anyres_max_num_patches = re.match( + r"anyres_max_(\d+)", image_aspect_ratio + ) + if matched_anyres_max_num_patches: + max_num_patches = int( + matched_anyres_max_num_patches.group(1) + ) + + if ( + image_aspect_ratio == "anyres" + or "anyres_max" in image_aspect_ratio + ): + if hasattr(self.get_vision_tower(), "image_size"): + vision_tower_image_size = ( + self.get_vision_tower().image_size + ) + else: + raise ValueError( + "vision_tower_image_size is not found in the vision tower." + ) + try: + ( + num_patch_width, + num_patch_height, + ) = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + vision_tower_image_size, + ) + except Exception as e: + rank0_print(f"Error: {e}") + num_patch_width, num_patch_height = 2, 2 + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + else: + image_feature = image_feature.view(2, 2, height, width, -1) + + if "maxpool2x2" in mm_patch_merge_type: + image_feature = image_feature.permute( + 4, 0, 2, 1, 3 + ).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = nn.functional.max_pool2d(image_feature, 2) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif ( + "unpad" in mm_patch_merge_type + and "anyres_max" in image_aspect_ratio + and matched_anyres_max_num_patches + ): + unit = image_feature.shape[2] + image_feature = image_feature.permute( + 4, 0, 2, 1, 3 + ).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image( + image_feature, image_sizes[image_idx] + ) + c, h, w = image_feature.shape + times = math.sqrt(h * w / (max_num_patches * unit**2)) + if times > 1.1: + image_feature = image_feature[None] + image_feature = nn.functional.interpolate( + image_feature, + [int(h // times), int(w // times)], + mode="bilinear", + )[0] + image_feature = torch.cat( + ( + image_feature, + self.model.image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif "unpad" in mm_patch_merge_type: + image_feature = image_feature.permute( + 4, 0, 2, 1, 3 + ).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image( + image_feature, image_sizes[image_idx] + ) + image_feature = torch.cat( + ( + image_feature, + self.model.image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute( + 0, 2, 1, 3, 4 + ).contiguous() + image_feature = image_feature.flatten(0, 3) + if "nobase" in mm_patch_merge_type: + pass + else: + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + new_image_features.append(image_feature) + else: # single image operations + image_feature = image_feature[0] + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat( + (image_feature, self.model.image_newline[None]), dim=0 + ) + + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError( + f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}" + ) + else: + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr( + self.config, "mm_use_im_start_end", False + ): + raise NotImplementedError + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange( + 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device + ) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [ + cur_input_ids[cur_attention_mask] + for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) + ] + labels = [ + cur_labels[cur_attention_mask] + for cur_labels, cur_attention_mask in zip(labels, attention_mask) + ] + new_input_embeds = [] + new_labels = [] + cur_speech_idx = 0 + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + num_image = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + if num_speech + num_image == 0: + cur_speech_features = speech_features[cur_speech_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat( + [cur_input_embeds_1, cur_speech_features[0:0]], dim=0 + ) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_speech_idx += 1 + cur_image_idx += 1 + continue + + multimodal_token_indices = sorted( + [-1] + + torch.where(cur_input_ids == SPEECH_TOKEN_INDEX)[0].tolist() + + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + + [cur_input_ids.shape[0]] + ) + cur_input_ids_nospeech = [] + cur_labels = labels[batch_idx] + cur_labels_nospeech = [] + for i in range(len(multimodal_token_indices) - 1): + cur_input_ids_nospeech.append( + cur_input_ids[ + multimodal_token_indices[i] + + 1 : multimodal_token_indices[i + 1] + ] + ) + cur_labels_nospeech.append( + cur_labels[ + multimodal_token_indices[i] + + 1 : multimodal_token_indices[i + 1] + ] + ) + + split_sizes = [x.shape[0] for x in cur_labels_nospeech] + cur_input_embeds = self.get_model().embed_tokens( + torch.cat(cur_input_ids_nospeech) + ) + cur_input_embeds_no_speech = torch.split( + cur_input_embeds, split_sizes, dim=0 + ) + cur_new_input_embeds = [] + cur_new_labels = [] + for i in range(num_speech + num_image + 1): + cur_new_input_embeds.append(cur_input_embeds_no_speech[i]) + cur_new_labels.append(cur_labels_nospeech[i]) + if cur_speech_idx < num_speech: + try: + cur_speech_features = speech_features[cur_speech_idx] + except: + cur_speech_features = speech_features[cur_speech_idx - 1] + cur_speech_idx += 1 + cur_new_input_embeds.append(cur_speech_features) + cur_new_labels.append( + torch.full( + (cur_speech_features.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype, + ) + ) + if cur_image_idx < num_image: + try: + cur_image_features = image_features[cur_image_idx] + except: + cur_image_features = image_features[cur_image_idx - 1] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append( + torch.full( + (cur_image_features.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype, + ) + ) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as speech features can make the sequence longer + tokenizer_model_max_length = getattr( + self.config, "tokenizer_model_max_length", None + ) + if tokenizer_model_max_length is not None: + new_input_embeds = [ + x[:tokenizer_model_max_length] for x in new_input_embeds + ] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full( + (batch_size, max_len), + IGNORE_INDEX, + dtype=new_labels[0].dtype, + device=new_labels[0].device, + ) + attention_mask = torch.zeros( + (batch_size, max_len), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + position_ids = torch.zeros( + (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device + ) + + for i, (cur_new_embed, cur_new_labels) in enumerate( + zip(new_input_embeds, new_labels) + ): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append( + torch.cat( + ( + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + cur_new_embed, + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + else: + new_input_embeds_padded.append( + torch.cat( + ( + cur_new_embed, + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + print(f"new_input_embeds: {new_input_embeds[0].shape}") + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return ( + None, + position_ids, + attention_mask, + past_key_values, + new_input_embeds, + new_labels, + ) diff --git a/egogpt/model/language_model/__pycache__/egogpt_llama.cpython-310.pyc b/egogpt/model/language_model/__pycache__/egogpt_llama.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfa06ce8630c65d9ea8ccbe309c9235400a43011 Binary files /dev/null and b/egogpt/model/language_model/__pycache__/egogpt_llama.cpython-310.pyc differ diff --git a/egogpt/model/language_model/__pycache__/egogpt_qwen.cpython-310.pyc b/egogpt/model/language_model/__pycache__/egogpt_qwen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d4cc83b9695390eab5e91789f473f97ce5331ba Binary files /dev/null and b/egogpt/model/language_model/__pycache__/egogpt_qwen.cpython-310.pyc differ diff --git a/egogpt/model/language_model/egogpt_llama.py b/egogpt/model/language_model/egogpt_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5077c24cf8e42ce8df824312de4904903e3f3d --- /dev/null +++ b/egogpt/model/language_model/egogpt_llama.py @@ -0,0 +1,159 @@ +# Adopted from https://github.com/haotian-liu/LLaVA. We modify the code to support speech input. Below is the original copyright: +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + LlamaConfig, + LlamaForCausalLM, + LlamaModel, +) +from transformers.generation.utils import GenerateOutput +from transformers.modeling_outputs import CausalLMOutputWithPast + +from ..egogpt_arch import EgoGPTMetaForCausalLM, EgoGPTMetaModel + + +class EgoGPTConfig(LlamaConfig): + model_type = "egogpt_llama" + + +class EgoGPTLlamaModel(EgoGPTMetaModel, LlamaModel): + config_class = EgoGPTConfig + + def __init__(self, config: LlamaConfig): + super(EgoGPTLlamaModel, self).__init__(config) + + +class EgoGPTLlamaForCausalLM(LlamaForCausalLM, EgoGPTMetaForCausalLM): + config_class = EgoGPTConfig + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = EgoGPTLlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + speech: Optional[torch.FloatTensor] = None, + speech_lengths: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + ) = self.prepare_inputs_labels_for_speech_and_text( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + speech, + speech_lengths, + ) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + speech: Optional[torch.Tensor] = None, + speech_lengths: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if speech is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _, + ) = self.prepare_inputs_labels_for_speech_and_text( + inputs, position_ids, attention_mask, None, None, speech, speech_lengths + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): + speech = kwargs.pop("speech", None) + speech_lengths = kwargs.pop("speech_lengths", None) + inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + if speech is not None: + inputs["speech"] = speech + inputs["speech_lengths"] = speech_lengths + return inputs + + +AutoConfig.register("egogpt_llama", EgoGPTConfig) +AutoModelForCausalLM.register(EgoGPTConfig, EgoGPTLlamaForCausalLM) diff --git a/egogpt/model/language_model/egogpt_qwen.py b/egogpt/model/language_model/egogpt_qwen.py new file mode 100644 index 0000000000000000000000000000000000000000..25c2a199ed299e4260c3f203ed594d5781780804 --- /dev/null +++ b/egogpt/model/language_model/egogpt_qwen.py @@ -0,0 +1,164 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import transformers +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + Qwen2Config, + Qwen2ForCausalLM, + Qwen2Model, +) +from transformers.generation.utils import GenerateOutput +from transformers.modeling_outputs import CausalLMOutputWithPast + +from ..egogpt_arch import EgoGPTMetaForCausalLM, EgoGPTMetaModel + + +class EgoGPTConfigQwen(Qwen2Config): + model_type = "egogpt_qwen" + + +class EgoGPTQwenModel(EgoGPTMetaModel, Qwen2Model): + config_class = EgoGPTConfigQwen + + def __init__(self, config: Qwen2Config): + super(EgoGPTQwenModel, self).__init__(config) + + +class EgoGPTQwenForCausalLM(Qwen2ForCausalLM, EgoGPTMetaForCausalLM): + config_class = EgoGPTConfigQwen + + def __init__(self, config): + super(Qwen2ForCausalLM, self).__init__(config) + + config.rope_scaling = None + self.model = EgoGPTQwenModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + speech: Optional[torch.FloatTensor] = None, + speech_lengths: Optional[torch.LongTensor] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + modalities: Optional[List[str]] = ["image"], + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + ) = self.prepare_inputs_labels_for_speech_and_text( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + speech, + speech_lengths, + images, + image_sizes, + modalities, + ) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + speech: Optional[torch.Tensor] = None, + speech_lengths: Optional[torch.Tensor] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + modalities: Optional[List[str]] = ["image"], + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if speech is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _, + ) = self.prepare_inputs_labels_for_speech_and_text( + inputs, + position_ids, + attention_mask, + None, + None, + speech, + speech_lengths, + images, + image_sizes, + modalities, + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): + speech = kwargs.pop("speech", None) + speech_lengths = kwargs.pop("speech_lengths", None) + inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + if speech is not None: + inputs["speech"] = speech + inputs["speech_lengths"] = speech_lengths + return inputs + + +AutoConfig.register("egogpt_qwen", EgoGPTConfigQwen) +AutoModelForCausalLM.register(EgoGPTConfigQwen, EgoGPTQwenForCausalLM) diff --git a/egogpt/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc b/egogpt/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27e4dcf84f4e876c6d4559aaed1892df24d4400b Binary files /dev/null and b/egogpt/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc differ diff --git a/egogpt/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc b/egogpt/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d1d3f776e80712e6c22f65c0ae5431326a2b7b3 Binary files /dev/null and b/egogpt/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc differ diff --git a/egogpt/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc b/egogpt/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a2ca802787c87844ed4d221891eae0e3d6332ed Binary files /dev/null and b/egogpt/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc differ diff --git a/egogpt/model/multimodal_encoder/builder.py b/egogpt/model/multimodal_encoder/builder.py new file mode 100755 index 0000000000000000000000000000000000000000..483b0a50d0835d5608cd0c5b0f33c91a2a74c326 --- /dev/null +++ b/egogpt/model/multimodal_encoder/builder.py @@ -0,0 +1,36 @@ +import os + +from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 +from .siglip_encoder import SigLipVisionTower + +# from .eva_clip.eva_clip_encoder import EvaClipVisionTower +# from .dev_eva_clip.eva_vit import EvaViTWrapper + + +def build_vision_tower(vision_tower_cfg, **kwargs): + vision_tower = getattr( + vision_tower_cfg, + "mm_vision_tower", + getattr(vision_tower_cfg, "vision_tower", None), + ) + is_absolute_path_exists = os.path.exists(vision_tower) + use_s2 = getattr(vision_tower_cfg, "s2", False) + # if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: + if ( + vision_tower.startswith("openai") + or vision_tower.startswith("laion") + or "ShareGPT4V" in vision_tower + ): + if use_s2: + return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) + else: + return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + elif ( + "siglip" in vision_tower.lower() + or "open_clip_pytorch_model.bin" in vision_tower + ): + return SigLipVisionTower( + vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs + ) + + raise ValueError(f"Unknown vision tower: {vision_tower}") diff --git a/egogpt/model/multimodal_encoder/clip_encoder.py b/egogpt/model/multimodal_encoder/clip_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..57fe652a7c00ebd978052728397ff13ed4d4d43a --- /dev/null +++ b/egogpt/model/multimodal_encoder/clip_encoder.py @@ -0,0 +1,235 @@ +import torch +import torch.nn as nn +from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel + +from egogpt.utils import rank0_print + +try: + from s2wrapper import forward as multiscale_forward +except: + pass + + +class CLIPVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + if not delay_load: + rank0_print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + rank0_print( + f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True." + ) + self.load_model() + elif ( + hasattr(args, "mm_tunable_parts") + and "mm_vision_tower" in args.mm_tunable_parts + ): + rank0_print( + f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`." + ) + self.load_model() + else: + self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) + + def load_model(self, device_map=None): + if self.is_loaded: + rank0_print( + "{} is already loaded, `load_model` called again, skipping.".format( + self.vision_tower_name + ) + ) + return + + self.image_processor = CLIPImageProcessor.from_pretrained( + self.vision_tower_name + ) + self.vision_tower = CLIPVisionModel.from_pretrained( + self.vision_tower_name, device_map=device_map + ) + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + select_feature_type = self.select_feature + + if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: + select_every_k_layer = len(image_forward_outs.hidden_states) // 4 + image_features = torch.cat( + [ + image_forward_outs.hidden_states[i] + for i in range( + select_every_k_layer + self.select_layer, + len(image_forward_outs.hidden_states), + select_every_k_layer, + ) + ], + dim=-1, + ) + select_feature_type = select_feature_type.replace("slicefour_", "") + elif self.select_feature in [ + "slice_m25811_f6_patch", + "slice_m25811_f6_cls_patch", + ]: + select_layers = [-2, -5, -8, -11, 6] + image_features = torch.cat( + [image_forward_outs.hidden_states[i] for i in select_layers], dim=-1 + ) + select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if select_feature_type == "patch": + image_features = image_features[:, 1:] + elif select_feature_type == "cls_patch": + image_features = image_features + else: + raise ValueError(f"Unexpected select feature: {select_feature_type}") + return image_features + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower( + image.to(device=self.device, dtype=self.dtype).unsqueeze(0), + output_hidden_states=True, + ) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower( + images.to(device=self.device, dtype=self.dtype), + output_hidden_states=True, + ) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + _hidden_size = self.config.hidden_size + if "slicefour" in self.select_feature: + _hidden_size *= 4 + if "slice_m25811_f6" in self.select_feature: + _hidden_size *= 5 + return _hidden_size + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def num_patches(self): + _num_patches = (self.config.image_size // self.config.patch_size) ** 2 + if "cls_patch" in self.select_feature: + _num_patches += 1 + return _num_patches + + @property + def image_size(self): + return self.config.image_size + + +class CLIPVisionTowerS2(CLIPVisionTower): + def __init__(self, vision_tower, args, delay_load=False): + self.s2_scales = getattr(args, "s2_scales", "336,672,1008") + self.s2_scales = list(map(int, self.s2_scales.split(","))) + self.s2_scales.sort() + self.s2_split_size = self.s2_scales[0] + self.s2_image_size = self.s2_scales[-1] + + super().__init__(vision_tower, args, delay_load) + + # change resize/crop size in preprocessing to the largest image size in s2_scale + if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False): + self.image_processor.size["shortest_edge"] = self.s2_image_size + self.image_processor.crop_size["height"] = self.image_processor.crop_size[ + "width" + ] = self.s2_image_size + + def load_model(self, device_map=None): + if self.is_loaded: + rank0_print( + "{} is already loaded, `load_model` called again, skipping.".format( + self.vision_tower_name + ) + ) + return + + self.image_processor = CLIPImageProcessor.from_pretrained( + self.vision_tower_name + ) + self.vision_tower = CLIPVisionModel.from_pretrained( + self.vision_tower_name, device_map=device_map + ) + self.vision_tower.requires_grad_(False) + + self.image_processor.size["shortest_edge"] = self.s2_image_size + self.image_processor.crop_size["height"] = self.image_processor.crop_size[ + "width" + ] = self.s2_image_size + + self.is_loaded = True + + def forward_feature(self, images): + image_forward_outs = self.vision_tower( + images.to(device=self.device, dtype=self.dtype), output_hidden_states=True + ) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + return image_features + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_feature = multiscale_forward( + self.forward_feature, + image.unsqueeze(0), + img_sizes=self.s2_scales, + max_split_size=self.s2_split_size, + split_forward=True, + ) + image_features.append(image_feature) + else: + image_features = multiscale_forward( + self.forward_feature, + images, + img_sizes=self.s2_scales, + max_split_size=self.s2_split_size, + split_forward=True, + ) + + return image_features + + @property + def hidden_size(self): + return self.config.hidden_size * len(self.s2_scales) diff --git a/egogpt/model/multimodal_encoder/siglip_encoder.py b/egogpt/model/multimodal_encoder/siglip_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..880216bd0623308379fcfabd33429aebf3bd57a3 --- /dev/null +++ b/egogpt/model/multimodal_encoder/siglip_encoder.py @@ -0,0 +1,742 @@ +""" +# Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py +""" + +import os +from dataclasses import dataclass +from functools import partial, reduce +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from PIL import Image +from torch import nn +from transformers import PretrainedConfig +from transformers.activations import ACT2FN +from transformers.image_processing_utils import BatchFeature, get_size_dict +from transformers.image_transforms import ( + convert_to_rgb, + normalize, + rescale, + resize, + to_channel_dimension_format, +) +from transformers.image_utils import ( + ChannelDimension, + PILImageResampling, + to_numpy_array, +) +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput + +from egogpt.utils import rank0_print + + +class SigLipImageProcessor: + def __init__( + self, + image_mean=(0.5, 0.5, 0.5), + image_std=(0.5, 0.5, 0.5), + size=(384, 384), + crop_size: Dict[str, int] = None, + resample=PILImageResampling.BICUBIC, + rescale_factor=1 / 255, + data_format=ChannelDimension.FIRST, + ): + crop_size = ( + crop_size if crop_size is not None else {"height": 384, "width": 384} + ) + crop_size = get_size_dict( + crop_size, default_to_square=True, param_name="crop_size" + ) + + self.image_mean = image_mean + self.image_std = image_std + self.size = size + self.resample = resample + self.rescale_factor = rescale_factor + self.data_format = data_format + self.crop_size = crop_size + + def preprocess(self, images, return_tensors): + if isinstance(images, Image.Image): + images = [images] + else: + # to adapt video data + images = [to_numpy_array(image) for image in images] + assert isinstance(images, list) + + transforms = [ + convert_to_rgb, + to_numpy_array, + partial( + resize, + size=self.size, + resample=self.resample, + data_format=self.data_format, + ), + partial(rescale, scale=self.rescale_factor, data_format=self.data_format), + partial( + normalize, + mean=self.image_mean, + std=self.image_std, + data_format=self.data_format, + ), + partial( + to_channel_dimension_format, + channel_dim=self.data_format, + input_channel_dim=self.data_format, + ), + ] + + images = reduce(lambda x, f: [*map(f, x)], transforms, images) + data = {"pixel_values": images} + + return BatchFeature(data=data, tensor_type=return_tensors) + + +class SigLipVisionConfig(PretrainedConfig): + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=1152, + image_mean=(0.5, 0.5, 0.5), + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=384, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.image_mean = image_mean + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + # get the vision config dict if we are loading from SigLipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + print( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip +class SigLipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class SigLipVisionEmbeddings(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + patch_embeds = self.patch_embedding( + pixel_values + ) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class SigLipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + batch_size, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + batch_size, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + batch_size, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + ) + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip +class SigLipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip +class SigLipEncoderLayer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SigLipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SigLipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SigLipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + pass + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip +class SigLipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SigLipEncoderLayer`]. + + Args: + config: SigLipVisionConfig + """ + + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, encoder_states, all_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class SigLipVisionTransformer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SigLipVisionEmbeddings(config) + self.encoder = SigLipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = SigLipMultiheadAttentionPoolingHead(config) + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SigLipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SigLipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True + ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class SigLipVisionModel(SigLipPreTrainedModel): + config_class = SigLipVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["SigLipEncoderLayer"] + + def __init__(self, config: SigLipVisionConfig): + super().__init__(config) + + self.vision_model = SigLipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SigLipVisionModel + + >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SigLipVisionTower(nn.Module): + def __init__(self, vision_tower, vision_tower_cfg, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.config = SigLipVisionConfig() + + self.vision_tower_name = vision_tower + + self.image_processor = SigLipImageProcessor() + + if not delay_load: + rank0_print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + rank0_print( + f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True." + ) + self.load_model() + elif ( + hasattr(vision_tower_cfg, "mm_tunable_parts") + and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts + ): + rank0_print( + f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`." + ) + self.load_model() + else: + self.cfg_only = self.config + + def load_model(self, device_map=None): + if self.is_loaded: + rank0_print( + "{} is already loaded, `load_model` called again, skipping.".format( + self.vision_tower_name + ) + ) + return + + self.vision_tower = SigLipVisionModel.from_pretrained( + self.vision_tower_name, device_map=device_map + ) + + del self.vision_tower.vision_model.encoder.layers[-1:] + self.vision_tower.vision_model.head = nn.Identity() + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower( + image.to(device=self.device, dtype=self.dtype).unsqueeze(0), + output_hidden_states=True, + ) + image_feature = image_forward_out.hidden_states[-1].to(image.dtype) + assert image_features.shape[-2] == 729 + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower( + images.to(device=self.device, dtype=self.dtype), + output_hidden_states=True, + ) + image_features = image_forward_outs.hidden_states[-1].to(images.dtype) + assert image_features.shape[-2] == 729 + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + for p in self.vision_tower.parameters(): + return p.dtype + + @property + def device(self): + for p in self.vision_tower.parameters(): + return p.device + + @property + def hidden_size(self): + return self.config.hidden_size + + @property + def num_patches(self): + return (self.config.image_size // self.config.patch_size) ** 2 + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] + + @property + def image_size(self): + return self.config.image_size diff --git a/egogpt/model/multimodal_projector/__pycache__/builder.cpython-310.pyc b/egogpt/model/multimodal_projector/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18e6274eb2e8d88939c231e564901d31d388a370 Binary files /dev/null and b/egogpt/model/multimodal_projector/__pycache__/builder.cpython-310.pyc differ diff --git a/egogpt/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc b/egogpt/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee3bdbe1613efb0a48bf0eedd9ae12b5851d47de Binary files /dev/null and b/egogpt/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc differ diff --git a/egogpt/model/multimodal_projector/builder.py b/egogpt/model/multimodal_projector/builder.py new file mode 100755 index 0000000000000000000000000000000000000000..afb3e21b3751a4adac2eef455a99bb2bc13304ef --- /dev/null +++ b/egogpt/model/multimodal_projector/builder.py @@ -0,0 +1,68 @@ +import re + +import torch +import torch.nn as nn + +from .pooler_projector import PoolerProjector + + +class IdentityMap(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_projector_type": "identity"} + + +class SimpleResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.pre_norm = nn.LayerNorm(channels) + + self.proj = nn.Sequential( + nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) + ) + + def forward(self, x): + x = self.pre_norm(x) + return x + self.proj(x) + + +def build_vision_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, "mm_projector_type", "linear") + + if projector_type == "linear": + return nn.Linear(config.mm_hidden_size, config.hidden_size) + + if projector_type == "pooler": + return PoolerProjector(config, kwargs["vision_cfg"]) + + mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) + if mlp_gelu_resnet_match: + mlp_depth = int(mlp_gelu_resnet_match.group(1)) + res_depth = int(mlp_gelu_resnet_match.group(2)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + for _ in range(res_depth): + modules.append(SimpleResBlock(config.hidden_size)) + return nn.Sequential(*modules) + + if projector_type == "identity": + return IdentityMap() + + raise ValueError(f"Unknown projector type: {projector_type}") diff --git a/egogpt/model/multimodal_projector/pooler_projector.py b/egogpt/model/multimodal_projector/pooler_projector.py new file mode 100755 index 0000000000000000000000000000000000000000..df0f95c269ee2421f844017fe7b965c16784f557 --- /dev/null +++ b/egogpt/model/multimodal_projector/pooler_projector.py @@ -0,0 +1,34 @@ +import math + +import torch +import torch.nn as nn +from transformers.models.clip.modeling_clip import CLIPVisionModel + + +class PoolerProjector(nn.Module): + def __init__(self, config, vision_cfg): + super().__init__() + self._config = config + self.hw = vision_cfg.image_size // vision_cfg.patch_size + + self.conv_pool = nn.Conv2d( + config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2 + ) + + self.proj = nn.Sequential( + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size), + ) + + def forward(self, x, *args, **kwargs): + height = width = self.hw + assert height * width == x.shape[1] + x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) + x = self.conv_pool(x) + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + @property + def config(self): + return {"mm_projector_type": "pooler"} diff --git a/egogpt/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc b/egogpt/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e98e3a27f7a17c63e452e77866ed89f0b832d390 Binary files /dev/null and b/egogpt/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc differ diff --git a/egogpt/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc b/egogpt/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab6ce03499fa17542ecf3a3ede3bf0f37ac3f2f5 Binary files /dev/null and b/egogpt/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc differ diff --git a/egogpt/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc b/egogpt/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef07a21d9f6a39b4928537df7ae5e8c2b02af724 Binary files /dev/null and b/egogpt/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc differ diff --git a/egogpt/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc b/egogpt/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc5d8b9bd34dab7d42087e9d6a648ec7d9a4e76b Binary files /dev/null and b/egogpt/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc differ diff --git a/egogpt/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc b/egogpt/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff18a11197fbbb5c7b86937bcacb5fcb75c18a65 Binary files /dev/null and b/egogpt/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc differ diff --git a/egogpt/model/multimodal_resampler/builder.py b/egogpt/model/multimodal_resampler/builder.py new file mode 100755 index 0000000000000000000000000000000000000000..c5eeeb49285a4d0367743a536490c00d64326e88 --- /dev/null +++ b/egogpt/model/multimodal_resampler/builder.py @@ -0,0 +1,34 @@ +import torch + +from .masked_drop import MaskedDrop +from .perceiver import PerceiverResampler +from .qformer import Qformer +from .spatial_pool import SpatialPool + + +class IdentityMap(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_resampler_type": None} + + +def build_vision_resampler(model_args, delay_load=False, **kwargs): + resampler_type = getattr(model_args, "mm_resampler_type", None) + if resampler_type == "masked_drop": + return MaskedDrop(model_args) + elif resampler_type == "spatial_pool": + return SpatialPool(model_args, **kwargs) + elif resampler_type == "perceiver": + return PerceiverResampler(model_args, **kwargs) + elif resampler_type == "qformer": + return Qformer(model_args, **kwargs) + elif resampler_type is None: + return IdentityMap() + + raise ValueError(f"Unknown resampler type: {resampler_type}") diff --git a/egogpt/model/multimodal_resampler/masked_drop.py b/egogpt/model/multimodal_resampler/masked_drop.py new file mode 100755 index 0000000000000000000000000000000000000000..3a45fbf96b927d397cc6cf32c7b7c7e30bdb781b --- /dev/null +++ b/egogpt/model/multimodal_resampler/masked_drop.py @@ -0,0 +1,89 @@ +import random + +import torch +import torch.nn as nn + + +class MaskedDrop(nn.Module): + def __init__(self, model_args): + super().__init__() + + self.mode = model_args.mm_mask_drop_mode + self.skip_percentage = model_args.mm_mask_drop_skip_percentage + self.ratio = model_args.mm_mask_drop_ratio + self.ratio_upper = model_args.mm_mask_drop_ratio_upper + self.ratio_lower = model_args.mm_mask_drop_ratio_lower + + def forward(self, image_features, *args, **kwargs): + if not self.training: + return image_features + + if self.skip_percentage > random.random(): + return image_features + + masked_features = [] + + for image_feature in image_features: + num_tokens = image_feature.shape[0] + if self.mode == "fixed": + num_keep = int(num_tokens * self.ratio) + masked_features.append( + self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0] + ) + elif self.mode == "range": + num_keep = int( + num_tokens * random.uniform(self.ratio_lower, self.ratio_upper) + ) + masked_features.append( + self.random_masking(image_feature.unsqueeze(0), num_keep)[0] + ) + elif self.mode == "cls_only": + masked_features.append(image_feature[0:1]) + else: + raise ValueError(f"Unexpected masked drop mode: {self.mode}") + + if self.mode not in ["range"] and ( + type(image_features) is not list or self.mode in ["cls_only"] + ): + masked_features = torch.stack(masked_features, dim=0) + + return masked_features + + @property + def config(self): + return { + "mm_resampler_type": "masked_drop", + "mm_mask_drop_mode": self.mode, + "mm_mask_drop_skip_percentage": self.skip_percentage, + "mm_mask_drop_ratio": self.ratio, + "mm_mask_drop_ratio_upper": self.ratio_upper, + "mm_mask_drop_ratio_lower": self.ratio_lower, + } + + def random_masking(self, x, len_keep): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore diff --git a/egogpt/model/multimodal_resampler/perceiver.py b/egogpt/model/multimodal_resampler/perceiver.py new file mode 100755 index 0000000000000000000000000000000000000000..a81a85fcf33ebe38ababa8222c75d88927bc542f --- /dev/null +++ b/egogpt/model/multimodal_resampler/perceiver.py @@ -0,0 +1,172 @@ +""" +Taken from https://github.com/lucidrains/flamingo-pytorch +""" + +import torch +from einops import rearrange, repeat + +try: + from einops_exts import rearrange_many +except: + pass + +from torch import einsum, nn + + +def exists(val): + return val is not None + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, T, n1, D) + latent (torch.Tensor): latent features + shape (b, T, n2, D) + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) + q = q * self.scale + + # attention + sim = einsum("... i d, ... j d -> ... i j", q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h t n d -> b t n (h d)", h=h) + return self.to_out(out) + + +class PerceiverResamplerModule(nn.Module): + def __init__( + self, + *, + dim, + depth=6, + dim_head=64, + heads=8, + num_latents=64, + max_num_media=None, + max_num_frames=None, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + self.frame_embs = ( + nn.Parameter(torch.randn(max_num_frames, dim)) + if exists(max_num_frames) + else None + ) + self.media_time_embs = ( + nn.Parameter(torch.randn(max_num_media, 1, dim)) + if exists(max_num_media) + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult) + if ff_mult > 0 + else nn.Identity(), + ] + ) + ) + + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + """ + Args: + x (torch.Tensor): image features + shape (b, T, F, v, D) + Returns: + shape (b, T, n, D) where n is self.num_latents + """ + b, T, F, v = x.shape[:4] + + # frame and media time embeddings + if exists(self.frame_embs): + frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) + x = x + frame_embs + x = rearrange( + x, "b T F v d -> b T (F v) d" + ) # flatten the frame and spatial dimensions + if exists(self.media_time_embs): + x = x + self.media_time_embs[:T] + + # blocks + latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + return self.norm(latents) + + +class PerceiverResampler(nn.Module): + def __init__(self, model_args, vision_tower): + super().__init__() + + self.depth = model_args.mm_perceiver_depth + self.num_latents = model_args.mm_perceiver_latents + self.ff_mult = model_args.mm_perceiver_ff_mult + self.pretrained = model_args.mm_perceiver_pretrained + + self.perceiver = PerceiverResamplerModule( + dim=vision_tower.hidden_size, + depth=self.depth, + num_latents=self.num_latents, + ff_mult=self.ff_mult, + ) + + if self.pretrained is not None: + self.load_state_dict(torch.load(self.pretrained)) + + def forward(self, image_features, *args, **kwargs): + return self.perceiver(image_features[:, None, None]).squeeze(1) + + @property + def config(self): + return { + "mm_resampler_type": "perceiver", + "mm_perceiver_depth": self.depth, + "mm_perceiver_latents": self.num_latents, + "mm_perceiver_ff_mult": self.ff_mult, + "mm_perceiver_pretrained": self.pretrained, + } diff --git a/egogpt/model/multimodal_resampler/qformer.py b/egogpt/model/multimodal_resampler/qformer.py new file mode 100755 index 0000000000000000000000000000000000000000..c4124917ba67863d6e4ef15b50042472bd70d554 --- /dev/null +++ b/egogpt/model/multimodal_resampler/qformer.py @@ -0,0 +1,1281 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import Tensor, device, dtype, nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.file_utils import ModelOutput +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.models.bert.configuration_bert import BertConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qformer(nn.Module): + def __init__(self, model_args, vision_tower): + super().__init__() + + self.depth = model_args.mm_qformer_depth + self.num_latents = model_args.mm_qformer_latents + self.pretrained = model_args.mm_qformer_pretrained + + self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer( + vision_tower.hidden_size, self.depth, self.num_latents + ) + + if self.pretrained is not None: + pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"] + pretrained_dict = { + k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj") + } + self.load_state_dict(pretrained_dict) + + def build_Qformer(self, vision_width, cross_attention_freq, num_query_token): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + Qformer.cls = None + Qformer.bert.embeddings.word_embeddings = None + Qformer.bert.embeddings.position_embeddings = None + for layer in Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + return Qformer, query_tokens, nn.LayerNorm(vision_width) + + def forward(self, image_features, *args, **kwargs): + x = self.ln_vision(image_features) + image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device) + + query_tokens = self.query_tokens.expand(x.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=x, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + return query_output.last_hidden_state + + @property + def hidden_size(self): + return 768 + + @property + def config(self): + return { + "mm_resampler_type": "qformer", + "mm_qformer_depth": self.depth, + "mm_qformer_latents": self.num_latents, + "mm_qformer_pretrained": self.pretrained, + } diff --git a/egogpt/model/multimodal_resampler/spatial_pool.py b/egogpt/model/multimodal_resampler/spatial_pool.py new file mode 100755 index 0000000000000000000000000000000000000000..f508afd3dbe7d7e508755efce617ab79e6c57229 --- /dev/null +++ b/egogpt/model/multimodal_resampler/spatial_pool.py @@ -0,0 +1,57 @@ +import math + +import torch +import torch.nn as nn + + +class SpatialPool(nn.Module): + def __init__(self, model_args, vision_tower): + super().__init__() + + self.mode = model_args.mm_spatial_pool_mode + self.stride = model_args.mm_spatial_pool_stride + self.out_channels = getattr( + model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size + ) + + if self.mode == "average": + self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) + elif self.mode == "max": + self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) + elif self.mode == "conv": + self.pool = nn.Conv2d( + in_channels=vision_tower.hidden_size, + out_channels=self.out_channels, + kernel_size=self.stride, + stride=self.stride, + ) + else: + raise ValueError(f"Unknown pooling mode: {self.pool}.") + + def forward(self, image_features, images, *args, **kwargs): + ori_W = int( + math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2]) + ) + ori_H = int(ori_W * images.shape[2] // images.shape[3]) + + B, _, F = image_features.shape + + image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute( + 0, 3, 1, 2 + ) + image_features_spatial_pool = self.pool(image_features_spatial) + + return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() + + @property + def config(self): + return { + "mm_resampler_type": "spatial_pool", + "mm_spatial_pool_stride": self.stride, + "mm_spatial_pool_mode": self.mode, + "mm_spatial_pool_out_channels": self.out_channels, + } + + @property + def hidden_size(self): + return self.out_channels diff --git a/egogpt/model/speech_encoder/__pycache__/audio.cpython-310.pyc b/egogpt/model/speech_encoder/__pycache__/audio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0758dc8971bc4e46c9c378704e2a7a1cd3170af0 Binary files /dev/null and b/egogpt/model/speech_encoder/__pycache__/audio.cpython-310.pyc differ diff --git a/egogpt/model/speech_encoder/__pycache__/builder.cpython-310.pyc b/egogpt/model/speech_encoder/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d49f9d1cf746f479c7045a8f10ae814815726e4a Binary files /dev/null and b/egogpt/model/speech_encoder/__pycache__/builder.cpython-310.pyc differ diff --git a/egogpt/model/speech_encoder/__pycache__/decoding.cpython-310.pyc b/egogpt/model/speech_encoder/__pycache__/decoding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b14214c8a2c22ae87c6d6219d3bf5697fd58c81 Binary files /dev/null and b/egogpt/model/speech_encoder/__pycache__/decoding.cpython-310.pyc differ diff --git a/egogpt/model/speech_encoder/__pycache__/model.cpython-310.pyc b/egogpt/model/speech_encoder/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ffb4e62b53f59c84a499be407ae1e02b9bef437 Binary files /dev/null and b/egogpt/model/speech_encoder/__pycache__/model.cpython-310.pyc differ diff --git a/egogpt/model/speech_encoder/__pycache__/speech_encoder.cpython-310.pyc b/egogpt/model/speech_encoder/__pycache__/speech_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14b3ff8dd478507c15e39b97496162de8c95db82 Binary files /dev/null and b/egogpt/model/speech_encoder/__pycache__/speech_encoder.cpython-310.pyc differ diff --git a/egogpt/model/speech_encoder/__pycache__/timing.cpython-310.pyc b/egogpt/model/speech_encoder/__pycache__/timing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e18800e02d7a37a15e981d4d902992ae13550d4b Binary files /dev/null and b/egogpt/model/speech_encoder/__pycache__/timing.cpython-310.pyc differ diff --git a/egogpt/model/speech_encoder/__pycache__/tokenizer.cpython-310.pyc b/egogpt/model/speech_encoder/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6dd77cead7893888d1ae497aa24733084e1014a Binary files /dev/null and b/egogpt/model/speech_encoder/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/egogpt/model/speech_encoder/__pycache__/transcribe.cpython-310.pyc b/egogpt/model/speech_encoder/__pycache__/transcribe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffe77b0c70295b2ad866d464afaf4be78736087b Binary files /dev/null and b/egogpt/model/speech_encoder/__pycache__/transcribe.cpython-310.pyc differ diff --git a/egogpt/model/speech_encoder/__pycache__/utils.cpython-310.pyc b/egogpt/model/speech_encoder/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe94df0bc4265122842e3083ed87a56121d4f0ac Binary files /dev/null and b/egogpt/model/speech_encoder/__pycache__/utils.cpython-310.pyc differ diff --git a/egogpt/model/speech_encoder/audio.py b/egogpt/model/speech_encoder/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6c66ad9dc1e508030710c40981d4a7afc93d7f --- /dev/null +++ b/egogpt/model/speech_encoder/audio.py @@ -0,0 +1,157 @@ +import os +from functools import lru_cache +from subprocess import CalledProcessError, run +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from .utils import exact_div + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input + +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + + # This launches a subprocess to decode audio while down-mixing + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # fmt: off + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", "0", + "-i", file, + "-f", "s16le", + "-ac", "1", + "-acodec", "pcm_s16le", + "-ar", str(sr), + "-" + ] + # fmt: on + try: + out = run(cmd, capture_output=True, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select( + dim=axis, index=torch.arange(length, device=array.device) + ) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), + ) + """ + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + + filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + with np.load(filters_path, allow_pickle=False) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram( + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int = 80, + padding: int = 0, + device: Optional[Union[str, torch.device]] = None, +): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec diff --git a/egogpt/model/speech_encoder/builder.py b/egogpt/model/speech_encoder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..26a78aef1c039336eae00ed538904b5e20e88dfe --- /dev/null +++ b/egogpt/model/speech_encoder/builder.py @@ -0,0 +1,9 @@ +from .speech_encoder import WhisperWrappedEncoder + + +def build_speech_encoder(config): + speech_encoder_type = getattr(config, "speech_encoder_type", None) + if "whisper" in speech_encoder_type.lower(): + return WhisperWrappedEncoder(config) + + raise ValueError(f"Unknown speech encoder: {speech_encoder_type}") diff --git a/egogpt/model/speech_encoder/decoding.py b/egogpt/model/speech_encoder/decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..49485d0090fc0b8bf7b052d0b9b031cd95a8a28f --- /dev/null +++ b/egogpt/model/speech_encoder/decoding.py @@ -0,0 +1,826 @@ +from dataclasses import dataclass, field, replace +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.distributions import Categorical + +from .audio import CHUNK_LENGTH +from .tokenizer import Tokenizer, get_tokenizer +from .utils import compression_ratio + +if TYPE_CHECKING: + from .model import Whisper + + +@torch.no_grad() +def detect_language( + model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None +) -> Tuple[Tensor, List[dict]]: + """ + Detect the spoken language in the audio, and return them as list of strings, along with the ids + of the most probable language tokens and the probability distribution over all language tokens. + This is performed outside the main decode loop in order to not interfere with kv-caching. + + Returns + ------- + language_tokens : Tensor, shape = (n_audio,) + ids of the most probable language tokens, which appears after the startoftranscript token. + language_probs : List[Dict[str, float]], length = n_audio + list of dictionaries containing the probability distribution over all languages. + """ + if tokenizer is None: + tokenizer = get_tokenizer( + model.is_multilingual, num_languages=model.num_languages + ) + if ( + tokenizer.language is None + or tokenizer.language_token not in tokenizer.sot_sequence + ): + raise ValueError( + "This model doesn't have language tokens so it can't perform lang id" + ) + + single = mel.ndim == 2 + if single: + mel = mel.unsqueeze(0) + + # skip encoder forward pass if already-encoded audio features were given + if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): + mel = model.encoder(mel) + + # forward pass using a single token, startoftranscript + n_audio = mel.shape[0] + x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] + logits = model.logits(x, mel)[:, 0] + + # collect detected languages; suppress all non-language tokens + mask = torch.ones(logits.shape[-1], dtype=torch.bool) + mask[list(tokenizer.all_language_tokens)] = False + logits[:, mask] = -np.inf + language_tokens = logits.argmax(dim=-1) + language_token_probs = logits.softmax(dim=-1).cpu() + language_probs = [ + { + c: language_token_probs[i, j].item() + for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes) + } + for i in range(n_audio) + ] + + if single: + language_tokens = language_tokens[0] + language_probs = language_probs[0] + + return language_tokens, language_probs + + +@dataclass(frozen=True) +class DecodingOptions: + # whether to perform X->X "transcribe" or X->English "translate" + task: str = "transcribe" + + # language that the audio is in; uses detected language if None + language: Optional[str] = None + + # sampling-related options + temperature: float = 0.0 + sample_len: Optional[int] = None # maximum number of tokens to sample + best_of: Optional[int] = None # number of independent sample trajectories, if t > 0 + beam_size: Optional[int] = None # number of beams in beam search, if t == 0 + patience: Optional[float] = None # patience in beam search (arxiv:2204.05424) + + # "alpha" in Google NMT, or None for length norm, when ranking generations + # to select which to return among the beams or best-of-N samples + length_penalty: Optional[float] = None + + # text or tokens to feed as the prompt or the prefix; for more info: + # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 + prompt: Optional[Union[str, List[int]]] = None # for the previous context + prefix: Optional[Union[str, List[int]]] = None # to prefix the current context + + # list of tokens ids (or comma-separated token ids) to suppress + # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` + suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1" + suppress_blank: bool = True # this will suppress blank outputs + + # timestamp sampling options + without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only + max_initial_timestamp: Optional[float] = 1.0 + + # implementation details + fp16: bool = True # use fp16 for most of the calculation + + +@dataclass(frozen=True) +class DecodingResult: + audio_features: Tensor + language: str + language_probs: Optional[Dict[str, float]] = None + tokens: List[int] = field(default_factory=list) + text: str = "" + avg_logprob: float = np.nan + no_speech_prob: float = np.nan + temperature: float = np.nan + compression_ratio: float = np.nan + + +class Inference: + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: + """Perform a forward pass on the decoder and return per-token logits""" + raise NotImplementedError + + def rearrange_kv_cache(self, source_indices) -> None: + """Update the key-value cache according to the updated beams""" + raise NotImplementedError + + def cleanup_caching(self) -> None: + """Clean up any resources or hooks after decoding is finished""" + pass + + +class PyTorchInference(Inference): + def __init__(self, model: "Whisper", initial_token_length: int): + self.model: "Whisper" = model + self.initial_token_length = initial_token_length + self.kv_cache = {} + self.hooks = [] + + key_modules = [block.attn.key for block in self.model.decoder.blocks] + value_modules = [block.attn.value for block in self.model.decoder.blocks] + self.kv_modules = key_modules + value_modules + + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: + if not self.kv_cache: + self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() + + if tokens.shape[-1] > self.initial_token_length: + # only need to use the last token except in the first forward pass + tokens = tokens[:, -1:] + + return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) + + def cleanup_caching(self): + for hook in self.hooks: + hook.remove() + + self.kv_cache = {} + self.hooks = [] + + def rearrange_kv_cache(self, source_indices): + if source_indices != list(range(len(source_indices))): + for module in self.kv_modules: + # update the key/value cache to contain the selected sequences + self.kv_cache[module] = self.kv_cache[module][source_indices].detach() + + +class SequenceRanker: + def rank( + self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]] + ) -> List[int]: + """ + Given a list of groups of samples and their cumulative log probabilities, + return the indices of the samples in each group to select as the final result + """ + raise NotImplementedError + + +class MaximumLikelihoodRanker(SequenceRanker): + """ + Select the sample with the highest log probabilities, penalized using either + a simple length normalization or Google NMT paper's length penalty + """ + + def __init__(self, length_penalty: Optional[float]): + self.length_penalty = length_penalty + + def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]): + def scores(logprobs, lengths): + result = [] + for logprob, length in zip(logprobs, lengths): + if self.length_penalty is None: + penalty = length + else: + # from the Google NMT paper + penalty = ((5 + length) / 6) ** self.length_penalty + result.append(logprob / penalty) + return result + + # get the sequence with the highest score + lengths = [[len(t) for t in s] for s in tokens] + return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)] + + +class TokenDecoder: + def reset(self): + """Initialize any stateful variables for decoding a new sequence""" + + def update( + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor + ) -> Tuple[Tensor, bool]: + """Specify how to select the next token, based on the current trace and logits + + Parameters + ---------- + tokens : Tensor, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + + logits : Tensor, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + sum_logprobs : Tensor, shape = (n_batch) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : Tensor, shape = (n_batch, current_sequence_length + 1) + the tokens, appended with the selected next token + + completed : bool + True if all sequences has reached the end of text + + """ + raise NotImplementedError + + def finalize( + self, tokens: Tensor, sum_logprobs: Tensor + ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]: + """Finalize search and return the final candidate sequences + + Parameters + ---------- + tokens : Tensor, shape = (n_audio, n_group, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence + + sum_logprobs : Tensor, shape = (n_audio, n_group) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : Sequence[Sequence[Tensor]], length = n_audio + sequence of Tensors containing candidate token sequences, for each audio input + + sum_logprobs : List[List[float]], length = n_audio + sequence of cumulative log probabilities corresponding to the above + + """ + raise NotImplementedError + + +class GreedyDecoder(TokenDecoder): + def __init__(self, temperature: float, eot: int): + self.temperature = temperature + self.eot = eot + + def update( + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor + ) -> Tuple[Tensor, bool]: + if self.temperature == 0: + next_tokens = logits.argmax(dim=-1) + else: + next_tokens = Categorical(logits=logits / self.temperature).sample() + + logprobs = F.log_softmax(logits.float(), dim=-1) + current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] + sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) + + next_tokens[tokens[:, -1] == self.eot] = self.eot + tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) + + completed = (tokens[:, -1] == self.eot).all() + return tokens, completed + + def finalize(self, tokens: Tensor, sum_logprobs: Tensor): + # make sure each sequence has at least one EOT token at the end + tokens = F.pad(tokens, (0, 1), value=self.eot) + return tokens, sum_logprobs.tolist() + + +class BeamSearchDecoder(TokenDecoder): + def __init__( + self, + beam_size: int, + eot: int, + inference: Inference, + patience: Optional[float] = None, + ): + self.beam_size = beam_size + self.eot = eot + self.inference = inference + self.patience = patience or 1.0 + self.max_candidates: int = round(beam_size * self.patience) + self.finished_sequences = None + + assert ( + self.max_candidates > 0 + ), f"Invalid beam size ({beam_size}) or patience ({patience})" + + def reset(self): + self.finished_sequences = None + + def update( + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor + ) -> Tuple[Tensor, bool]: + if tokens.shape[0] % self.beam_size != 0: + raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") + + n_audio = tokens.shape[0] // self.beam_size + if self.finished_sequences is None: # for the first update + self.finished_sequences = [{} for _ in range(n_audio)] + + logprobs = F.log_softmax(logits.float(), dim=-1) + next_tokens, source_indices, finished_sequences = [], [], [] + for i in range(n_audio): + scores, sources, finished = {}, {}, {} + + # STEP 1: calculate the cumulative log probabilities for possible candidates + for j in range(self.beam_size): + idx = i * self.beam_size + j + prefix = tokens[idx].tolist() + for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)): + new_logprob = (sum_logprobs[idx] + logprob).item() + sequence = tuple(prefix + [token.item()]) + scores[sequence] = new_logprob + sources[sequence] = idx + + # STEP 2: rank the candidates and keep the top beam_size sequences for each audio + saved = 0 + for sequence in sorted(scores, key=scores.get, reverse=True): + if sequence[-1] == self.eot: + finished[sequence] = scores[sequence] + else: + sum_logprobs[len(next_tokens)] = scores[sequence] + next_tokens.append(sequence) + source_indices.append(sources[sequence]) + + saved += 1 + if saved == self.beam_size: + break + + finished_sequences.append(finished) + + tokens = torch.tensor(next_tokens, device=tokens.device) + self.inference.rearrange_kv_cache(source_indices) + + # add newly finished sequences to self.finished_sequences + assert len(self.finished_sequences) == len(finished_sequences) + for previously_finished, newly_finished in zip( + self.finished_sequences, finished_sequences + ): + for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): + if len(previously_finished) >= self.max_candidates: + break # the candidate list is full + previously_finished[seq] = newly_finished[seq] + + # mark as completed if all audio has enough number of samples + completed = all( + len(sequences) >= self.max_candidates + for sequences in self.finished_sequences + ) + return tokens, completed + + def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): + # collect all finished sequences, including patience, and add unfinished ones if not enough + sum_logprobs = sum_logprobs.cpu() + for i, sequences in enumerate(self.finished_sequences): + if ( + len(sequences) < self.beam_size + ): # when not enough sequences are finished + for j in list(np.argsort(sum_logprobs[i]))[::-1]: + sequence = preceding_tokens[i, j].tolist() + [self.eot] + sequences[tuple(sequence)] = sum_logprobs[i][j].item() + if len(sequences) >= self.beam_size: + break + + tokens: List[List[Tensor]] = [ + [torch.tensor(seq) for seq in sequences.keys()] + for sequences in self.finished_sequences + ] + sum_logprobs: List[List[float]] = [ + list(sequences.values()) for sequences in self.finished_sequences + ] + return tokens, sum_logprobs + + +class LogitFilter: + def apply(self, logits: Tensor, tokens: Tensor) -> None: + """Apply any filtering or masking to logits in-place + + Parameters + ---------- + logits : Tensor, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + tokens : Tensor, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + + """ + raise NotImplementedError + + +class SuppressBlank(LogitFilter): + def __init__(self, tokenizer: Tokenizer, sample_begin: int): + self.tokenizer = tokenizer + self.sample_begin = sample_begin + + def apply(self, logits: Tensor, tokens: Tensor): + if tokens.shape[1] == self.sample_begin: + logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf + + +class SuppressTokens(LogitFilter): + def __init__(self, suppress_tokens: Sequence[int]): + self.suppress_tokens = list(suppress_tokens) + + def apply(self, logits: Tensor, tokens: Tensor): + logits[:, self.suppress_tokens] = -np.inf + + +class ApplyTimestampRules(LogitFilter): + def __init__( + self, + tokenizer: Tokenizer, + sample_begin: int, + max_initial_timestamp_index: Optional[int], + ): + self.tokenizer = tokenizer + self.sample_begin = sample_begin + self.max_initial_timestamp_index = max_initial_timestamp_index + + def apply(self, logits: Tensor, tokens: Tensor): + # suppress <|notimestamps|> which is handled by without_timestamps + if self.tokenizer.no_timestamps is not None: + logits[:, self.tokenizer.no_timestamps] = -np.inf + + # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + for k in range(tokens.shape[0]): + sampled_tokens = tokens[k, self.sample_begin :] + seq = [t for t in sampled_tokens.tolist()] + last_was_timestamp = ( + len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin + ) + penultimate_was_timestamp = ( + len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin + ) + + if last_was_timestamp: + if penultimate_was_timestamp: # has to be non-timestamp + logits[k, self.tokenizer.timestamp_begin :] = -np.inf + else: # cannot be normal text tokens + logits[k, : self.tokenizer.eot] = -np.inf + + timestamps = sampled_tokens[ + sampled_tokens.ge(self.tokenizer.timestamp_begin) + ] + if timestamps.numel() > 0: + # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last + # also force each segment to have a nonzero length, to prevent infinite looping + if last_was_timestamp and not penultimate_was_timestamp: + timestamp_last = timestamps[-1] + else: + timestamp_last = timestamps[-1] + 1 + logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf + + if tokens.shape[1] == self.sample_begin: + # suppress generating non-timestamp tokens at the beginning + logits[:, : self.tokenizer.timestamp_begin] = -np.inf + + # apply the `max_initial_timestamp` option + if self.max_initial_timestamp_index is not None: + last_allowed = ( + self.tokenizer.timestamp_begin + self.max_initial_timestamp_index + ) + logits[:, last_allowed + 1 :] = -np.inf + + # if sum of probability over timestamps is above any other token, sample timestamp + logprobs = F.log_softmax(logits.float(), dim=-1) + for k in range(tokens.shape[0]): + timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp( + dim=-1 + ) + max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() + if timestamp_logprob > max_text_token_logprob: + logits[k, : self.tokenizer.timestamp_begin] = -np.inf + + +class DecodingTask: + inference: Inference + sequence_ranker: SequenceRanker + decoder: TokenDecoder + logit_filters: List[LogitFilter] + + def __init__(self, model: "Whisper", options: DecodingOptions): + self.model = model + + language = options.language or "en" + tokenizer = get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=options.task, + ) + self.tokenizer: Tokenizer = tokenizer + self.options: DecodingOptions = self._verify_options(options) + + self.n_group: int = options.beam_size or options.best_of or 1 + self.n_ctx: int = model.dims.n_text_ctx + self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 + + self.sot_sequence: Tuple[int] = tokenizer.sot_sequence + if self.options.without_timestamps: + self.sot_sequence = tokenizer.sot_sequence_including_notimestamps + + self.initial_tokens: Tuple[int] = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.sot_index: int = self.initial_tokens.index(tokenizer.sot) + + # inference: implements the forward pass through the decoder, including kv caching + self.inference = PyTorchInference(model, len(self.initial_tokens)) + + # sequence ranker: implements how to rank a group of sampled sequences + self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) + + # decoder: implements how to select the next tokens, given the autoregressive distribution + if options.beam_size is not None: + self.decoder = BeamSearchDecoder( + options.beam_size, tokenizer.eot, self.inference, options.patience + ) + else: + self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) + + # logit filters: applies various rules to suppress or penalize certain tokens + self.logit_filters = [] + if self.options.suppress_blank: + self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin)) + if self.options.suppress_tokens: + self.logit_filters.append(SuppressTokens(self._get_suppress_tokens())) + if not options.without_timestamps: + precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds + max_initial_timestamp_index = None + if options.max_initial_timestamp: + max_initial_timestamp_index = round( + self.options.max_initial_timestamp / precision + ) + self.logit_filters.append( + ApplyTimestampRules( + tokenizer, self.sample_begin, max_initial_timestamp_index + ) + ) + + def _verify_options(self, options: DecodingOptions) -> DecodingOptions: + if options.beam_size is not None and options.best_of is not None: + raise ValueError("beam_size and best_of can't be given together") + if options.temperature == 0: + if options.best_of is not None: + raise ValueError("best_of with greedy sampling (T=0) is not compatible") + if options.patience is not None and options.beam_size is None: + raise ValueError("patience requires beam_size to be given") + if options.length_penalty is not None and not ( + 0 <= options.length_penalty <= 1 + ): + raise ValueError("length_penalty (alpha) should be a value between 0 and 1") + + return options + + def _get_initial_tokens(self) -> Tuple[int]: + tokens = list(self.sot_sequence) + + if prefix := self.options.prefix: + prefix_tokens = ( + self.tokenizer.encode(" " + prefix.strip()) + if isinstance(prefix, str) + else prefix + ) + if self.sample_len is not None: + max_prefix_len = self.n_ctx // 2 - self.sample_len + prefix_tokens = prefix_tokens[-max_prefix_len:] + tokens = tokens + prefix_tokens + + if prompt := self.options.prompt: + prompt_tokens = ( + self.tokenizer.encode(" " + prompt.strip()) + if isinstance(prompt, str) + else prompt + ) + tokens = ( + [self.tokenizer.sot_prev] + + prompt_tokens[-(self.n_ctx // 2 - 1) :] + + tokens + ) + + return tuple(tokens) + + def _get_suppress_tokens(self) -> Tuple[int]: + suppress_tokens = self.options.suppress_tokens + + if isinstance(suppress_tokens, str): + suppress_tokens = [int(t) for t in suppress_tokens.split(",")] + + if -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(self.tokenizer.non_speech_tokens) + elif suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + else: + assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" + + suppress_tokens.extend( + [ + self.tokenizer.transcribe, + self.tokenizer.translate, + self.tokenizer.sot, + self.tokenizer.sot_prev, + self.tokenizer.sot_lm, + ] + ) + if self.tokenizer.no_speech is not None: + # no-speech probability is collected separately + suppress_tokens.append(self.tokenizer.no_speech) + + return tuple(sorted(set(suppress_tokens))) + + def _get_audio_features(self, mel: Tensor): + if self.options.fp16: + mel = mel.half() + + if mel.shape[-2:] == ( + self.model.dims.n_audio_ctx, + self.model.dims.n_audio_state, + ): + # encoded audio features are given; skip audio encoding + audio_features = mel + else: + audio_features = self.model.encoder(mel) + + if audio_features.dtype != ( + torch.float16 if self.options.fp16 else torch.float32 + ): + return TypeError( + f"audio_features has an incorrect dtype: {audio_features.dtype}" + ) + + return audio_features + + def _detect_language(self, audio_features: Tensor, tokens: Tensor): + languages = [self.options.language] * audio_features.shape[0] + lang_probs = None + + if self.options.language is None or self.options.task == "lang_id": + lang_tokens, lang_probs = self.model.detect_language( + audio_features, self.tokenizer + ) + languages = [max(probs, key=probs.get) for probs in lang_probs] + if self.options.language is None: + tokens[:, self.sot_index + 1] = lang_tokens # write language tokens + + return languages, lang_probs + + def _main_loop(self, audio_features: Tensor, tokens: Tensor): + n_batch = tokens.shape[0] + sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) + no_speech_probs = [np.nan] * n_batch + + try: + for i in range(self.sample_len): + logits = self.inference.logits(tokens, audio_features) + + if ( + i == 0 and self.tokenizer.no_speech is not None + ): # save no_speech_probs + probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() + + # now we need to consider the logits at the last token only + logits = logits[:, -1] + + # apply the logit filters, e.g. for suppressing or applying penalty to + for logit_filter in self.logit_filters: + logit_filter.apply(logits, tokens) + + # expand the tokens tensor with the selected next tokens + tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) + + if completed or tokens.shape[-1] > self.n_ctx: + break + finally: + self.inference.cleanup_caching() + + return tokens, sum_logprobs, no_speech_probs + + @torch.no_grad() + def run(self, mel: Tensor) -> List[DecodingResult]: + self.decoder.reset() + tokenizer: Tokenizer = self.tokenizer + n_audio: int = mel.shape[0] + + audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass + tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1) + + # detect language if requested, overwriting the language token + languages, language_probs = self._detect_language(audio_features, tokens) + if self.options.task == "lang_id": + return [ + DecodingResult( + audio_features=features, language=language, language_probs=probs + ) + for features, language, probs in zip( + audio_features, languages, language_probs + ) + ] + + # repeat text tensors by the group size, for beam search or best-of-n sampling + tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) + + # call the main sampling loop + tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) + + # reshape the tensors to have (n_audio, n_group) as the first two dimensions + audio_features = audio_features[:: self.n_group] + no_speech_probs = no_speech_probs[:: self.n_group] + assert audio_features.shape[0] == len(no_speech_probs) == n_audio + + tokens = tokens.reshape(n_audio, self.n_group, -1) + sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) + + # get the final candidates for each group, and slice between the first sampled token and EOT + tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) + tokens: List[List[Tensor]] = [ + [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] + for s in tokens + ] + + # select the top-ranked sample in each group + selected = self.sequence_ranker.rank(tokens, sum_logprobs) + tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] + texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] + + sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] + avg_logprobs: List[float] = [ + lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs) + ] + + fields = ( + texts, + languages, + tokens, + audio_features, + avg_logprobs, + no_speech_probs, + ) + if len(set(map(len, fields))) != 1: + raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") + + return [ + DecodingResult( + audio_features=features, + language=language, + tokens=tokens, + text=text, + avg_logprob=avg_logprob, + no_speech_prob=no_speech_prob, + temperature=self.options.temperature, + compression_ratio=compression_ratio(text), + ) + for text, language, tokens, features, avg_logprob, no_speech_prob in zip( + *fields + ) + ] + + +@torch.no_grad() +def decode( + model: "Whisper", + mel: Tensor, + options: DecodingOptions = DecodingOptions(), + **kwargs, +) -> Union[DecodingResult, List[DecodingResult]]: + """ + Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). + + Parameters + ---------- + model: Whisper + the Whisper model instance + + mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000) + A tensor containing the Mel spectrogram(s) + + options: DecodingOptions + A dataclass that contains all necessary options for decoding 30-second segments + + Returns + ------- + result: Union[DecodingResult, List[DecodingResult]] + The result(s) of decoding contained in `DecodingResult` dataclass instance(s) + """ + if single := mel.ndim == 2: + mel = mel.unsqueeze(0) + + if kwargs: + options = replace(options, **kwargs) + + result = DecodingTask(model, options).run(mel) + + return result[0] if single else result diff --git a/egogpt/model/speech_encoder/model.py b/egogpt/model/speech_encoder/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3b93b8b8bd752e77529a254782c94601b4e7502a --- /dev/null +++ b/egogpt/model/speech_encoder/model.py @@ -0,0 +1,345 @@ +import base64 +import gzip +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Dict, Iterable, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from .decoding import decode as decode_function +from .decoding import detect_language as detect_language_function +from .transcribe import transcribe as transcribe_function + +try: + from torch.nn.functional import scaled_dot_product_attention + + SDPA_AVAILABLE = True +except (ImportError, RuntimeError, OSError): + scaled_dot_product_attention = None + SDPA_AVAILABLE = False + + +@dataclass +class ModelDimensions: + n_mels: int + n_audio_ctx: int + n_audio_state: int + n_audio_head: int + n_audio_layer: int + n_vocab: int + n_text_ctx: int + n_text_state: int + n_text_head: int + n_text_layer: int + + +class LayerNorm(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + return super().forward(x).type(x.dtype) # Choiszt fix + + +class Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear( + x, + self.weight.to(x.dtype), + None if self.bias is None else self.bias.to(x.dtype), + ) + + +class Conv1d(nn.Conv1d): + def _conv_forward( + self, x: Tensor, weight: Tensor, bias: Optional[Tensor] + ) -> Tensor: + return super()._conv_forward( + x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) + ) + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +@contextmanager +def disable_sdpa(): + prev_state = MultiHeadAttention.use_sdpa + try: + MultiHeadAttention.use_sdpa = False + yield + finally: + MultiHeadAttention.use_sdpa = prev_state + + +class MultiHeadAttention(nn.Module): + use_sdpa = True + + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), qk + + def qkv_attention( + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa: + a = scaled_dot_product_attention( + q, k, v, is_causal=mask is not None and n_ctx > 1 + ) + out = a.permute(0, 2, 1, 3).flatten(start_dim=2) + qk = None + else: + qk = (q * scale) @ (k * scale).transpose(-1, -2) + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.float() + + w = F.softmax(qk, dim=-1).to(q.dtype) + out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) + qk = qk.detach() + + return out, qk + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head) + self.attn_ln = LayerNorm(n_state) + + self.cross_attn = ( + MultiHeadAttention(n_state, n_head) if cross_attention else None + ) + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = nn.Sequential( + Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) + ) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] + if self.cross_attn: + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class AudioEncoder(nn.Module): + def __init__( + self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): + super().__init__() + self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] + ) + self.ln_post = LayerNorm(n_state) + + def forward(self, x: Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" + x = (x + self.positional_embedding).to(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + +class TextDecoder(nn.Module): + def __init__( + self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): + super().__init__() + + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ + ResidualAttentionBlock(n_state, n_head, cross_attention=True) + for _ in range(n_layer) + ] + ) + self.ln = LayerNorm(n_state) + + mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.register_buffer("mask", mask, persistent=False) + + def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + """ + x : torch.LongTensor, shape = (batch_size, <= n_ctx) + the text tokens + xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) + the encoded audio features to be attended on + """ + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = ( + self.token_embedding(x) + + self.positional_embedding[offset : offset + x.shape[-1]] + ) + x = x.to(xa.dtype) + + for block in self.blocks: + x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + + x = self.ln(x) + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + return logits + + +class Whisper(nn.Module): + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + ) + # use the last half among the decoder layers for time alignment by default; + # to use a specific set of heads, see `set_alignment_heads()` below. + all_heads = torch.zeros( + self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool + ) + all_heads[self.dims.n_text_layer // 2 :] = True + self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) + + def set_alignment_heads(self, dump: bytes): + array = np.frombuffer( + gzip.decompress(base64.b85decode(dump)), dtype=bool + ).copy() + mask = torch.from_numpy(array).reshape( + self.dims.n_text_layer, self.dims.n_text_head + ) + self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) + + def embed_audio(self, mel: torch.Tensor): + return self.encoder(mel) + + def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): + return self.decoder(tokens, audio_features) + + def forward( + self, mel: torch.Tensor, tokens: torch.Tensor + ) -> Dict[str, torch.Tensor]: + return self.decoder(tokens, self.encoder(mel)) + + @property + def device(self): + return next(self.parameters()).device + + @property + def is_multilingual(self): + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) + + def install_kv_cache_hooks(self, cache: Optional[dict] = None): + """ + The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value + tensors calculated for the previous positions. This method returns a dictionary that stores + all caches, and the necessary hooks for the key and value projection modules that save the + intermediate tensors to be reused during later calculations. + + Returns + ------- + cache : Dict[nn.Module, torch.Tensor] + A dictionary object mapping the key/value projection modules to its cache + hooks : List[RemovableHandle] + List of PyTorch RemovableHandle objects to stop the hooks to be called + """ + cache = {**cache} if cache is not None else {} + hooks = [] + + def save_to_cache(module, _, output): + if module not in cache or output.shape[1] > self.dims.n_text_ctx: + # save as-is, for the first token or cross attention + cache[module] = output + else: + cache[module] = torch.cat([cache[module], output], dim=1).detach() + return cache[module] + + def install_hooks(layer: nn.Module): + if isinstance(layer, MultiHeadAttention): + hooks.append(layer.key.register_forward_hook(save_to_cache)) + hooks.append(layer.value.register_forward_hook(save_to_cache)) + + self.decoder.apply(install_hooks) + return cache, hooks + + detect_language = detect_language_function + transcribe = transcribe_function + decode = decode_function diff --git a/egogpt/model/speech_encoder/speech_encoder.py b/egogpt/model/speech_encoder/speech_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e37ee3e0a5c4d914c9a04684da30b0cfa61784 --- /dev/null +++ b/egogpt/model/speech_encoder/speech_encoder.py @@ -0,0 +1,198 @@ +# Adopted from https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/encoder.py + +import types + +import deepspeed +import torch +import torch.nn as nn +import torch.nn.functional as F + +from egogpt.utils import rank0_print + +from .model import ModelDimensions, Whisper + + +def load_zero_partitions( + model, + state_dict, + is_deepspeed_zero3_enabled, + pretrained_model_path, + ignore_mismatched_sizes=False, +): + """ + adept from pytorch lightning and transformers + with deepspeed.zero.Init(): + model = MyModel() + state_dict = torch.load(model_path, map_location="cpu") + load_zero_partitions(model, prefix="") + """ + + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + model_state_dict = model.state_dict() + expected_keys = list(model_state_dict.keys()) + loaded_keys = list(state_dict.keys()) + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape + != model_state_dict[model_key].shape + ): + mismatched_keys.append( + ( + checkpoint_key, + state_dict[checkpoint_key].shape, + model_state_dict[model_key].shape, + ) + ) + del state_dict[checkpoint_key] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + if is_deepspeed_zero3_enabled: + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters( + list(module.parameters(recurse=False)), modifier_rank=0 + ): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + else: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + load(model_to_load, prefix=start_prefix) + del state_dict + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + raise RuntimeError( + f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}" + ) + if len(unexpected_keys) > 0: + rank0_print( + f"Some weights of the model checkpoint at {pretrained_model_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + rank0_print( + f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" + ) + if len(missing_keys) > 0: + rank0_print( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + rank0_print( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + rank0_print( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + +class WhisperWrappedEncoder(nn.Module): + def __init__(self, config, delay_load=False): + super().__init__() + + self.is_loaded = False + self.speech_encoder_name = config.speech_encoder + + if not delay_load: + rank0_print(f"Loading speech encoder: {self.speech_encoder_name}") + self.load_model(config) + + def load_model(self, model_config): + if self.is_loaded: + print( + "{} is already loaded, `load_model` called again, skipping.".format( + self.speech_encoder_name + ) + ) + return + + def replace_layer_norm(module): + from whisper.model import LayerNorm + + for name, child in module.named_children(): + if isinstance(child, LayerNorm): + old_params = child.state_dict() + new_layer_norm = nn.LayerNorm( + child.normalized_shape, + eps=child.eps, + elementwise_affine=child.elementwise_affine, + ) + new_layer_norm.load_state_dict(old_params) + setattr(module, name, new_layer_norm) + else: + replace_layer_norm(child) + + # import whisper + # self.encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder + checkpoint = torch.load(self.speech_encoder_name, map_location="cpu") + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + deepspeed3_enabled = True + # print(deepspeed3_enabled) + load_zero_partitions( + model, + checkpoint["model_state_dict"], + deepspeed3_enabled, + self.speech_encoder_name, + ) + self.encoder = model.encoder + replace_layer_norm(self.encoder) + self.encoder.requires_grad_(False) + + self.is_loaded = True + + def forward(self, audio): + return self.encoder(audio) diff --git a/egogpt/model/speech_encoder/timing.py b/egogpt/model/speech_encoder/timing.py new file mode 100644 index 0000000000000000000000000000000000000000..e5634142bddd914c4b949bbd8113e17224a19f87 --- /dev/null +++ b/egogpt/model/speech_encoder/timing.py @@ -0,0 +1,388 @@ +import itertools +import subprocess +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING, List + +import numba +import numpy as np +import torch +import torch.nn.functional as F + +from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND +from .tokenizer import Tokenizer + +if TYPE_CHECKING: + from .model import Whisper + + +def median_filter(x: torch.Tensor, filter_width: int): + """Apply a median filter of width `filter_width` along the last dimension of `x`""" + pad_width = filter_width // 2 + if x.shape[-1] <= pad_width: + # F.pad requires the padding width to be smaller than the input dimension + return x + + if (ndim := x.ndim) <= 2: + # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D + x = x[None, None, :] + + assert ( + filter_width > 0 and filter_width % 2 == 1 + ), "`filter_width` should be an odd number" + + result = None + x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") + if x.is_cuda: + try: + from .triton_ops import median_filter_cuda + + result = median_filter_cuda(x, filter_width) + except (RuntimeError, subprocess.CalledProcessError): + warnings.warn( + "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " + "falling back to a slower median kernel implementation..." + ) + + if result is None: + # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) + result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] + + if ndim <= 2: + result = result[0, 0] + + return result + + +@numba.jit(nopython=True) +def backtrace(trace: np.ndarray): + i = trace.shape[0] - 1 + j = trace.shape[1] - 1 + trace[0, :] = 2 + trace[:, 0] = 1 + + result = [] + while i > 0 or j > 0: + result.append((i - 1, j - 1)) + + if trace[i, j] == 0: + i -= 1 + j -= 1 + elif trace[i, j] == 1: + i -= 1 + elif trace[i, j] == 2: + j -= 1 + else: + raise ValueError("Unexpected trace[i, j]") + + result = np.array(result) + return result[::-1, :].T + + +@numba.jit(nopython=True, parallel=True) +def dtw_cpu(x: np.ndarray): + N, M = x.shape + cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf + trace = -np.ones((N + 1, M + 1), dtype=np.float32) + + cost[0, 0] = 0 + for j in range(1, M + 1): + for i in range(1, N + 1): + c0 = cost[i - 1, j - 1] + c1 = cost[i - 1, j] + c2 = cost[i, j - 1] + + if c0 < c1 and c0 < c2: + c, t = c0, 0 + elif c1 < c0 and c1 < c2: + c, t = c1, 1 + else: + c, t = c2, 2 + + cost[i, j] = x[i - 1, j - 1] + c + trace[i, j] = t + + return backtrace(trace) + + +def dtw_cuda(x, BLOCK_SIZE=1024): + from .triton_ops import dtw_kernel + + M, N = x.shape + assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" + + x_skew = ( + F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) + ) + x_skew = x_skew.T.contiguous() + cost = torch.ones(N + M + 2, M + 2) * np.inf + cost[0, 0] = 0 + cost = cost.cuda() + trace = torch.zeros_like(cost, dtype=torch.int32) + + dtw_kernel[(1,)]( + cost, + trace, + x_skew, + x_skew.stride(0), + cost.stride(0), + trace.stride(0), + N, + M, + BLOCK_SIZE=BLOCK_SIZE, + ) + + trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[ + :, : N + 1 + ] + return backtrace(trace.cpu().numpy()) + + +def dtw(x: torch.Tensor) -> np.ndarray: + if x.is_cuda: + try: + return dtw_cuda(x) + except (RuntimeError, subprocess.CalledProcessError): + warnings.warn( + "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " + "falling back to a slower DTW implementation..." + ) + + return dtw_cpu(x.double().cpu().numpy()) + + +@dataclass +class WordTiming: + word: str + tokens: List[int] + start: float + end: float + probability: float + + +def find_alignment( + model: "Whisper", + tokenizer: Tokenizer, + text_tokens: List[int], + mel: torch.Tensor, + num_frames: int, + *, + medfilt_width: int = 7, + qk_scale: float = 1.0, +) -> List[WordTiming]: + if len(text_tokens) == 0: + return [] + + tokens = torch.tensor( + [ + *tokenizer.sot_sequence, + tokenizer.no_timestamps, + *text_tokens, + tokenizer.eot, + ] + ).to(model.device) + + # install hooks on the cross attention layers to retrieve the attention weights + QKs = [None] * model.dims.n_text_layer + hooks = [ + block.cross_attn.register_forward_hook( + lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0]) + ) + for i, block in enumerate(model.decoder.blocks) + ] + + from .model import disable_sdpa + + with torch.no_grad(), disable_sdpa(): + logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] + sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] + token_probs = sampled_logits.softmax(dim=-1) + text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] + text_token_probs = text_token_probs.tolist() + + for hook in hooks: + hook.remove() + + # heads * tokens * frames + weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) + weights = weights[:, :, : num_frames // 2] + weights = (weights * qk_scale).softmax(dim=-1) + std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) + weights = (weights - mean) / std + weights = median_filter(weights, medfilt_width) + + matrix = weights.mean(axis=0) + matrix = matrix[len(tokenizer.sot_sequence) : -1] + text_indices, time_indices = dtw(-matrix) + + words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) + if len(word_tokens) <= 1: + # return on eot only + # >>> np.pad([], (1, 0)) + # array([0.]) + # This results in crashes when we lookup jump_times with float, like + # IndexError: arrays used as indices must be of integer (or boolean) type + return [] + word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) + + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) + jump_times = time_indices[jumps] / TOKENS_PER_SECOND + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] + word_probabilities = [ + np.mean(text_token_probs[i:j]) + for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + ] + + return [ + WordTiming(word, tokens, start, end, probability) + for word, tokens, start, end, probability in zip( + words, word_tokens, start_times, end_times, word_probabilities + ) + ] + + +def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str): + # merge prepended punctuations + i = len(alignment) - 2 + j = len(alignment) - 1 + while i >= 0: + previous = alignment[i] + following = alignment[j] + if previous.word.startswith(" ") and previous.word.strip() in prepended: + # prepend it to the following word + following.word = previous.word + following.word + following.tokens = previous.tokens + following.tokens + previous.word = "" + previous.tokens = [] + else: + j = i + i -= 1 + + # merge appended punctuations + i = 0 + j = 1 + while j < len(alignment): + previous = alignment[i] + following = alignment[j] + if not previous.word.endswith(" ") and following.word in appended: + # append it to the previous word + previous.word = previous.word + following.word + previous.tokens = previous.tokens + following.tokens + following.word = "" + following.tokens = [] + else: + i = j + j += 1 + + +def add_word_timestamps( + *, + segments: List[dict], + model: "Whisper", + tokenizer: Tokenizer, + mel: torch.Tensor, + num_frames: int, + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", + last_speech_timestamp: float, + **kwargs, +): + if len(segments) == 0: + return + + text_tokens_per_segment = [ + [token for token in segment["tokens"] if token < tokenizer.eot] + for segment in segments + ] + + text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) + alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) + word_durations = np.array([t.end - t.start for t in alignment]) + word_durations = word_durations[word_durations.nonzero()] + median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + median_duration = min(0.7, float(median_duration)) + max_duration = median_duration * 2 + + # hack: truncate long words at sentence boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(word_durations) > 0: + sentence_end_marks = ".。!!??" + # ensure words at sentence boundaries are not longer than twice the median word duration. + for i in range(1, len(alignment)): + if alignment[i].end - alignment[i].start > max_duration: + if alignment[i].word in sentence_end_marks: + alignment[i].end = alignment[i].start + max_duration + elif alignment[i - 1].word in sentence_end_marks: + alignment[i].start = alignment[i].end - max_duration + + merge_punctuations(alignment, prepend_punctuations, append_punctuations) + + time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE + word_index = 0 + + for segment, text_tokens in zip(segments, text_tokens_per_segment): + saved_tokens = 0 + words = [] + + while word_index < len(alignment) and saved_tokens < len(text_tokens): + timing = alignment[word_index] + + if timing.word: + words.append( + dict( + word=timing.word, + start=round(time_offset + timing.start, 2), + end=round(time_offset + timing.end, 2), + probability=timing.probability, + ) + ) + + saved_tokens += len(timing.tokens) + word_index += 1 + + # hack: truncate long words at segment boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(words) > 0: + # ensure the first and second word after a pause is not longer than + # twice the median word duration. + if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( + words[0]["end"] - words[0]["start"] > max_duration + or ( + len(words) > 1 + and words[1]["end"] - words[0]["start"] > max_duration * 2 + ) + ): + if ( + len(words) > 1 + and words[1]["end"] - words[1]["start"] > max_duration + ): + boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration) + words[0]["end"] = words[1]["start"] = boundary + words[0]["start"] = max(0, words[0]["end"] - max_duration) + + # prefer the segment-level start timestamp if the first word is too long. + if ( + segment["start"] < words[0]["end"] + and segment["start"] - 0.5 > words[0]["start"] + ): + words[0]["start"] = max( + 0, min(words[0]["end"] - median_duration, segment["start"]) + ) + else: + segment["start"] = words[0]["start"] + + # prefer the segment-level end timestamp if the last word is too long. + if ( + segment["end"] > words[-1]["start"] + and segment["end"] + 0.5 < words[-1]["end"] + ): + words[-1]["end"] = max( + words[-1]["start"] + median_duration, segment["end"] + ) + else: + segment["end"] = words[-1]["end"] + + last_speech_timestamp = segment["end"] + + segment["words"] = words diff --git a/egogpt/model/speech_encoder/tokenizer.py b/egogpt/model/speech_encoder/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2af837570b1e6e87457179f07d63d046cbe0d4f3 --- /dev/null +++ b/egogpt/model/speech_encoder/tokenizer.py @@ -0,0 +1,395 @@ +import base64 +import os +import string +from dataclasses import dataclass, field +from functools import cached_property, lru_cache +from typing import Dict, List, Optional, Tuple + +import tiktoken + +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", + "yue": "cantonese", +} + +# language code lookup by name, with a few language aliases +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", + "mandarin": "zh", +} + + +@dataclass +class Tokenizer: + """A thin wrapper around `tiktoken` providing quick access to special tokens""" + + encoding: tiktoken.Encoding + num_languages: int + language: Optional[str] = None + task: Optional[str] = None + sot_sequence: Tuple[int] = () + special_tokens: Dict[str, int] = field(default_factory=dict) + + def __post_init__(self): + for special in self.encoding.special_tokens_set: + special_token = self.encoding.encode_single_token(special) + self.special_tokens[special] = special_token + + sot: int = self.special_tokens["<|startoftranscript|>"] + translate: int = self.special_tokens["<|translate|>"] + transcribe: int = self.special_tokens["<|transcribe|>"] + + langs = tuple(LANGUAGES.keys())[: self.num_languages] + sot_sequence = [sot] + if self.language is not None: + sot_sequence.append(sot + 1 + langs.index(self.language)) + if self.task is not None: + task_token: int = transcribe if self.task == "transcribe" else translate + sot_sequence.append(task_token) + + self.sot_sequence = tuple(sot_sequence) + + def encode(self, text, **kwargs): + return self.encoding.encode(text, **kwargs) + + def decode(self, token_ids: List[int], **kwargs) -> str: + token_ids = [t for t in token_ids if t < self.timestamp_begin] + return self.encoding.decode(token_ids, **kwargs) + + def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str: + """ + Timestamp tokens are above other special tokens' id range and are ignored by `decode()`. + This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". + """ + return self.encoding.decode(token_ids, **kwargs) + + @cached_property + def eot(self) -> int: + return self.encoding.eot_token + + @cached_property + def transcribe(self) -> int: + return self.special_tokens["<|transcribe|>"] + + @cached_property + def translate(self) -> int: + return self.special_tokens["<|translate|>"] + + @cached_property + def sot(self) -> int: + return self.special_tokens["<|startoftranscript|>"] + + @cached_property + def sot_lm(self) -> int: + return self.special_tokens["<|startoflm|>"] + + @cached_property + def sot_prev(self) -> int: + return self.special_tokens["<|startofprev|>"] + + @cached_property + def no_speech(self) -> int: + return self.special_tokens["<|nospeech|>"] + + @cached_property + def no_timestamps(self) -> int: + return self.special_tokens["<|notimestamps|>"] + + @cached_property + def timestamp_begin(self) -> int: + return self.special_tokens["<|0.00|>"] + + @cached_property + def language_token(self) -> int: + """Returns the token id corresponding to the value of the `language` field""" + if self.language is None: + raise ValueError("This tokenizer does not have language token configured") + + return self.to_language_token(self.language) + + def to_language_token(self, language): + if token := self.special_tokens.get(f"<|{language}|>", None): + return token + + raise KeyError(f"Language {language} not found in tokenizer.") + + @cached_property + def all_language_tokens(self) -> Tuple[int]: + result = [] + for token, token_id in self.special_tokens.items(): + if token.strip("<|>") in LANGUAGES: + result.append(token_id) + return tuple(result)[: self.num_languages] + + @cached_property + def all_language_codes(self) -> Tuple[str]: + return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) + + @cached_property + def sot_sequence_including_notimestamps(self) -> Tuple[int]: + return tuple(list(self.sot_sequence) + [self.no_timestamps]) + + @cached_property + def non_speech_tokens(self) -> Tuple[int]: + """ + Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech + annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. + + - ♪♪♪ + - ( SPEAKING FOREIGN LANGUAGE ) + - [DAVID] Hey there, + + keeping basic punctuations like commas, periods, question marks, exclamation points, etc. + """ + symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') + symbols += ( + "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() + ) + + # symbols that may be a single token or multiple tokens depending on the tokenizer. + # In case they're multiple tokens, suppress the first token, which is safe because: + # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress + # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. + miscellaneous = set("♩♪♫♬♭♮♯") + assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) + + # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]} + for symbol in symbols + list(miscellaneous): + for tokens in [ + self.encoding.encode(symbol), + self.encoding.encode(" " + symbol), + ]: + if len(tokens) == 1 or symbol in miscellaneous: + result.add(tokens[0]) + + return tuple(sorted(result)) + + def split_to_word_tokens(self, tokens: List[int]): + if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: + # These languages don't typically use spaces, so it is difficult to split words + # without morpheme analysis. Here, we instead split words at any + # position where the tokens are decoded as valid unicode points + return self.split_tokens_on_unicode(tokens) + + return self.split_tokens_on_spaces(tokens) + + def split_tokens_on_unicode(self, tokens: List[int]): + decoded_full = self.decode_with_timestamps(tokens) + replacement_char = "\ufffd" + + words = [] + word_tokens = [] + current_tokens = [] + unicode_offset = 0 + + for token in tokens: + current_tokens.append(token) + decoded = self.decode_with_timestamps(current_tokens) + + if ( + replacement_char not in decoded + or decoded_full[unicode_offset + decoded.index(replacement_char)] + == replacement_char + ): + words.append(decoded) + word_tokens.append(current_tokens) + current_tokens = [] + unicode_offset += len(decoded) + + return words, word_tokens + + def split_tokens_on_spaces(self, tokens: List[int]): + subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) + words = [] + word_tokens = [] + + for subword, subword_tokens in zip(subwords, subword_tokens_list): + special = subword_tokens[0] >= self.eot + with_space = subword.startswith(" ") + punctuation = subword.strip() in string.punctuation + if special or with_space or punctuation or len(words) == 0: + words.append(subword) + word_tokens.append(subword_tokens) + else: + words[-1] = words[-1] + subword + word_tokens[-1].extend(subword_tokens) + + return words, word_tokens + + +@lru_cache(maxsize=None) +def get_encoding(name: str = "gpt2", num_languages: int = 99): + vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") + ranks = { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in open(vocab_path) if line) + } + n_vocab = len(ranks) + special_tokens = {} + + specials = [ + "<|endoftext|>", + "<|startoftranscript|>", + *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], + "<|translate|>", + "<|transcribe|>", + "<|startoflm|>", + "<|startofprev|>", + "<|nospeech|>", + "<|notimestamps|>", + *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], + ] + + for token in specials: + special_tokens[token] = n_vocab + n_vocab += 1 + + return tiktoken.Encoding( + name=os.path.basename(vocab_path), + explicit_n_vocab=n_vocab, + pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + mergeable_ranks=ranks, + special_tokens=special_tokens, + ) + + +@lru_cache(maxsize=None) +def get_tokenizer( + multilingual: bool, + *, + num_languages: int = 99, + language: Optional[str] = None, + task: Optional[str] = None, # Literal["transcribe", "translate", None] +) -> Tokenizer: + if language is not None: + language = language.lower() + if language not in LANGUAGES: + if language in TO_LANGUAGE_CODE: + language = TO_LANGUAGE_CODE[language] + else: + raise ValueError(f"Unsupported language: {language}") + + if multilingual: + encoding_name = "multilingual" + language = language or "en" + task = task or "transcribe" + else: + encoding_name = "gpt2" + language = None + task = None + + encoding = get_encoding(name=encoding_name, num_languages=num_languages) + + return Tokenizer( + encoding=encoding, num_languages=num_languages, language=language, task=task + ) diff --git a/egogpt/model/speech_encoder/transcribe.py b/egogpt/model/speech_encoder/transcribe.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1240bd6af301ac95d9c2cab4d4a13a00daf5fb --- /dev/null +++ b/egogpt/model/speech_encoder/transcribe.py @@ -0,0 +1,605 @@ +import argparse +import os +import traceback +import warnings +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import numpy as np +import torch +import tqdm + +from .audio import ( + FRAMES_PER_SECOND, + HOP_LENGTH, + N_FRAMES, + N_SAMPLES, + SAMPLE_RATE, + log_mel_spectrogram, + pad_or_trim, +) +from .decoding import DecodingOptions, DecodingResult +from .timing import add_word_timestamps +from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer +from .utils import ( + exact_div, + format_timestamp, + get_end, + get_writer, + make_safe, + optional_float, + optional_int, + str2bool, +) + +if TYPE_CHECKING: + from .model import Whisper + + +def transcribe( + model: "Whisper", + audio: Union[str, np.ndarray, torch.Tensor], + *, + verbose: Optional[bool] = None, + temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + initial_prompt: Optional[str] = None, + word_timestamps: bool = False, + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None, + **decode_options, +): + """ + Transcribe an audio file using Whisper + + Parameters + ---------- + model: Whisper + The Whisper model instance + + audio: Union[str, np.ndarray, torch.Tensor] + The path to the audio file to open, or the audio waveform + + verbose: bool + Whether to display the text being decoded to the console. If True, displays all the details, + If False, displays minimal details. If None, does not display anything + + temperature: Union[float, Tuple[float, ...]] + Temperature for sampling. It can be a tuple of temperatures, which will be successively used + upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. + + compression_ratio_threshold: float + If the gzip compression ratio is above this value, treat as failed + + logprob_threshold: float + If the average log probability over sampled tokens is below this value, treat as failed + + no_speech_threshold: float + If the no_speech probability is higher than this value AND the average log probability + over sampled tokens is below `logprob_threshold`, consider the segment as silent + + condition_on_previous_text: bool + if True, the previous output of the model is provided as a prompt for the next window; + disabling may make the text inconsistent across windows, but the model becomes less prone to + getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. + + word_timestamps: bool + Extract word-level timestamps using the cross-attention pattern and dynamic time warping, + and include the timestamps for each word in each segment. + + prepend_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the next word + + append_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the previous word + + initial_prompt: Optional[str] + Optional text to provide as a prompt for the first window. This can be used to provide, or + "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns + to make it more likely to predict those word correctly. + + decode_options: dict + Keyword arguments to construct `DecodingOptions` instances + + clip_timestamps: Union[str, List[float]] + Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process. + The last end timestamp defaults to the end of the file. + + hallucination_silence_threshold: Optional[float] + When word_timestamps is True, skip silent periods longer than this threshold (in seconds) + when a possible hallucination is detected + + Returns + ------- + A dictionary containing the resulting text ("text") and segment-level details ("segments"), and + the spoken language ("language"), which is detected when `decode_options["language"]` is None. + """ + dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 + if model.device == torch.device("cpu"): + if torch.cuda.is_available(): + warnings.warn("Performing inference on CPU when CUDA is available") + if dtype == torch.float16: + warnings.warn("FP16 is not supported on CPU; using FP32 instead") + dtype = torch.float32 + + if dtype == torch.float32: + decode_options["fp16"] = False + + # Pad 30-seconds of silence to the input audio, for slicing + mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) + content_frames = mel.shape[-1] - N_FRAMES + content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) + + if decode_options.get("language", None) is None: + if not model.is_multilingual: + decode_options["language"] = "en" + else: + if verbose: + print( + "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" + ) + mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) + _, probs = model.detect_language(mel_segment) + decode_options["language"] = max(probs, key=probs.get) + if verbose is not None: + print( + f"Detected language: {LANGUAGES[decode_options['language']].title()}" + ) + + language: str = decode_options["language"] + task: str = decode_options.get("task", "transcribe") + tokenizer = get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=task, + ) + + if isinstance(clip_timestamps, str): + clip_timestamps = [ + float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else []) + ] + seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps] + if len(seek_points) == 0: + seek_points.append(0) + if len(seek_points) % 2 == 1: + seek_points.append(content_frames) + seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) + + punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" + + if word_timestamps and task == "translate": + warnings.warn("Word-level timestamps on translations may not be reliable.") + + def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: + temperatures = ( + [temperature] if isinstance(temperature, (int, float)) else temperature + ) + decode_result = None + + for t in temperatures: + kwargs = {**decode_options} + if t > 0: + # disable beam_size and patience when t > 0 + kwargs.pop("beam_size", None) + kwargs.pop("patience", None) + else: + # disable best_of when t == 0 + kwargs.pop("best_of", None) + + options = DecodingOptions(**kwargs, temperature=t) + decode_result = model.decode(segment, options) + + needs_fallback = False + if ( + compression_ratio_threshold is not None + and decode_result.compression_ratio > compression_ratio_threshold + ): + needs_fallback = True # too repetitive + if ( + logprob_threshold is not None + and decode_result.avg_logprob < logprob_threshold + ): + needs_fallback = True # average log probability is too low + if ( + no_speech_threshold is not None + and decode_result.no_speech_prob > no_speech_threshold + ): + needs_fallback = False # silence + if not needs_fallback: + break + + return decode_result + + clip_idx = 0 + seek = seek_clips[clip_idx][0] + input_stride = exact_div( + N_FRAMES, model.dims.n_audio_ctx + ) # mel frames per output token: 2 + time_precision = ( + input_stride * HOP_LENGTH / SAMPLE_RATE + ) # time per output token: 0.02 (seconds) + all_tokens = [] + all_segments = [] + prompt_reset_since = 0 + + if initial_prompt is not None: + initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) + all_tokens.extend(initial_prompt_tokens) + else: + initial_prompt_tokens = [] + + def new_segment( + *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult + ): + tokens = tokens.tolist() + text_tokens = [token for token in tokens if token < tokenizer.eot] + return { + "seek": seek, + "start": start, + "end": end, + "text": tokenizer.decode(text_tokens), + "tokens": tokens, + "temperature": result.temperature, + "avg_logprob": result.avg_logprob, + "compression_ratio": result.compression_ratio, + "no_speech_prob": result.no_speech_prob, + } + + # show the progress bar when verbose is False (if True, transcribed text will be printed) + with tqdm.tqdm( + total=content_frames, unit="frames", disable=verbose is not False + ) as pbar: + last_speech_timestamp = 0.0 + # NOTE: This loop is obscurely flattened to make the diff readable. + # A later commit should turn this into a simpler nested loop. + # for seek_clip_start, seek_clip_end in seek_clips: + # while seek < seek_clip_end + while clip_idx < len(seek_clips): + seek_clip_start, seek_clip_end = seek_clips[clip_idx] + if seek < seek_clip_start: + seek = seek_clip_start + if seek >= seek_clip_end: + clip_idx += 1 + if clip_idx < len(seek_clips): + seek = seek_clips[clip_idx][0] + continue + time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) + window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) + segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek) + mel_segment = mel[:, seek : seek + segment_size] + segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE + mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) + + decode_options["prompt"] = all_tokens[prompt_reset_since:] + result: DecodingResult = decode_with_fallback(mel_segment) + tokens = torch.tensor(result.tokens) + + if no_speech_threshold is not None: + # no voice activity check + should_skip = result.no_speech_prob > no_speech_threshold + if ( + logprob_threshold is not None + and result.avg_logprob > logprob_threshold + ): + # don't skip if the logprob is high enough, despite the no_speech_prob + should_skip = False + + if should_skip: + seek += segment_size # fast-forward to the next segment boundary + continue + + previous_seek = seek + current_segments = [] + + # anomalous words are very long/short/improbable + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score + + def is_segment_anomaly(segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [w for w in segment["words"] if w["word"] not in punctuation] + words = words[:8] + score = sum(word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) + + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) + + timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) + single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] + + consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + consecutive.add_(1) + if len(consecutive) > 0: + # if the output contains two consecutive timestamp tokens + slices = consecutive.tolist() + if single_timestamp_ending: + slices.append(len(tokens)) + + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_pos = ( + sliced_tokens[0].item() - tokenizer.timestamp_begin + ) + end_timestamp_pos = ( + sliced_tokens[-1].item() - tokenizer.timestamp_begin + ) + current_segments.append( + new_segment( + start=time_offset + start_timestamp_pos * time_precision, + end=time_offset + end_timestamp_pos * time_precision, + tokens=sliced_tokens, + result=result, + ) + ) + last_slice = current_slice + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last timestamp. + seek += segment_size + else: + # otherwise, ignore the unfinished segment and seek to the last timestamp + last_timestamp_pos = ( + tokens[last_slice - 1].item() - tokenizer.timestamp_begin + ) + seek += last_timestamp_pos * input_stride + else: + duration = segment_duration + timestamps = tokens[timestamp_tokens.nonzero().flatten()] + if ( + len(timestamps) > 0 + and timestamps[-1].item() != tokenizer.timestamp_begin + ): + # no consecutive timestamps but it has a timestamp; use the last one. + last_timestamp_pos = ( + timestamps[-1].item() - tokenizer.timestamp_begin + ) + duration = last_timestamp_pos * time_precision + + current_segments.append( + new_segment( + start=time_offset, + end=time_offset + duration, + tokens=tokens, + result=result, + ) + ) + seek += segment_size + + if word_timestamps: + add_word_timestamps( + segments=current_segments, + model=model, + tokenizer=tokenizer, + mel=mel_segment, + num_frames=segment_size, + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + last_speech_timestamp=last_speech_timestamp, + ) + + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + seek = round(last_word_end * FRAMES_PER_SECOND) + + # skip silence before possible hallucinations + if hallucination_silence_threshold is not None: + threshold = hallucination_silence_threshold + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + remaining_duration = window_end_time - last_word_end + if remaining_duration > threshold: + seek = round(last_word_end * FRAMES_PER_SECOND) + else: + seek = previous_seek + segment_size + + # if first segment might be a hallucination, skip leading silence + first_segment = next_words_segment(current_segments) + if first_segment is not None and is_segment_anomaly(first_segment): + gap = first_segment["start"] - time_offset + if gap > threshold: + seek = previous_seek + round(gap * FRAMES_PER_SECOND) + continue + + # skip silence before any possible hallucination that is surrounded + # by silence or more hallucinations + hal_last_end = last_speech_timestamp + for si in range(len(current_segments)): + segment = current_segments[si] + if not segment["words"]: + continue + if is_segment_anomaly(segment): + next_segment = next_words_segment( + current_segments[si + 1 :] + ) + if next_segment is not None: + hal_next_start = next_segment["words"][0]["start"] + else: + hal_next_start = time_offset + segment_duration + silence_before = ( + segment["start"] - hal_last_end > threshold + or segment["start"] < threshold + or segment["start"] - time_offset < 2.0 + ) + silence_after = ( + hal_next_start - segment["end"] > threshold + or is_segment_anomaly(next_segment) + or window_end_time - segment["end"] < 2.0 + ) + if silence_before and silence_after: + seek = round( + max(time_offset + 1, segment["start"]) + * FRAMES_PER_SECOND + ) + if content_duration - segment["end"] < threshold: + seek = content_frames + current_segments[si:] = [] + break + hal_last_end = segment["end"] + + last_word_end = get_end(current_segments) + if last_word_end is not None: + last_speech_timestamp = last_word_end + + if verbose: + for segment in current_segments: + start, end, text = segment["start"], segment["end"], segment["text"] + line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" + print(make_safe(line)) + + # if a segment is instantaneous or does not contain text, clear it + for i, segment in enumerate(current_segments): + if segment["start"] == segment["end"] or segment["text"].strip() == "": + segment["text"] = "" + segment["tokens"] = [] + segment["words"] = [] + + all_segments.extend( + [ + {"id": i, **segment} + for i, segment in enumerate( + current_segments, start=len(all_segments) + ) + ] + ) + all_tokens.extend( + [token for segment in current_segments for token in segment["tokens"]] + ) + + if not condition_on_previous_text or result.temperature > 0.5: + # do not feed the prompt tokens if a high temperature was used + prompt_reset_since = len(all_tokens) + + # update progress bar + pbar.update(min(content_frames, seek) - previous_seek) + + return dict( + text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), + segments=all_segments, + language=language, + ) + + +def cli(): + from . import available_models + + def valid_model_name(name): + if name in available_models() or os.path.exists(name): + return name + raise ValueError( + f"model should be one of {available_models()} or path to a model checkpoint" + ) + + # fmt: off + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") + parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use") + parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") + parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") + parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") + + parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") + parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") + + parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") + parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") + parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") + parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") + parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") + + parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") + parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") + parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") + parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") + + parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") + parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") + parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") + parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") + parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") + parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") + parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") + parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") + parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line") + parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") + parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment") + parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") + parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file") + parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected") + # fmt: on + + args = parser.parse_args().__dict__ + model_name: str = args.pop("model") + model_dir: str = args.pop("model_dir") + output_dir: str = args.pop("output_dir") + output_format: str = args.pop("output_format") + device: str = args.pop("device") + os.makedirs(output_dir, exist_ok=True) + + if model_name.endswith(".en") and args["language"] not in {"en", "English"}: + if args["language"] is not None: + warnings.warn( + f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead." + ) + args["language"] = "en" + + temperature = args.pop("temperature") + if (increment := args.pop("temperature_increment_on_fallback")) is not None: + temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment)) + else: + temperature = [temperature] + + if (threads := args.pop("threads")) > 0: + torch.set_num_threads(threads) + + from . import load_model + + model = load_model(model_name, device=device, download_root=model_dir) + + writer = get_writer(output_format, output_dir) + word_options = [ + "highlight_words", + "max_line_count", + "max_line_width", + "max_words_per_line", + ] + if not args["word_timestamps"]: + for option in word_options: + if args[option]: + parser.error(f"--{option} requires --word_timestamps True") + if args["max_line_count"] and not args["max_line_width"]: + warnings.warn("--max_line_count has no effect without --max_line_width") + if args["max_words_per_line"] and args["max_line_width"]: + warnings.warn("--max_words_per_line has no effect with --max_line_width") + writer_args = {arg: args.pop(arg) for arg in word_options} + for audio_path in args.pop("audio"): + try: + result = transcribe(model, audio_path, temperature=temperature, **args) + writer(result, audio_path, **writer_args) + except Exception as e: + traceback.print_exc() + print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") + + +if __name__ == "__main__": + cli() diff --git a/egogpt/model/speech_encoder/utils.py b/egogpt/model/speech_encoder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9b138626edba89eeebb5e5a30507489058b79e --- /dev/null +++ b/egogpt/model/speech_encoder/utils.py @@ -0,0 +1,316 @@ +import json +import os +import re +import sys +import zlib +from typing import Callable, List, Optional, TextIO + +system_encoding = sys.getdefaultencoding() + +if system_encoding != "utf-8": + + def make_safe(string): + # replaces any character not representable using the system default encoding with an '?', + # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). + return string.encode(system_encoding, errors="replace").decode(system_encoding) + +else: + + def make_safe(string): + # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding + return string + + +def exact_div(x, y): + assert x % y == 0 + return x // y + + +def str2bool(string): + str2val = {"True": True, "False": False} + if string in str2val: + return str2val[string] + else: + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + + +def optional_int(string): + return None if string == "None" else int(string) + + +def optional_float(string): + return None if string == "None" else float(string) + + +def compression_ratio(text) -> float: + text_bytes = text.encode("utf-8") + return len(text_bytes) / len(zlib.compress(text_bytes)) + + +def format_timestamp( + seconds: float, always_include_hours: bool = False, decimal_marker: str = "." +): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + ) + + +def get_start(segments: List[dict]) -> Optional[float]: + return next( + (w["start"] for s in segments for w in s["words"]), + segments[0]["start"] if segments else None, + ) + + +def get_end(segments: List[dict]) -> Optional[float]: + return next( + (w["end"] for s in reversed(segments) for w in reversed(s["words"])), + segments[-1]["end"] if segments else None, + ) + + +class ResultWriter: + extension: str + + def __init__(self, output_dir: str): + self.output_dir = output_dir + + def __call__( + self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs + ): + audio_basename = os.path.basename(audio_path) + audio_basename = os.path.splitext(audio_basename)[0] + output_path = os.path.join( + self.output_dir, audio_basename + "." + self.extension + ) + + with open(output_path, "w", encoding="utf-8") as f: + self.write_result(result, file=f, options=options, **kwargs) + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + raise NotImplementedError + + +class WriteTXT(ResultWriter): + extension: str = "txt" + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + for segment in result["segments"]: + print(segment["text"].strip(), file=file, flush=True) + + +class SubtitlesWriter(ResultWriter): + always_include_hours: bool + decimal_marker: str + + def iterate_result( + self, + result: dict, + options: Optional[dict] = None, + *, + max_line_width: Optional[int] = None, + max_line_count: Optional[int] = None, + highlight_words: bool = False, + max_words_per_line: Optional[int] = None, + ): + options = options or {} + max_line_width = max_line_width or options.get("max_line_width") + max_line_count = max_line_count or options.get("max_line_count") + highlight_words = highlight_words or options.get("highlight_words", False) + max_words_per_line = max_words_per_line or options.get("max_words_per_line") + preserve_segments = max_line_count is None or max_line_width is None + max_line_width = max_line_width or 1000 + max_words_per_line = max_words_per_line or 1000 + + def iterate_subtitles(): + line_len = 0 + line_count = 1 + # the next subtitle to yield (a list of word timings with whitespace) + subtitle: List[dict] = [] + last: float = get_start(result["segments"]) or 0.0 + for segment in result["segments"]: + chunk_index = 0 + words_count = max_words_per_line + while chunk_index < len(segment["words"]): + remaining_words = len(segment["words"]) - chunk_index + if max_words_per_line > len(segment["words"]) - chunk_index: + words_count = remaining_words + for i, original_timing in enumerate( + segment["words"][chunk_index : chunk_index + words_count] + ): + timing = original_timing.copy() + long_pause = ( + not preserve_segments and timing["start"] - last > 3.0 + ) + has_room = line_len + len(timing["word"]) <= max_line_width + seg_break = i == 0 and len(subtitle) > 0 and preserve_segments + if ( + line_len > 0 + and has_room + and not long_pause + and not seg_break + ): + # line continuation + line_len += len(timing["word"]) + else: + # new line + timing["word"] = timing["word"].strip() + if ( + len(subtitle) > 0 + and max_line_count is not None + and (long_pause or line_count >= max_line_count) + or seg_break + ): + # subtitle break + yield subtitle + subtitle = [] + line_count = 1 + elif line_len > 0: + # line break + line_count += 1 + timing["word"] = "\n" + timing["word"] + line_len = len(timing["word"].strip()) + subtitle.append(timing) + last = timing["start"] + chunk_index += max_words_per_line + if len(subtitle) > 0: + yield subtitle + + if len(result["segments"]) > 0 and "words" in result["segments"][0]: + for subtitle in iterate_subtitles(): + subtitle_start = self.format_timestamp(subtitle[0]["start"]) + subtitle_end = self.format_timestamp(subtitle[-1]["end"]) + subtitle_text = "".join([word["word"] for word in subtitle]) + if highlight_words: + last = subtitle_start + all_words = [timing["word"] for timing in subtitle] + for i, this_word in enumerate(subtitle): + start = self.format_timestamp(this_word["start"]) + end = self.format_timestamp(this_word["end"]) + if last != start: + yield last, start, subtitle_text + + yield start, end, "".join( + [ + re.sub(r"^(\s*)(.*)$", r"\1\2", word) + if j == i + else word + for j, word in enumerate(all_words) + ] + ) + last = end + else: + yield subtitle_start, subtitle_end, subtitle_text + else: + for segment in result["segments"]: + segment_start = self.format_timestamp(segment["start"]) + segment_end = self.format_timestamp(segment["end"]) + segment_text = segment["text"].strip().replace("-->", "->") + yield segment_start, segment_end, segment_text + + def format_timestamp(self, seconds: float): + return format_timestamp( + seconds=seconds, + always_include_hours=self.always_include_hours, + decimal_marker=self.decimal_marker, + ) + + +class WriteVTT(SubtitlesWriter): + extension: str = "vtt" + always_include_hours: bool = False + decimal_marker: str = "." + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + print("WEBVTT\n", file=file) + for start, end, text in self.iterate_result(result, options, **kwargs): + print(f"{start} --> {end}\n{text}\n", file=file, flush=True) + + +class WriteSRT(SubtitlesWriter): + extension: str = "srt" + always_include_hours: bool = True + decimal_marker: str = "," + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + for i, (start, end, text) in enumerate( + self.iterate_result(result, options, **kwargs), start=1 + ): + print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) + + +class WriteTSV(ResultWriter): + """ + Write a transcript to a file in TSV (tab-separated values) format containing lines like: + \t\t + + Using integer milliseconds as start and end times means there's no chance of interference from + an environment setting a language encoding that causes the decimal in a floating point number + to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. + """ + + extension: str = "tsv" + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + print("start", "end", "text", sep="\t", file=file) + for segment in result["segments"]: + print(round(1000 * segment["start"]), file=file, end="\t") + print(round(1000 * segment["end"]), file=file, end="\t") + print(segment["text"].strip().replace("\t", " "), file=file, flush=True) + + +class WriteJSON(ResultWriter): + extension: str = "json" + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + json.dump(result, file) + + +def get_writer( + output_format: str, output_dir: str +) -> Callable[[dict, TextIO, dict], None]: + writers = { + "txt": WriteTXT, + "vtt": WriteVTT, + "srt": WriteSRT, + "tsv": WriteTSV, + "json": WriteJSON, + } + + if output_format == "all": + all_writers = [writer(output_dir) for writer in writers.values()] + + def write_all( + result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + for writer in all_writers: + writer(result, file, options, **kwargs) + + return write_all + + return writers[output_format](output_dir) diff --git a/egogpt/model/speech_projector/__pycache__/builder.cpython-310.pyc b/egogpt/model/speech_projector/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10b1d6b68888cc81202bd198765fc5e7ec108a6b Binary files /dev/null and b/egogpt/model/speech_projector/__pycache__/builder.cpython-310.pyc differ diff --git a/egogpt/model/speech_projector/__pycache__/speech_projector.cpython-310.pyc b/egogpt/model/speech_projector/__pycache__/speech_projector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..835ea5f74668561da2a038661085d2b6b0742146 Binary files /dev/null and b/egogpt/model/speech_projector/__pycache__/speech_projector.cpython-310.pyc differ diff --git a/egogpt/model/speech_projector/builder.py b/egogpt/model/speech_projector/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..bf96772ad93b158028fa2dfa02b20fce19f743ed --- /dev/null +++ b/egogpt/model/speech_projector/builder.py @@ -0,0 +1,9 @@ +from .speech_projector import EncoderProjectorConcat + + +def build_speech_projector(config): + projector_type = getattr(config, "speech_projector_type", "linear") + if projector_type == "linear": + return EncoderProjectorConcat(config) + + raise ValueError(f"Unknown projector type: {projector_type}") diff --git a/egogpt/model/speech_projector/speech_projector.py b/egogpt/model/speech_projector/speech_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..2338d976a3525c8ddf6860fa4e6b7bc8c643bfe1 --- /dev/null +++ b/egogpt/model/speech_projector/speech_projector.py @@ -0,0 +1,30 @@ +# Adopted from https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/projector.py + + +import torch +import torch.nn as nn + + +class EncoderProjectorConcat(nn.Module): + def __init__(self, config): + super().__init__() + self.k = config.speech_encoder_ds_rate + self.encoder_dim = config.speech_encoder_hidden_size + self.llm_dim = config.hidden_size + self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(2048, config.hidden_size) + + def forward(self, x): + batch_size, seq_len, dim = x.size() + num_frames_to_discard = seq_len % self.k + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view(batch_size, seq_len // self.k, dim * self.k) + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x diff --git a/egogpt/train/__pycache__/llava_trainer.cpython-310.pyc b/egogpt/train/__pycache__/llava_trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c3c092c68b8ce25a98a71bf7fbaefbb846e3207 Binary files /dev/null and b/egogpt/train/__pycache__/llava_trainer.cpython-310.pyc differ diff --git a/egogpt/train/llava_trainer.py b/egogpt/train/llava_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8a12896ab066a5a090b737a9957cf6d31d7698e9 --- /dev/null +++ b/egogpt/train/llava_trainer.py @@ -0,0 +1,584 @@ +import importlib.metadata +import os +from typing import List, Optional + +import torch +import torch.nn as nn +from packaging import version +from peft import PeftModel +from torch.utils.data import Sampler +from transformers import Trainer +from transformers.trainer import ( + ALL_LAYERNORM_LAYERS, + get_parameter_names, + has_length, + is_sagemaker_mp_enabled, + logger, +) +from transformers.trainer_pt_utils import get_dataloader_sampler +from transformers.trainer_pt_utils import ( + get_length_grouped_indices as get_length_grouped_indices_hf, +) +from transformers.trainer_pt_utils import get_model_param_count, get_parameter_names +from transformers.trainer_utils import ( + HPSearchBackend, + TrainOutput, + has_length, + speed_metrics, +) +from transformers.training_args import ParallelMode +from transformers.utils import ( + is_accelerate_available, + is_peft_available, + is_sagemaker_mp_enabled, + is_torch_xla_available, +) + +TIME_STAMP = os.environ.get("TIME_STAMP", "default_value") +BYTENAS = os.environ.get("BYTENAS", "vl-research") + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + print(name, "no ignore status") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = { + k: t + for k, t in named_params + if any(key_match in k for key_match in keys_to_match) + } + to_return = { + k: maybe_zero_3(v, ignore_status=True, name=k).cpu() + for k, v in to_return.items() + } + return to_return + + +def split_to_even_chunks(indices, lengths, num_chunks): + """ + Split a list of indices into `chunks` chunks of roughly equal lengths. + """ + + if len(indices) % num_chunks != 0: + return [indices[i::num_chunks] for i in range(num_chunks)] + + num_indices_per_chunk = len(indices) // num_chunks + + chunks = [[] for _ in range(num_chunks)] + chunks_lengths = [0 for _ in range(num_chunks)] + for index in indices: + shortest_chunk = chunks_lengths.index(min(chunks_lengths)) + chunks[shortest_chunk].append(index) + chunks_lengths[shortest_chunk] += lengths[index] + if len(chunks[shortest_chunk]) == num_indices_per_chunk: + chunks_lengths[shortest_chunk] = float("inf") + + return chunks + + +def get_variable_length_grouped_indices( + lengths, batch_size, world_size, megabatch_mult=8, generator=None +): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True) + megabatch_size = world_size * batch_size * megabatch_mult + megabatches = [ + sorted_indices[i : i + megabatch_size] + for i in range(0, len(lengths), megabatch_size) + ] + megabatches = [ + sorted(megabatch, key=lambda i: indices[i], reverse=True) + for megabatch in megabatches + ] + shuffled_indices = [i for megabatch in megabatches for i in megabatch] + world_batch_size = world_size * batch_size + batches = [ + shuffled_indices[i : i + world_batch_size] + for i in range(0, len(lengths), world_batch_size) + ] + batch_indices = torch.randperm(len(batches), generator=generator) + batches = [batches[i] for i in batch_indices] + + return [i for batch in batches for i in batch] + + +def get_modality_length_grouped_indices( + lengths, batch_size, world_size, generator=None +): + """ + Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - reorder by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + assert all(l != 0 for l in lengths), "Should not have zero length." + if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): + # all samples are in the same modality + return get_length_grouped_indices( + lengths, batch_size, world_size, generator=generator + ) + mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) + lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) + + mm_shuffle = [ + mm_indices[i] + for i in get_length_grouped_indices( + mm_lengths, batch_size, world_size, generator=None + ) + ] + lang_shuffle = [ + lang_indices[i] + for i in get_length_grouped_indices( + lang_lengths, batch_size, world_size, generator=None + ) + ] + megabatch_size = world_size * batch_size + mm_megabatches = [ + mm_shuffle[i : i + megabatch_size] + for i in range(0, len(mm_shuffle), megabatch_size) + ] + lang_megabatches = [ + lang_shuffle[i : i + megabatch_size] + for i in range(0, len(lang_shuffle), megabatch_size) + ] + + last_mm = mm_megabatches[-1] + last_lang = lang_megabatches[-1] + additional_batch = last_mm + last_lang + megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] + megabatch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in megabatch_indices] + + if len(additional_batch) > 0: + megabatches.append(sorted(additional_batch)) + + return [i for megabatch in megabatches for i in megabatch] + + +def get_length_grouped_indices( + lengths, batch_size, world_size, generator=None, merge=True +): + """ + Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - reorder by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = world_size * batch_size + megabatches = [ + indices[i : i + megabatch_size].tolist() + for i in range(0, len(lengths), megabatch_size) + ] + megabatches = [ + sorted(megabatch, key=lambda i: lengths[i], reverse=True) + for megabatch in megabatches + ] + megabatches = [ + split_to_even_chunks(megabatch, lengths, world_size) + for megabatch in megabatches + ] + + return [i for megabatch in megabatches for batch in megabatch for i in batch] + + +def get_length_grouped_indices_auto_single( + lengths, batch_size, world_size, generator=None +): + indices = get_length_grouped_indices_hf( + lengths, batch_size * world_size, generator=generator + ) + + megabatch_size = world_size * batch_size + megabatches = [ + indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size) + ] + megabatches = [ + sorted(megabatch, key=lambda i: lengths[i], reverse=True) + for megabatch in megabatches + ] + megabatches = [ + split_to_even_chunks(megabatch, lengths, world_size) + for megabatch in megabatches + ] + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + batch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in batch_indices] + + return [i for megabatch in megabatches for batch in megabatch for i in batch] + + +def get_modality_length_grouped_indices_auto( + lengths, batch_size, world_size, generator=None +): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + assert all(l != 0 for l in lengths), "Should not have zero length." + if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): + # all samples are in the same modality + return get_length_grouped_indices_auto_single( + lengths, batch_size, world_size, generator=generator + ) + mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) + lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) + + mm_shuffle = [ + mm_indices[i] + for i in get_length_grouped_indices_auto_single( + mm_lengths, batch_size, world_size, generator=None + ) + ] + lang_shuffle = [ + lang_indices[i] + for i in get_length_grouped_indices_auto_single( + lang_lengths, batch_size, world_size, generator=None + ) + ] + megabatch_size = world_size * batch_size + mm_megabatches = [ + mm_shuffle[i : i + megabatch_size] + for i in range(0, len(mm_shuffle), megabatch_size) + ] + lang_megabatches = [ + lang_shuffle[i : i + megabatch_size] + for i in range(0, len(lang_shuffle), megabatch_size) + ] + + last_mm = mm_megabatches[-1] + last_lang = lang_megabatches[-1] + additional_batch = last_mm + last_lang + megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] + megabatch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in megabatch_indices] + + if len(additional_batch) > 0: + megabatches.append(sorted(additional_batch)) + + return [i for megabatch in megabatches for i in megabatch] + + +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__( + self, + batch_size: int, + world_size: int, + lengths: Optional[List[int]] = None, + generator=None, + variable_length: bool = False, + group_by_modality: bool = False, + group_by_modality_auto: bool = False, + ): + if lengths is None: + raise ValueError("Lengths must be provided.") + + self.batch_size = batch_size + self.world_size = world_size + self.lengths = lengths + self.generator = generator + self.variable_length = variable_length + self.group_by_modality = group_by_modality + self.group_by_modality_auto = group_by_modality_auto + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + if self.variable_length: + assert ( + not self.group_by_modality + ), "Variable length grouping is not supported with modality grouping." + indices = get_variable_length_grouped_indices( + self.lengths, self.batch_size, self.world_size, generator=self.generator + ) + else: + if self.group_by_modality: + indices = get_modality_length_grouped_indices( + self.lengths, + self.batch_size, + self.world_size, + generator=self.generator, + ) + elif self.group_by_modality_auto: + indices = get_modality_length_grouped_indices_auto( + self.lengths, + self.batch_size, + self.world_size, + generator=self.generator, + ) + else: + indices = get_length_grouped_indices_auto_single( + self.lengths, + self.batch_size, + self.world_size, + generator=self.generator, + ) + return iter(indices) + + +def _is_peft_model(model): + if is_peft_available(): + classes_to_check = (PeftModel,) if is_peft_available() else () + # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 + if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): + from peft import PeftMixedModel + + classes_to_check = (*classes_to_check, PeftMixedModel) + return isinstance(model, classes_to_check) + return False + + +TRAINER_STATE_NAME = "trainer_state.json" + + +class LLaVATrainer(Trainer): + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + if self.args.group_by_length: + lengths = self.train_dataset.lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + # world_size=self.args.world_size, + world_size=self.args.world_size + * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + ) + elif self.args.group_by_modality_length: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + # world_size=self.args.world_size, + world_size=self.args.world_size + * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + group_by_modality=True, + ) + elif self.args.group_by_modality_length_auto: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + # world_size=self.args.world_size, + world_size=self.args.world_size + * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + group_by_modality_auto=True, + ) + elif self.args.group_by_varlen: + lengths = self.train_dataset.lengths + return LengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + # self.args.train_batch_size, # TODO: seems that we should have gradient_accumulation_steps + # world_size=self.args.world_size, + world_size=self.args.world_size + * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + variable_length=True, + ) + else: + return super()._get_train_sampler() + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + if is_sagemaker_mp_enabled(): + return super().create_optimizer() + + opt_model = self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + lr_mapper = {} + if self.args.speech_projector_lr is not None: + lr_mapper["speech_projector"] = self.args.speech_projector_lr + + if len(lr_mapper) > 0: + special_lr_parameters = [ + name + for name, _ in opt_model.named_parameters() + if any(module_keyword in name for module_keyword in lr_mapper) + ] + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in opt_model.named_parameters() + if ( + n in decay_parameters + and n not in special_lr_parameters + and p.requires_grad + ) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p + for n, p in opt_model.named_parameters() + if ( + n not in decay_parameters + and n not in special_lr_parameters + and p.requires_grad + ) + ], + "weight_decay": 0.0, + }, + ] + for module_keyword, lr in lr_mapper.items(): + module_parameters = [ + name + for name, _ in opt_model.named_parameters() + if module_keyword in name + ] + optimizer_grouped_parameters.extend( + [ + { + "params": [ + p + for n, p in opt_model.named_parameters() + if ( + n in decay_parameters + and n in module_parameters + and p.requires_grad + ) + ], + "weight_decay": self.args.weight_decay, + "lr": lr, + }, + { + "params": [ + p + for n, p in opt_model.named_parameters() + if ( + n not in decay_parameters + and n in module_parameters + and p.requires_grad + ) + ], + "weight_decay": 0.0, + "lr": lr, + }, + ] + ) + else: + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in opt_model.named_parameters() + if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p + for n, p in opt_model.named_parameters() + if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args + ) + + self.optimizer = optimizer_cls( + optimizer_grouped_parameters, **optimizer_kwargs + ) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum( + { + p.data_ptr(): p.numel() for p in module.parameters() + }.values() + ) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override( + module, "weight", {"optim_bits": 32} + ) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") + + return self.optimizer + + def _save_checkpoint(self, model, trial, metrics=None): + if getattr(self.args, "tune_mm_mlp_adapter", False): + from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + + # Only save Adapter + keys_to_match = ["speech_projector"] + if getattr(self.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_mm_adapter_state_maybe_zero_3( + self.model.named_parameters(), keys_to_match + ) + + if self.args.local_rank == 0 or self.args.local_rank == -1: + self.model.config.save_pretrained(output_dir) + torch.save( + weight_to_save, os.path.join(output_dir, f"speech_projector.bin") + ) + else: + print("self.is_local_process_zero()", self.is_local_process_zero()) + super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + if getattr(self.args, "tune_mm_mlp_adapter", False): + pass + super(LLaVATrainer, self)._save(output_dir, state_dict) + else: + super(LLaVATrainer, self)._save(output_dir, state_dict) diff --git a/egogpt/train/train_audio.py b/egogpt/train/train_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..2d82a5ddf6eab2eaafd1d375c8ed2fac1b5eff6b --- /dev/null +++ b/egogpt/train/train_audio.py @@ -0,0 +1,1502 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import base64 +import copy +import glob +import io +import json +import logging +import math +import os +import pathlib +import pickle +import random +import re +import time +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Sequence + +import numpy as np +import soundfile as sf +import tokenizers +import torch +import transformers +import whisper +from packaging import version +from PIL import Image +from safetensors.torch import load_file as safetensor_load_file +from scipy.signal import resample +from torch.utils.data import Dataset + +from egogpt import conversation as conversation_lib +from egogpt.constants import ( + DEFAULT_IMAGE_TOKEN, + DEFAULT_SPEECH_TOKEN, + IGNORE_INDEX, + IMAGE_TOKEN_INDEX, + SPEECH_TOKEN_INDEX, +) +from egogpt.mm_utils import ( + process_anyres_image, + process_highres_image, + process_highres_image_crop_split, +) +from egogpt.model import * +from egogpt.train.llava_trainer import LLaVATrainer +from egogpt.utils import process_video_with_decord, process_video_with_decord_byframe + +local_rank = None +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse( + "0.14" +) + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_mm_mlp_adapter: bool = field(default=False) + tune_speech_generator_only: bool = field(default=False) + speech_encoder: Optional[str] = field(default=None) + unfreeze_mm_speech_encoder: bool = field(default=False) + mm_vision_select_layer: Optional[int] = field( + default=-1 + ) # default to the last layer + pretrain_speech_projector: Optional[str] = field(default=None) + speech_projector_type: Optional[str] = field(default="linear") + speech_encoder_type: Optional[str] = field(default="whisper") + speech_encoder_config: Optional[str] = field( + default="models/speech_encoder/large-v3.pt" + ) + speech_encoder_ds_rate: Optional[int] = field(default=5) + speech_encoder_hidden_size: Optional[int] = field(default=1280) + tune_mm_mlp_adapter: bool = field(default=False) + tune_mm_vision_resampler: bool = field(default=False) + vision_tower: Optional[str] = field(default=None) + unfreeze_mm_vision_tower: bool = field(default=False) + unfreeze_language_model: bool = field(default=False) + mm_vision_select_layer: Optional[int] = field( + default=-1 + ) # default to the last layer + pretrain_mm_mlp_adapter: Optional[str] = field(default=None) + mm_projector_type: Optional[str] = field(default="linear") + mm_use_im_start_end: bool = field(default=False) + mm_use_im_patch_token: bool = field(default=True) + mm_patch_merge_type: Optional[str] = field(default="flat") + mm_vision_select_feature: Optional[str] = field(default="patch") + mm_resampler_type: Optional[str] = field(default=None) + mm_mask_drop_mode: str = field(default="fixed") + mm_mask_drop_skip_percentage: float = field(default=0.0) + mm_mask_drop_ratio: float = field(default=0.25) + mm_mask_drop_ratio_upper: Optional[float] = field(default=None) + mm_mask_drop_ratio_lower: Optional[float] = field(default=None) + mm_spatial_pool_stride: Optional[int] = field(default=None) + mm_spatial_pool_mode: str = field(default="bilinear") + mm_spatial_pool_out_channels: Optional[int] = field(default=None) + mm_perceiver_depth: Optional[int] = field(default=3) + mm_perceiver_latents: Optional[int] = field(default=32) + mm_perceiver_ff_mult: Optional[float] = field(default=4) + mm_perceiver_pretrained: Optional[str] = field(default=None) + mm_qformer_depth: Optional[int] = field(default=3) + mm_qformer_latents: Optional[int] = field(default=32) + mm_qformer_pretrained: Optional[str] = field(default=None) + rope_scaling_factor: Optional[float] = field(default=None) + rope_scaling_type: Optional[str] = field(default=None) + + s2: Optional[bool] = field(default=False) + s2_scales: Optional[str] = field(default="336,672,1008") + + use_pos_skipping: Optional[bool] = field(default=False) + pos_skipping_range: Optional[int] = field(default=4096) + + mm_newline_position: Optional[str] = field(default="grid") + delay_load: Optional[bool] = field(default=True) + delay_load_audio: Optional[bool] = field(default=True) + add_faster_video: Optional[bool] = field(default=False) + faster_token_stride: Optional[int] = field(default=10) + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + lazy_preprocess: bool = False + is_multimodal: bool = False + image_aspect_ratio: str = "square" + image_grid_pinpoints: Optional[str] = field(default=None) + image_crop_resolution: Optional[int] = field(default=None) + image_split_resolution: Optional[int] = field(default=None) + video_fps: Optional[int] = field(default=1) + frames_upbound: Optional[int] = field(default=100) + force_sample: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + freeze_mm_mlp_adapter: bool = field(default=False) + mpt_attn_impl: Optional[str] = field(default="triton") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={ + "help": "Compress the quantization statistics through double quantization." + }, + ) + quant_type: str = field( + default="nf4", + metadata={ + "help": "Quantization data type to use. Should be one of `fp4` or `nf4`." + }, + ) + bits: int = field(default=16, metadata={"help": "How many bits to use."}) + lora_enable: bool = field(default=False) + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + speech_projector_lr: Optional[float] = None + gradient_checkpointing: bool = field(default=True) + mm_speech_encoder_lr: Optional[float] = None + diffusion_head_lr: Optional[float] = None + group_by_varlen: bool = field(default=False) + group_by_modality_length: bool = field(default=False) + group_by_modality_length_auto: bool = field(default=False) + min_lr_ratio: float = field(default=0.0) + sample_independently: bool = field(default=False) + freeze_mm_mlp_adapter: bool = field(default=False) + mm_projector_lr: Optional[float] = None + mm_vision_tower_lr: Optional[float] = None + freeze_mm_vision_resampler: bool = field(default=False) + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning( + f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}" + ) + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = { + k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() + } + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = { + k: t + for k, t in named_params + if any(key_match in k for key_match in keys_to_match) + } + to_return = { + k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() + } + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ["speech_projector", "speech_encoder"] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + + if getattr(trainer.args, "tune_mm_mlp_adapter", False): + # Only save Adapter + keys_to_match = ["speech_projector"] + if getattr(trainer.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_mm_adapter_state_maybe_zero_3( + trainer.model.named_parameters(), keys_to_match + ) + trainer.model.config.save_pretrained(output_dir) + + current_folder = output_dir.split("/")[-1] + parent_folder = os.path.dirname(output_dir) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + if current_folder.startswith("checkpoint-"): + speech_projector_folder = os.path.join( + parent_folder, "speech_projector" + ) + os.makedirs(speech_projector_folder, exist_ok=True) + torch.save( + weight_to_save, + os.path.join(speech_projector_folder, f"{current_folder}.bin"), + ) + else: + torch.save( + weight_to_save, os.path.join(output_dir, f"speech_projector.bin") + ) + return + + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn( + strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer +) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) + for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = "unknown" + sentence["value"] = ( + BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL + ) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def tokenizer_speech_token( + prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None +): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if ( + len(prompt_chunks) > 0 + and len(prompt_chunks[0]) > 0 + and prompt_chunks[0][0] == tokenizer.bos_token_id + ): + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == "pt": + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f"Unsupported tensor type: {return_tensors}") + return input_ids + + +def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + # Add speech and image special tokens to the beginning of the conversation + for source in sources: + for sentence in source: + if DEFAULT_SPEECH_TOKEN in sentence["value"]: + sentence["value"] = ( + sentence["value"].replace(DEFAULT_SPEECH_TOKEN, "").strip() + ) + sentence["value"] = DEFAULT_SPEECH_TOKEN + "\n" + sentence["value"] + sentence["value"] = sentence["value"].strip() + if DEFAULT_IMAGE_TOKEN in sentence["value"]: + sentence["value"] = ( + sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() + ) + sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] + sentence["value"] = sentence["value"].strip() + return sources + + +def preprocess_llama_2( + sources, tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_speech: + input_ids = torch.stack( + [ + tokenizer_speech_token(prompt, tokenizer, return_tensors="pt") + for prompt in conversations + ], + dim=0, + ) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 + + # Mask targets + sep = "[/INST] " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_speech: + round_len = len(tokenizer_speech_token(rou, tokenizer)) + instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_llama_3( + sources, tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + assert len(source) == 2, "now only support single-turn conversation" + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_speech: + input_ids = torch.stack( + [ + tokenizer_speech_token(prompt, tokenizer, return_tensors="pt") + for prompt in conversations + ], + dim=0, + ) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3 + + # Mask targets + sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>\n\n" + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + parts = conversation.split(sep) + parts[0] += sep + + if has_speech: + conversation_len = len(tokenizer_speech_token(conversation, tokenizer)) + instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 1 + else: + conversation_len = len(tokenizer(conversation).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + cur_len += conversation_len + target[cur_len:] = IGNORE_INDEX + + # if cur_len < tokenizer.model_max_length: + # if cur_len != total_len: + # target[:] = IGNORE_INDEX + # print( + # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + # f" (ignored)" + # ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_v1( + sources, tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_speech: + input_ids = torch.stack( + [ + tokenizer_speech_token(prompt, tokenizer, return_tensors="pt") + for prompt in conversations + ], + dim=0, + ) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + if conv.sep_style == conversation_lib.SeparatorStyle.TWO: + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_speech: + round_len = len(tokenizer_speech_token(rou, tokenizer)) + instruction_len = ( + len(tokenizer_speech_token(parts[0], tokenizer)) - 2 + ) + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + elif conv.sep_style == conversation_lib.SeparatorStyle.QWEN2: + # Mask targets + sep = "<|im_start|>assistant\n" + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + raw_rounds = conversation.split("<|im_end|>\n") + cur_len = 0 + rounds = [] + now_str = "" + for rou in raw_rounds: + if len(rou) > 0: + rou = rou + "<|im_end|>\n" + if rou.startswith("<|endoftext|>"): + rounds[-1] = rounds[-1] + "<|endoftext|>" + rou = rou.replace("<|endoftext|>", "") + if len(rou.strip()) == 0: + continue + if "<|im_start|>assistant\n" in rou: + now_str += rou + rounds.append(now_str) + now_str = "" + else: + now_str += rou + + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_speech: + round_len = len(tokenizer_speech_token(rou, tokenizer)) + instruction_len = ( + len(tokenizer_speech_token(parts[0], tokenizer)) - 2 + ) + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + try: + is_legacy = tokenizer.legacy + except: + is_legacy = True + + if i != 0 and not is_legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch for QWEN2: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_plain( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + # add end signal and concatenate together + conversations = [] + for source in sources: + assert len(source) == 2 + assert DEFAULT_SPEECH_TOKEN in source[0]["value"] + source[0]["value"] = DEFAULT_SPEECH_TOKEN + conversation = ( + source[0]["value"] + + source[1]["value"] + + conversation_lib.default_conversation.sep + ) + conversations.append(conversation) + # tokenize conversations + input_ids = [ + tokenizer_speech_token(prompt, tokenizer, return_tensors="pt") + for prompt in conversations + ] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_len = len(tokenizer_speech_token(source[0]["value"], tokenizer)) + target[:tokenized_len] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets) + + +def preprocess_qwen( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_speech: bool = False, + has_image: bool = False, + max_len=2048, + system_message: str = "You are a helpful assistant.", +) -> Dict: + def split_text(text, keywords): + pattern = "(" + "|".join(map(re.escape, keywords)) + ")" + parts = re.split(pattern, text) + parts = [part for part in parts if part] + return parts + + roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"} + + # im_start, im_end = tokenizer.additional_special_tokens_ids + + im_start = tokenizer("<|im_start|>").input_ids[0] + im_end = tokenizer("<|im_end|>").input_ids[0] + nl_tokens = tokenizer("\n").input_ids + _system = tokenizer("system").input_ids + nl_tokens + + # Apply prompt templates + input_ids, targets = [], [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != roles["human"]: + source = source[1:] + + input_id, target = [], [] + system = ( + [im_start] + + _system + + tokenizer(system_message).input_ids + + [im_end] + + nl_tokens + ) + input_id += system + target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens + assert len(input_id) == len(target) + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + splited_sentence = split_text(sentence["value"], ["", ""]) + _input_id = [] + for part in splited_sentence: + _input_id += tokenizer(role).input_ids + nl_tokens # add prefix + if "" == part: + _input_id += [SPEECH_TOKEN_INDEX] + elif "" == part: + _input_id += [IMAGE_TOKEN_INDEX] + else: + _input_id += tokenizer(part).input_ids + _input_id += [im_end] + nl_tokens # add suffix + input_id += _input_id + if role == "<|im_start|>user": + _target = ( + [im_start] + + [IGNORE_INDEX] * (len(_input_id) - 3) + + [im_end] + + nl_tokens + ) + elif role == "<|im_start|>assistant": + _target = ( + [im_start] + + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + + [im_end] + + nl_tokens + ) + else: + raise NotImplementedError + target += _target + assert len(input_id) == len(target) + input_ids.append(input_id) + targets.append(target) + input_ids = torch.tensor(input_ids, dtype=torch.long) + targets = torch.tensor(targets, dtype=torch.long) + + return dict( + input_ids=input_ids, # tensor(bs x seq_len) + labels=targets, # tensor(bs x seq_len) + ) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_speech: bool = False, + has_image: bool = False, +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if ( + conversation_lib.default_conversation.sep_style + == conversation_lib.SeparatorStyle.PLAIN + ): + return preprocess_plain(sources, tokenizer, has_image=has_image) + if ( + conversation_lib.default_conversation.sep_style + == conversation_lib.SeparatorStyle.LLAMA_2 + ): + return preprocess_llama_2( + sources, tokenizer, has_speech=has_speech, has_image=has_image + ) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1( + sources, tokenizer, has_speech=has_speech, has_image=has_image + ) + if ( + conversation_lib.default_conversation.sep_style + == conversation_lib.SeparatorStyle.LLAMA_3 + ): + return preprocess_llama_3( + sources, tokenizer, has_speech=has_speech, has_image=has_image + ) + if conversation_lib.default_conversation.version == "qwen": + return preprocess_qwen( + sources, tokenizer, has_speech=has_speech, has_image=has_image + ) + raise NotImplementedError + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + ): + super(LazySupervisedDataset, self).__init__() + list_data_dict = json.load(open(data_path, "r")) + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + self.mel_size = 128 + + def __len__(self): + return len(self.list_data_dict) + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + cur_len = sum( + len(conv["value"].split()) for conv in sample["conversations"] + ) + assert cur_len > 0, f"Conversation length is 0 for {sample}" + if "image" in sample or "video" in sample or self.data_args.early_mix_text: + length_list.append(cur_len) + else: + length_list.append(-cur_len) + return length_list + + def process_audio(self, audio_file, start_frame=None, end_frame=None, fps=20): + speech, sample_rate = sf.read(audio_file) + if start_frame is not None and end_frame is not None: + start_sample = start_frame * sample_rate // fps + end_sample = end_frame * sample_rate // fps + speech = speech[start_sample:end_sample] + if sample_rate != 16000: + target_length = int(len(speech) * 16000 / sample_rate) + speech = resample(speech, target_length) + if speech.ndim > 1: + speech = np.mean(speech, axis=1) + speech = whisper.pad_or_trim(speech.astype(np.float32)) + speech = whisper.log_mel_spectrogram(speech, n_mels=self.mel_size).permute(1, 0) + speech_length = torch.LongTensor([speech.shape[0]]) + return speech, speech_length + + def process_image(self, image_file, overwrite_image_aspect_ratio=None): + processor = self.data_args.image_processor + # print(f"\n\nInspecting the image path, folder = {image_folder}, image={image_file}\n\n") + try: + image = Image.open(image_file).convert("RGB") + except Exception as exn: + print(f"Failed to open image {image_file}. Exception:", exn) + raise exn + + image_size = image.size + image_aspect_ratio = self.data_args.image_aspect_ratio + if overwrite_image_aspect_ratio is not None: + image_aspect_ratio = overwrite_image_aspect_ratio + if image_aspect_ratio == "highres": + image = process_highres_image( + image, + self.data_args.image_processor, + self.data_args.image_grid_pinpoints, + ) + elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: + image = process_anyres_image( + image, + self.data_args.image_processor, + self.data_args.image_grid_pinpoints, + ) + elif image_aspect_ratio == "crop_split": + image = process_highres_image_crop_split(image, self.data_args) + elif image_aspect_ratio == "pad": + + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square( + image, tuple(int(x * 255) for x in processor.image_mean) + ) + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + else: + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + return image, image_size, "image" + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + while True: + try: + sample = self._get_item(i) + # print("process sample",i) + break + except Exception as e: + while True: + try: + i += 1 + random_index = i % len(self.list_data_dict) + sample = self._get_item(random_index) + # print("something error, process sample",random_index) + break + except Exception as e: + # random_index = random.randint(0, len(self.list_data_dict) - 1) + continue + return sample + + def _get_item(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + if "image" in sources[0]: + image_file = self.list_data_dict[i]["image"] + if type(image_file) is list: + image = [self.process_image(f) for f in image_file] + # Handling multi images + # overwrite to process with simple pad + if len(image_file) > 1: + image = [self.process_image(f, "pad") for f in image_file] + image = [[im[0], im[1], "image"] for im in image] + else: + image = [self.process_image(image_file)] + + if "video" or "audio" in sources[0]: + if "video" in sources[0]: + video_file = self.list_data_dict[i]["video"] + # video_folder = self.data_args.video_folder + # video_file = os.path.join(video_folder, video_file) + if not os.path.exists(video_file): + print("File {} not exist!".format(video_file)) + if "start_frame" in self.list_data_dict[i]: + start_frame = self.list_data_dict[i]["start_frame"] + end_frame = self.list_data_dict[i]["end_frame"] + if self.list_data_dict[i].get( + "current_observation_frame", None + ): # Customized for egoplan data + current_observation_frame = self.list_data_dict[i][ + "current_observation_frame" + ] + else: + current_observation_frame = None + video = process_video_with_decord_byframe( + video_file, + start_frame, + end_frame, + self.data_args, + current_observation_frame, + ) + else: + ( + video, + video_time, + frame_time, + num_frames, + ) = process_video_with_decord(video_file, self.data_args) + processor = self.data_args.image_processor + processed_video = processor.preprocess(video, return_tensors="pt")[ + "pixel_values" + ] + image = [(processed_video, video[0].size, "video")] + + if "audio" in sources[0]: + audio_file = self.list_data_dict[i]["audio"] + # audio_folder = self.data_args.audio_folder + # audio_file = os.path.join(audio_folder, audio_file) + try: + if "start_frame" in self.list_data_dict[i]: + start_frame = self.list_data_dict[i]["start_frame"] + end_frame = self.list_data_dict[i]["end_frame"] + else: + start_frame = None + end_frame = None + audio, audio_length = self.process_audio( + audio_file, start_frame, end_frame + ) + except Exception as e: + print("audio error", e) + audio = [torch.zeros(3000, 128)] + audio_length = torch.tensor([3000]) + audio = [audio] + sources = preprocess_multimodal( + copy.deepcopy([e["conversations"] for e in sources]), self.data_args + ) + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + has_speech = "audio" in self.list_data_dict[i] + has_image = ("image" in self.list_data_dict[i]) or ( + "video" in self.list_data_dict[i] + ) + data_dict = preprocess( + sources, self.tokenizer, has_speech=has_speech, has_image=has_image + ) + if isinstance(i, int): + data_dict = dict( + input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0] + ) + + if "image" or "video" in self.list_data_dict[i]: + data_dict["image"] = image + # audio exist in the data + if "audio" in self.list_data_dict[i]: + data_dict["speech"] = audio + data_dict["speech_lengths"] = audio_length + else: # if no audio, add a dummy audio + data_dict["speech"] = [torch.zeros(3000, 128)] + data_dict["speech_lengths"] = torch.tensor([3000]) + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def pad_sequence(self, input_ids, batch_first, padding_value): + if self.tokenizer.padding_side == "left": + input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=batch_first, padding_value=padding_value + ) + if self.tokenizer.padding_side == "left": + input_ids = torch.flip(input_ids, [1]) + return input_ids + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple( + [instance[key] for instance in instances] for key in ("input_ids", "labels") + ) + input_ids = [ + _input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids + ] + labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels] + if self.tokenizer.pad_token_id is None: + if "qwen" in self.tokenizer.name_or_path.lower(): + # print("Setting pad token to bos token for qwen model.") + self.tokenizer.pad_token_id = 151643 + else: + self.tokenizer.pad_token_id = ( + self.tokenizer.eos_token_id + ) # FIXME: this could only be triggered for llama3 model. + input_ids = self.pad_sequence( + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) + labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + if "speech" in instances[0]: + speeches = [instance["speech"] for instance in instances] + speeches_lengths = [instance["speech_lengths"] for instance in instances] + batch["speech"] = [au for audio_list in speeches for au in audio_list] + + batch["speech_lengths"] = [ + au for audio_list in speeches_lengths for au in audio_list + ] + batch["speech_lengths"] = torch.stack(batch["speech_lengths"]) + + if all( + x is not None and x.shape == speeches[0][0].shape + for x in batch["speech"] + ): + batch["speech"] = torch.stack(batch["speech"]) + + if "image" in instances[0]: + images = [instance["image"] for instance in instances] + + batch["image_sizes"] = [im[1] for im_list in images for im in im_list] + batch["modalities"] = [im[2] for im_list in images for im in im_list] + images = [im[0] for im_list in images for im in im_list] + + # if all(x is not None and x.shape == images[0].shape for x in images): + # Image: (N, P, C, H, W) + # Video: (N, F, C, H, W) + # batch["images"] = torch.stack(images) + # else: + batch["images"] = images + return batch + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = LazySupervisedDataset( + tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args + ) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict( + train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator + ) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + compute_dtype = ( + torch.float16 + if training_args.fp16 + else (torch.bfloat16 if training_args.bf16 else torch.float32) + ) + + if "qwen" in model_args.model_name_or_path.lower(): + model = EgoGPTQwenForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation="flash_attention_2", + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + ) + else: + model = EgoGPTLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation="flash_attention_2", + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + ) + + model.config.use_cache = False + + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + if training_args.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + training_args.ddp_find_unused_parameters = True + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + use_dora=True, + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + rank0_print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + model.to(dtype=compute_dtype, device=training_args.device) + + if "qwen" in model_args.model_name_or_path.lower(): + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + ) + else: + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + if model_args.version == "v0": + if tokenizer.pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token="[PAD]"), + tokenizer=tokenizer, + model=model, + ) + elif model_args.version == "v0.5": + tokenizer.pad_token = tokenizer.unk_token + else: + tokenizer.pad_token = tokenizer.unk_token + if model_args.version in conversation_lib.conv_templates: + conversation_lib.default_conversation = conversation_lib.conv_templates[ + model_args.version + ] + else: + conversation_lib.default_conversation = conversation_lib.conv_templates[ + "vicuna_v1" + ] + + model.get_model().initialize_speech_modules( + model_args=model_args, fsdp=training_args.fsdp + ) + speech_encoder = model.get_speech_encoder() + speech_encoder.to( + dtype=torch.bfloat16 if training_args.bf16 else torch.float16, + device=training_args.device, + ) + + if model_args.vision_tower is not None: + model.get_model().initialize_vision_modules( + model_args=model_args, fsdp=training_args.fsdp + ) + # import pdb;pdb.set_trace() + vision_tower = model.get_vision_tower() + vision_tower.to( + dtype=torch.bfloat16 if training_args.bf16 else torch.float16, + device=training_args.device, + ) + + data_args.image_processor = vision_tower.image_processor + model.config.image_aspect_ratio = data_args.image_aspect_ratio + if data_args.image_grid_pinpoints is not None: + if ( + isinstance(data_args.image_grid_pinpoints, str) + and "x" in data_args.image_grid_pinpoints + ): + try: + patch_size = data_args.image_processor.size[0] + except Exception as e: + patch_size = data_args.image_processor.size["shortest_edge"] + + assert patch_size in [ + 224, + 336, + 384, + 448, + 512, + ], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", data_args.image_grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] + # Multiply all elements by patch_size + data_args.image_grid_pinpoints = [ + [dim * patch_size for dim in pair] for pair in grid_pinpoints + ] + elif isinstance(data_args.image_grid_pinpoints, str): + data_args.image_grid_pinpoints = ast.literal_eval( + data_args.image_grid_pinpoints + ) + + model.config.image_grid_pinpoints = data_args.image_grid_pinpoints + model.config.image_crop_resolution = data_args.image_crop_resolution + model.config.image_split_resolution = data_args.image_split_resolution + model.config.tokenizer_padding_side = tokenizer.padding_side + model.config.tokenizer_model_max_length = tokenizer.model_max_length + model.config.mm_newline_position = model_args.mm_newline_position + model.config.add_faster_video = model_args.add_faster_video + model.config.faster_token_stride = model_args.faster_token_stride + model.config.mm_spatial_pool_stride = model_args.mm_spatial_pool_stride + + data_args.is_multimodal = True + + model.config.tune_mm_mlp_adapter = ( + training_args.tune_mm_mlp_adapter + ) = model_args.tune_mm_mlp_adapter + if model_args.tune_mm_mlp_adapter: + model.requires_grad_(False) + if model_args.tune_mm_mlp_adapter: + for p in model.get_model().speech_projector.parameters(): + p.requires_grad = True + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = True + + model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter + if training_args.freeze_mm_mlp_adapter: + for p in model.get_model().speech_projector.parameters(): + p.requires_grad = False + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = False + + model.config.freeze_mm_vision_resampler = training_args.freeze_mm_vision_resampler + if training_args.freeze_mm_vision_resampler: + for p in model.get_model().vision_resampler.parameters(): + p.requires_grad = False + model.config.unfreeze_mm_speech_encoder = model_args.unfreeze_mm_speech_encoder + if model_args.unfreeze_mm_speech_encoder: + speech_encoder.requires_grad_(True) + + model.config.mm_use_im_start_end = ( + data_args.mm_use_im_start_end + ) = model_args.mm_use_im_start_end + model.config.mm_projector_lr = training_args.mm_projector_lr + model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr + model.config.speech_projector_lr = training_args.speech_projector_lr + model.config.mm_speech_encoder_lr = training_args.mm_speech_encoder_lr + model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + training_args.use_im_start_end = model_args.mm_use_im_start_end + + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + + # test_data = data_module['train_dataset'].__getitem__(0) + trainer = LLaVATrainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + model.config.use_cache = True + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), training_args.lora_bias + ) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( + model.named_parameters() + ) + if training_args.local_rank == 0 or training_args.local_rank == -1: + model.config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save( + non_lora_state_dict, + os.path.join(training_args.output_dir, "non_lora_trainables.bin"), + ) + else: + safe_save_model_for_hf_trainer( + trainer=trainer, output_dir=training_args.output_dir + ) + + +if __name__ == "__main__": + import torch + + print("number of gpus", torch.cuda.device_count()) + train() diff --git a/egogpt/utils.py b/egogpt/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..596597a234527f2a568196200c8daffa1ed1fa2f --- /dev/null +++ b/egogpt/utils.py @@ -0,0 +1,372 @@ +# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import logging.handlers +import os +import sys + +import cv2 +import numpy as np +import torch +import torch.distributed as dist +import transformers + +from egogpt.constants import LOGDIR + +try: + import av + from decord import VideoReader, cpu +except ImportError: + print("Please install pyav to use video processing functions.") + + +server_error_msg = ( + "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +) +moderation_msg = ( + "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." +) + +handler = None + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when="D", utc=True, encoding="UTF-8" + ) + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +def process_video_with_decord(video_file, data_args): + vr = VideoReader(video_file, ctx=cpu(0), num_threads=1) + total_frame_num = len(vr) + video_time = total_frame_num / vr.get_avg_fps() + avg_fps = round(vr.get_avg_fps() / data_args.video_fps) + frame_idx = [i for i in range(0, total_frame_num, avg_fps)] + frame_time = [i / avg_fps for i in frame_idx] + + if data_args.frames_upbound > 0: + if len(frame_idx) > data_args.frames_upbound or data_args.force_sample: + uniform_sampled_frames = np.linspace( + 0, total_frame_num - 1, data_args.frames_upbound, dtype=int + ) + frame_idx = uniform_sampled_frames.tolist() + frame_time = [i / vr.get_avg_fps() for i in frame_idx] + frames = vr.get_batch(frame_idx).asnumpy() + # resized_frames = np.array([cv2.resize(frame, (384, 384)) for frame in frames]) + # video = resized_frames + video = frames + frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) + + num_frames_to_sample = num_frames = len(frame_idx) + # https://github.com/dmlc/decord/issues/208 + vr.seek(0) + return video, video_time, frame_time, num_frames_to_sample + + +def process_video_with_decord_byframe( + video_file, start_frame, end_frame, data_args, current_observation_frame=None +): + try: + vr = VideoReader(video_file, ctx=cpu(0), num_threads=1) + total_frame_num = len(vr) + selected_frame = min(total_frame_num - 1, end_frame) + avg_fps = round(vr.get_avg_fps() / data_args.video_fps) + frame_idx = [i for i in range(start_frame, selected_frame, avg_fps)] + if data_args.frames_upbound > 0: + if len(frame_idx) > data_args.frames_upbound: + uniform_sampled_frames = np.linspace( + start_frame, selected_frame, data_args.frames_upbound, dtype=int + ) + frame_idx = uniform_sampled_frames.tolist() + if current_observation_frame: + frame_idx.append(current_observation_frame) + video = vr.get_batch(frame_idx).asnumpy() + # https://github.com/dmlc/decord/issues/208 + vr.seek(0) + except: + raise SyntaxError("Video processing error") + return video + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = "" + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = "" + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == "\n": + self.logger.log(self.log_level, line.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != "": + self.logger.log(self.log_level, self.linebuf.rstrip()) + self.linebuf = "" + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning( + f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}" + ) + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = { + k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() + } + return to_return + + +def get_speech_projector_state_maybe_zero_3(named_params, keys_to_match): + to_return = { + k: t + for k, t in named_params + if any(key_match in k for key_match in keys_to_match) + } + to_return = { + k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() + } + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + speech_keywords = ["speech_projector", "speech_encoder"] + for name, module in model.named_modules(): + if any(speech_keyword in name for speech_keyword in speech_keywords): + continue + if isinstance(module, cls): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") + return list(lora_module_names) + + +def rank0_print(*args): + if dist.is_initialized(): + if dist.get_rank() == 0: + print(f"Rank {dist.get_rank()}: ", *args) + else: + print(*args) + + +def rank_print(*args): + if dist.is_initialized(): + print(f"Rank {dist.get_rank()}: ", *args) + else: + print(*args) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + + if getattr(trainer.args, "tune_speech_projector", False): + # Only save projector + keys_to_match = ["speech_projector"] + if getattr(trainer.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_speech_projector_state_maybe_zero_3( + trainer.model.named_parameters(), keys_to_match + ) + trainer.model.config.save_pretrained(output_dir) + + current_folder = output_dir.split("/")[-1] + parent_folder = os.path.dirname(output_dir) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + if current_folder.startswith("checkpoint-"): + speech_projector_folder = os.path.join( + parent_folder, "speech_projector" + ) + os.makedirs(speech_projector_folder, exist_ok=True) + torch.save( + weight_to_save, + os.path.join(speech_projector_folder, f"{current_folder}.bin"), + ) + else: + torch.save( + weight_to_save, os.path.join(output_dir, f"speech_projector.bin") + ) + return + + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def lengths_to_padding_mask(lens): + bsz, max_lens = lens.size(0), torch.max(lens).item() + mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) + mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) + return mask + + +def lengths_to_mask(lens): + return ~lengths_to_padding_mask(lens) + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith("checkpoint-"): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"], + } + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"