diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..297854639ea5acf91aee4f106d27e2acad7f9169 --- /dev/null +++ b/app.py @@ -0,0 +1,382 @@ +import os +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['REGIONAL_POOL'] = '2x' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['SKIP_LOAD_VIT'] = '1' + + +import gradio as gr +import torch +import re +from decord import VideoReader, cpu +from PIL import Image +import numpy as np +import transformers +import moviepy.editor as mp +from typing import Dict, Optional, Sequence, List +import librosa +import whisper + +# import subprocess +# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) + +import sys +sys.path.append('./ola/CosyVoice/') +from ola.conversation import conv_templates, SeparatorStyle +from ola.model.builder import load_pretrained_model +from ola.utils import disable_torch_init +from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token +from ola.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image_genli +from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN +# from ola.CosyVoice.cosyvoice.cli.cosyvoice import CosyVoice + +model_path = "/mnt/lzy/ola-model/Ola-7b" +tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None) +model = model.to('cuda').eval() +model = model.bfloat16() + +# tts_model = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True) +# OUTPUT_SPEECH = False + +USE_SPEECH=False + +title_markdown = """ +
+ Oryx + +
+

Ola: Pushing the Frontiers of Omni-Modal Language Model with Progressive Modality Alignment

+
Project Page | Github | Huggingface | Paper
+
+
+""" + +bibtext = """ +### Citation +``` +@article{liu2025ola, +title={Ola: Pushing the Frontiers of Omni-Modal Language Model with Progressive Modality Alignment}, +author={Liu, Zuyan and Dong, Yuhao and Wang, Jiahui and Liu, Ziwei and Hu, Winston and Lu, Jiwen and Rao, Yongming}, +journal={arXiv preprint arXiv:2502.04328}, +year={2025} +} +``` +""" +cur_dir = os.path.dirname(os.path.abspath(__file__)) + + +def load_audio(audio_file_name): + speech_wav, samplerate = librosa.load(audio_file_name, sr=16000) + if len(speech_wav.shape) > 1: + speech_wav = speech_wav[:, 0] + speech_wav = speech_wav.astype(np.float32) + CHUNK_LIM = 480000 + SAMPLE_RATE = 16000 + speechs = [] + speech_wavs = [] + + if len(speech_wav) <= CHUNK_LIM: + speech = whisper.pad_or_trim(speech_wav) + speech_wav = whisper.pad_or_trim(speech_wav) + speechs.append(speech) + speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0)) + else: + for i in range(0, len(speech_wav), CHUNK_LIM): + chunk = speech_wav[i : i + CHUNK_LIM] + if len(chunk) < CHUNK_LIM: + chunk = whisper.pad_or_trim(chunk) + speechs.append(chunk) + speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0)) + mels = [] + for chunk in speechs: + chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0) + mels.append(chunk) + + mels = torch.cat(mels, dim=0) + speech_wavs = torch.cat(speech_wavs, dim=0) + if mels.shape[0] > 25: + mels = mels[:25] + speech_wavs = speech_wavs[:25] + + speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0]) + speech_chunks = torch.LongTensor([mels.shape[0]]) + return mels, speech_length, speech_chunks, speech_wavs + +def extract_audio(videos_file_path): + my_clip = mp.VideoFileClip(videos_file_path) + return my_clip.audio + +def ola_inference(multimodal, audio_path): + visual, text = multimodal["files"][0], multimodal["text"] + if visual.endswith("image2.png"): + modality = "video" + visual = f"{cur_dir}/case/case1.mp4" + if visual.endswith(".mp4"): + modality = "video" + else: + modality = "image" + + # input audio and video, do not parse audio in the video, else parse audio in the video + if audio_path: + USE_SPEECH = True + elif modality == "video": + USE_SPEECH = True + else: + USE_SPEECH = False + + speechs = [] + speech_lengths = [] + speech_wavs = [] + speech_chunks = [] + if modality == "video": + vr = VideoReader(visual, ctx=cpu(0)) + total_frame_num = len(vr) + fps = round(vr.get_avg_fps()) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + spare_frames = vr.get_batch(frame_idx).asnumpy() + video = [Image.fromarray(frame) for frame in spare_frames] + else: + image = [Image.open(visual)] + image_sizes = [image[0].size] + + if USE_SPEECH and audio_path: + audio_path = audio_path + speech, speech_length, speech_chunk, speech_wav = load_audio(audio_path) + speechs.append(speech.bfloat16().to('cuda')) + speech_lengths.append(speech_length.to('cuda')) + speech_chunks.append(speech_chunk.to('cuda')) + speech_wavs.append(speech_wav.to('cuda')) + print('load audio') + elif USE_SPEECH and not audio_path: + # parse audio in the video + audio = extract_audio(visual) + audio.write_audiofile("./video_audio.wav") + video_audio_path = './video_audio.wav' + speech, speech_length, speech_chunk, speech_wav = load_audio(video_audio_path) + speechs.append(speech.bfloat16().to('cuda')) + speech_lengths.append(speech_length.to('cuda')) + speech_chunks.append(speech_chunk.to('cuda')) + speech_wavs.append(speech_wav.to('cuda')) + else: + speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')] + speech_lengths = [torch.LongTensor([3000]).to('cuda')] + speech_wavs = [torch.zeros([1, 480000]).to('cuda')] + speech_chunks = [torch.LongTensor([1]).to('cuda')] + + conv_mode = "qwen_1_5" + if text: + qs = text + else: + qs = '' + if USE_SPEECH and audio_path: + qs = DEFAULT_IMAGE_TOKEN + "\n" + "User's question in speech: " + DEFAULT_SPEECH_TOKEN + '\n' + elif USE_SPEECH: + qs = DEFAULT_SPEECH_TOKEN + DEFAULT_IMAGE_TOKEN + "\n" + qs + else: + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + + conv = conv_templates[conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + if USE_SPEECH and audio_path: + input_ids = tokenizer_speech_question_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') + elif USE_SPEECH: + input_ids = tokenizer_speech_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') + else: + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') + + if modality == "video": + video_processed = [] + for idx, frame in enumerate(video): + image_processor.do_resize = False + image_processor.do_center_crop = False + frame = process_anyres_video(frame, image_processor) + + if frame_idx is not None and idx in frame_idx: + video_processed.append(frame.unsqueeze(0)) + elif frame_idx is None: + video_processed.append(frame.unsqueeze(0)) + + if frame_idx is None: + frame_idx = np.arange(0, len(video_processed), dtype=int).tolist() + + video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda") + video_processed = (video_processed, video_processed) + + video_data = (video_processed, (384, 384), "video") + else: + image_processor.do_resize = False + image_processor.do_center_crop = False + image_tensor, image_highres_tensor = [], [] + for visual in image: + image_tensor_, image_highres_tensor_ = process_anyres_highres_image_genli(visual, image_processor) + image_tensor.append(image_tensor_) + image_highres_tensor.append(image_highres_tensor_) + if all(x.shape == image_tensor[0].shape for x in image_tensor): + image_tensor = torch.stack(image_tensor, dim=0) + if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor): + image_highres_tensor = torch.stack(image_highres_tensor, dim=0) + if type(image_tensor) is list: + image_tensor = [_image.bfloat16().to("cuda") for _image in image_tensor] + else: + image_tensor = image_tensor.bfloat16().to("cuda") + if type(image_highres_tensor) is list: + image_highres_tensor = [_image.bfloat16().to("cuda") for _image in image_highres_tensor] + else: + image_highres_tensor = image_highres_tensor.bfloat16().to("cuda") + + pad_token_ids = 151643 + + attention_masks = input_ids.ne(pad_token_ids).long().to('cuda') + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + gen_kwargs = {} + + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0.2 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + + with torch.inference_mode(): + if modality == "video": + output_ids = model.generate( + inputs=input_ids, + images=video_data[0][0], + images_highres=video_data[0][1], + modalities=video_data[2], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + else: + output_ids = model.generate( + inputs=input_ids, + images=image_tensor, + images_highres=image_highres_tensor, + image_sizes=image_sizes, + modalities=['image'], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + + # if OUTPUT_SPEECH: + # voice_all = [] + # for i, j in enumerate(cosyvoice.inference_sft('Visual data comes in various forms, ranging from small icons of just a few pixels to long videos spanning hours. Existing multi-modal LLMs usually standardize these diverse visual inputs to a fixed resolution for visual encoders and yield similar numbers of tokens for LLMs. This approach is non-optimal for multimodal understanding and inefficient for processing inputs with long and short visual contents. To solve the problem, we propose Oryx, a unified multimodal architecture for the spatial-temporal understanding of images, videos, and multi-view 3D scenes. Oryx offers an on-demand solution to seamlessly and efficiently process visual inputs with arbitrary spatial sizes and temporal lengths through two core innovations: 1) a pre-trained OryxViT model that can encode images at any resolution into LLM-friendly visual representations; 2) a dynamic compressor module that supports 1x to 16x compression on visual tokens by request. These design features enable Oryx to accommodate extremely long visual contexts, such as videos, with lower resolution and high compression while maintaining high recognition precision for tasks like document understanding with native resolution and no compression. Beyond the architectural improvements, enhanced data curation and specialized training on long-context retrieval and spatial-aware data help Oryx achieve strong capabilities in image, video, and 3D multimodal understanding simultaneously. ', 'θ‹±ζ–‡ε₯³', stream=False)): + # voice_all.append(j['tts_speech']) + # voice_all = torch.cat(voice_all, dim=1) + # torchaudio.save('sft.wav', voice_all, 22050) + # return outputs, "sft.wav" + # else: + return outputs, None + +# Define input and output for the Gradio interface +demo = gr.Interface( + fn=ola_inference, + inputs=[gr.MultimodalTextbox(file_types=[".mp4", "image"],placeholder="Enter message or upload file..."), gr.Audio(type="filepath")], + outputs=["text", "audio"], + # examples=[ + # { + # "files":[f"{cur_dir}/case/image2.png"], + # "text":"Describe what is happening in this video in detail.", + # }, + # { + # "files":[f"{cur_dir}/case/image.png"], + # "text":"Describe this icon.", + # }, + # ], + title="Ola Demo", + description=title_markdown, + article=bibtext, +) + +# textbox = gr.Textbox( +# show_label=False, placeholder="Enter text and press ENTER", container=False, max_lines=100 +# ) +# with gr.Blocks( +# title="Oryx-7B", +# theme="finlaymacklon/smooth_slate", +# css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 50px}", +# fill_height=True +# ) as demo: +# html_header = "https://oryx-mllm.github.io/" +# gr.HTML(html_header) + +# with gr.Row(equal_height=True): +# with gr.Column(scale=3): +# with gr.Row(): +# video = gr.Video(label="Input Video", height=400) +# cur_dir = os.path.dirname(os.path.abspath(__file__)) +# with gr.Row(): +# gr.Examples( +# examples=[ +# [ +# f"{cur_dir}/case/case1.mp4", +# "Describe what is happening in this video in detail.", +# ], +# ], +# inputs=[video, textbox], +# ) + +# with gr.Column(scale=7): +# chatbot = gr.Chatbot(label="Oryx", bubble_full_width=False, height=660) +# with gr.Row(): +# with gr.Column(scale=8): +# textbox.render() +# with gr.Column(scale=1, min_width=50): +# submit_btn = gr.Button( +# value="Send", variant="primary", interactive=True +# ) +# # with gr.Row(elem_id="buttons") as button_row: +# # upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=True) +# # downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=True) +# # flag_btn = gr.Button(value="⚠️ Flag", interactive=True) +# # clear_btn = gr.Button(value="πŸ—‘οΈ Clear history", interactive=True) + +# submit_btn.click( +# oryx_inference, +# [video, textbox], +# [chatbot, textbox, video], +# ) +# Launch the Gradio app +demo.launch(server_name="0.0.0.0",server_port=80) diff --git a/ola/CosyVoice b/ola/CosyVoice new file mode 160000 index 0000000000000000000000000000000000000000..027e1ccb82ce59bbc12f35a96e0f92625cf18369 --- /dev/null +++ b/ola/CosyVoice @@ -0,0 +1 @@ +Subproject commit 027e1ccb82ce59bbc12f35a96e0f92625cf18369 diff --git a/ola/__pycache__/arguments.cpython-310.pyc b/ola/__pycache__/arguments.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1668126a8929d21b4f2e020e08c77ff358ff07f5 Binary files /dev/null and b/ola/__pycache__/arguments.cpython-310.pyc differ diff --git a/ola/__pycache__/arguments.cpython-38.pyc b/ola/__pycache__/arguments.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9881382e5729d28354335677d1a25bdd3d8a3fbc Binary files /dev/null and b/ola/__pycache__/arguments.cpython-38.pyc differ diff --git a/ola/__pycache__/constants.cpython-310.pyc b/ola/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9012d51428717321c590fd2a2aa5ff30407f63fe Binary files /dev/null and b/ola/__pycache__/constants.cpython-310.pyc differ diff --git a/ola/__pycache__/constants.cpython-38.pyc b/ola/__pycache__/constants.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60c8f6f1d5b026a6a2518e6740f2c820b7355461 Binary files /dev/null and b/ola/__pycache__/constants.cpython-38.pyc differ diff --git a/ola/__pycache__/conversation.cpython-310.pyc b/ola/__pycache__/conversation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d23d1750ba544eec54730b09e98e3aaed6bbc432 Binary files /dev/null and b/ola/__pycache__/conversation.cpython-310.pyc differ diff --git a/ola/__pycache__/conversation.cpython-38.pyc b/ola/__pycache__/conversation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b11153e67c4ec28a1e32c7efd41316f9edfc9f5 Binary files /dev/null and b/ola/__pycache__/conversation.cpython-38.pyc differ diff --git a/ola/__pycache__/mm_utils.cpython-310.pyc b/ola/__pycache__/mm_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4143ec1ae23c907a4c3efd27750cb15e1182dd16 Binary files /dev/null and b/ola/__pycache__/mm_utils.cpython-310.pyc differ diff --git a/ola/__pycache__/mm_utils.cpython-38.pyc b/ola/__pycache__/mm_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6a7bf02be74ca4a5258f039463013e2095022ac Binary files /dev/null and b/ola/__pycache__/mm_utils.cpython-38.pyc differ diff --git a/ola/__pycache__/utils.cpython-310.pyc b/ola/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b79a7b0216633caa9777d64478cfb3c78d6051a Binary files /dev/null and b/ola/__pycache__/utils.cpython-310.pyc differ diff --git a/ola/__pycache__/utils.cpython-38.pyc b/ola/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d8e461a77ad60257d41b4d4ff15cd7db1c2b448 Binary files /dev/null and b/ola/__pycache__/utils.cpython-38.pyc differ diff --git a/ola/arguments.py b/ola/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..199c5d5b7912fefbe3882a0b6f774e31a5f80cfc --- /dev/null +++ b/ola/arguments.py @@ -0,0 +1,65 @@ +import transformers + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_speech_projector: bool = field(default=False) + tune_speech_encoder: bool = field(default=False) + tune_speech_generator_only: bool = field(default=False) + speech_encoder_type: Optional[str] = field(default=None) + speech_encoder: Optional[str] = field(default=None) + pretrain_speech_projector: Optional[str] = field(default=None) + speech_projector_type: Optional[str] = field(default='linear') + speech_encoder_ds_rate: int = 5 + speech_encoder_hidden_size: int = 1280 + + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + is_multimodal: bool = False + input_type: str = field(default="mel") + speech_normalize: bool = False + mel_size: int = 128 + has_tgt_units: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + freeze_speech_projector: bool = field(default=False) + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + speech_projector_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) \ No newline at end of file diff --git a/ola/constants.py b/ola/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..9b903d94f9122dc8b657383f8604555aad819400 --- /dev/null +++ b/ola/constants.py @@ -0,0 +1,14 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +SPEECH_TOKEN_INDEX = -200 +DEFAULT_SPEECH_TOKEN = "" +IMAGE_TOKEN_INDEX= -300 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" \ No newline at end of file diff --git a/ola/conversation.py b/ola/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..790011549b87bb0fad65125e62e9f0842d119dd9 --- /dev/null +++ b/ola/conversation.py @@ -0,0 +1,254 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Any, Union, Tuple +import base64 +from io import BytesIO +from PIL import Image + + +class SeparatorStyle(Enum): + """Different separator style.""" + TWO = auto() + PLAIN = auto() + CHATML = auto() + LLAMA_2 = auto() + LLAMA_3 = auto() + QWEN2 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.PLAIN + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + tokenizer_id: str = "" + tokenizer: Any = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + + if self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.LLAMA_3: + wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg + ret = "<|begin_of_text|>" + wrap_sys(self.system) + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + ret += message.strip() + self.sep2 + else: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + return ret + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: + message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if self.system == "" else self.system + self.sep + "\n" + for role, message in messages: + if message: + if type(message) is tuple: + raise ValueError("Tuple not supported in CHATML") + message, images = message + message = "" * len(images) + message + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.QWEN2: + start = '<|im_start|>' + end = '<|im_end|>\n' + ret = start + 'system\n' + self.system + end + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + + if message.endswith('<|endoftext|>'): + message = message.replace('<|endoftext|>', '') + ret += start + role + "\n" + message + end + '<|endoftext|>' + else: + assert not '<|endoftext|>' in message, f"Invalid message: {message}" + ret += start + role + "\n" + message + end + else: + ret += start + role + "\n" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + msg, speech = msg + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=[], + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llama_2 = Conversation( + system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=[], + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_llama_3 = Conversation( + system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", + roles=("user", "assistant"), + version="llama_v3", + messages=[], + offset=0, + sep_style=SeparatorStyle.LLAMA_3, + sep="", + sep2="<|eot_id|>" +) + + +conv_qwen_v1 = Conversation( + system="You are a helpful assistant.", + roles=("user", "assistant"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.QWEN2, +) + +conv_plain = Conversation( + system="", + roles=("", ""), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", +) + +conv_qwen = Conversation( + system="""<|im_start|>system +You are a helpful assistant.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + version="qwen", + messages=[], + offset=0, + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", +) + +default_conversation = conv_llama_3 +conv_templates = { + "v1": conv_vicuna_v1, + "plain": conv_plain, + "llama_2": conv_llama_2, + "llama_3": conv_llama_3, + 'v1_qwen2': conv_qwen_v1, + "qwen_1_5": conv_qwen, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/ola/datasets/__init__.py b/ola/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ola/datasets/__pycache__/__init__.cpython-310.pyc b/ola/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0d0663533deed3cca3f909ab73cb2e652b028fa Binary files /dev/null and b/ola/datasets/__pycache__/__init__.cpython-310.pyc differ diff --git a/ola/datasets/__pycache__/__init__.cpython-38.pyc b/ola/datasets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9d34c4faa0cce83f6584731304021d4ffa7bffc Binary files /dev/null and b/ola/datasets/__pycache__/__init__.cpython-38.pyc differ diff --git a/ola/datasets/__pycache__/preprocess.cpython-310.pyc b/ola/datasets/__pycache__/preprocess.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6d7b6690ec7d0036b66fa0592cdce0cd9db73d9 Binary files /dev/null and b/ola/datasets/__pycache__/preprocess.cpython-310.pyc differ diff --git a/ola/datasets/__pycache__/preprocess.cpython-38.pyc b/ola/datasets/__pycache__/preprocess.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e2f4d0cab219302d2e9ba01c05d2145b53a30f1 Binary files /dev/null and b/ola/datasets/__pycache__/preprocess.cpython-38.pyc differ diff --git a/ola/datasets/preprocess.py b/ola/datasets/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..e60847af04e5ac012e0dea2d7785a22450c03d73 --- /dev/null +++ b/ola/datasets/preprocess.py @@ -0,0 +1,413 @@ +import copy +import torch +import transformers +import tokenizers + +from typing import Dict, Sequence + +from ola.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, IMAGE_TOKEN_INDEX +from ola import conversation as conversation_lib +from ola.model import * +from ola.arguments import DataArguments +from ola.constants import SPEECH_TOKEN_INDEX + +from packaging import version + +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') + + +def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + + +def preprocess_multimodal( + sources: Sequence[str], + data_args: DataArguments +) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + if DEFAULT_SPEECH_TOKEN in sentence['value']: + sentence['value'] = sentence['value'].replace(DEFAULT_SPEECH_TOKEN, '').strip() + sentence['value'] = DEFAULT_SPEECH_TOKEN + '\n' + sentence['value'] + sentence['value'] = sentence['value'].strip() + + return sources + +def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + +def tokenizer_speech_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [speech_token_idx, image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + +def tokenizer_speech_question_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("\nUser's question in speech: \n")] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + nl_tokens = tokenizer("\n").input_ids[0] + special_chunks = [image_token_index, nl_tokens] + special_chunks.extend(tokenizer("User's question in speech: ").input_ids) + special_chunks.extend([speech_token_idx, nl_tokens]) + + for x in insert_separator(prompt_chunks, special_chunks): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + +def preprocess_llama_2( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_speech: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_speech: + input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 + + # Mask targets + sep = "[/INST] " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_speech: + round_len = len(tokenizer_speech_token(rou, tokenizer)) + instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_llama_3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_speech: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + assert len(source) == 2, "now only support single-turn conversation" + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_speech: + input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3 + + # Mask targets + sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>\n\n" + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + parts = conversation.split(sep) + parts[0] += sep + + if has_speech: + conversation_len = len(tokenizer_speech_token(conversation, tokenizer)) + instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 1 + else: + conversation_len = len(tokenizer(conversation).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + cur_len += conversation_len + target[cur_len:] = IGNORE_INDEX + + # if cur_len < tokenizer.model_max_length: + # if cur_len != total_len: + # target[:] = IGNORE_INDEX + # print( + # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + # f" (ignored)" + # ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_speech: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_speech: + input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_speech: + round_len = len(tokenizer_speech_token(rou, tokenizer)) + instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + # FIXME: tokenizer bug + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_plain( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + # add end signal and concatenate together + conversations = [] + for source in sources: + assert len(source) == 2 + assert DEFAULT_SPEECH_TOKEN in source[0]['value'] + source[0]['value'] = DEFAULT_SPEECH_TOKEN + conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep + conversations.append(conversation) + # tokenize conversations + input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer)) + target[:tokenized_len] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_speech: bool = False +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: + return preprocess_plain(sources, tokenizer) + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: + return preprocess_llama_2(sources, tokenizer, has_speech=has_speech) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_speech=has_speech) + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_3: + return preprocess_llama_3(sources, tokenizer, has_speech=has_speech) + raise NotImplementedError \ No newline at end of file diff --git a/ola/mm_utils.py b/ola/mm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba63dc79423ddac3f80b197259ad58ed8e4c24b3 --- /dev/null +++ b/ola/mm_utils.py @@ -0,0 +1,272 @@ +from PIL import Image +import base64 +import math +import ast + +import torch +from transformers import StoppingCriteria +import os +import io + +if 'VIDEO_RESIZE' in os.environ: + # highresxpatch + VIDEO_RESIZE = os.environ['VIDEO_RESIZE'] + video_base, video_ps = VIDEO_RESIZE.split('x') + video_base = int(video_base) + video_ps = int(video_ps) + print(f"VIDEO_RESIZE is set as {VIDEO_RESIZE}, {video_base}, {video_ps}") +else: + HIGHRES_BASE = None + +if 'HIGHRES_BASE' in os.environ: + # highresxpatch + HIGHRES_BASE = os.environ['HIGHRES_BASE'] + highres_base, highres_ps = HIGHRES_BASE.split('x') + highres_base = int(highres_base) + highres_ps = int(highres_ps) + print(f"HIGHRES_BASE is set as {HIGHRES_BASE}, {highres_base}, {highres_ps}") +else: + HIGHRES_BASE = None + +if 'MAXRES' in os.environ: + # highresxpatch + MAXRES = int(os.environ['MAXRES']) + print(f"MAXRES is set as {MAXRES}") +else: + MAXRES = 1536 + +if 'MINRES' in os.environ: + # highresxpatch + MINRES = int(os.environ['MINRES']) + print(f"MINRES is set as {MINRES}") +else: + MINRES = 0 + +if 'VIDEO_MAXRES' in os.environ: + # highresxpatch + VIDEO_MAXRES = int(os.environ['VIDEO_MAXRES']) + print(f"VIDEO_MAXRES is set as {VIDEO_MAXRES}") +else: + VIDEO_MAXRES = 1536 + +if 'VIDEO_MINRES' in os.environ: + # highresxpatch + VIDEO_MINRES = int(os.environ['VIDEO_MINRES']) + print(f"VIDEO_MINRES is set as {VIDEO_MINRES}") +else: + MINRES = 0 + +if 'PAD2STRIDE' in os.environ: + # highresxpatch + PAD2STRIDE = True + print(f"PAD2STRIDE is set") +else: + PAD2STRIDE = False + +if 'LOWRES_RESIZE' in os.environ: + LOWRES_RESIZE = os.environ['LOWRES_RESIZE'] + print(f"LOWRES_RESIZE is set as {LOWRES_RESIZE}") + if 'x' in LOWRES_RESIZE: + size, ps = LOWRES_RESIZE.split('x') + size = int(size) + ps = int(ps) + LOWRES_RESIZE = (size, ps) + else: + LOWRES_RESIZE = int(LOWRES_RESIZE) +else: + LOWRES_RESIZE = None + + +def pad_image(image, target_resolution, value=0): + """ + Resize and pad an image to a target resolution while maintaining aspect ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ + original_width, original_height = image.size + target_width, target_height = target_resolution + # Create a new image with the target size and paste the resized image onto it + new_image = Image.new('RGB', (target_width, target_height), (value, value, value)) + paste_x = (target_width - original_width) // 2 + paste_y = (target_height - original_height) // 2 + new_image.paste(image, (paste_x, paste_y)) + return new_image + +def resize_images(image, patch_size=14, base_size=896): + h, w = image.size + if base_size == 0: + if h * w > MAXRES * MAXRES: + # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') + scale = MAXRES * MAXRES / (h * w) + scale = math.sqrt(scale) + elif h * w < MINRES * MINRES: + # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') + scale = MINRES * MINRES / (h * w) + scale = math.sqrt(scale) + else: + scale = None + else: + scale = base_size * base_size / (h * w) + scale = math.sqrt(scale) + + + if scale is not None: + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + new_h = max(new_h, patch_size) + new_w = max(new_w, patch_size) + image = image.resize((new_h, new_w)) + elif PAD2STRIDE: + if h % patch_size == 0: + new_h = h + else: + new_h = (h // patch_size + 1) * patch_size + + if w % patch_size == 0: + new_w = w + else: + new_w = (w // patch_size + 1) * patch_size + image = pad_image(image, (new_h, new_w), value=127) + else: + scale = 1.0 + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + new_h = max(new_h, patch_size) + new_w = max(new_w, patch_size) + image = image.resize((new_h, new_w)) + + return image + +def resize_video(image, patch_size=14, base_size=896): + h, w = image.size + if base_size == 0: + if h * w > VIDEO_MAXRES * VIDEO_MAXRES: + # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') + scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w) + scale = math.sqrt(scale) + elif h * w < VIDEO_MINRES * VIDEO_MINRES: + # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') + scale = VIDEO_MINRES * VIDEO_MINRES / (h * w) + scale = math.sqrt(scale) + else: + scale = None + else: + scale = base_size * base_size / (h * w) + scale = math.sqrt(scale) + + if scale is not None: + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + image = image.resize((new_h, new_w)) + elif PAD2STRIDE: + if h % patch_size == 0: + new_h = h + else: + new_h = (h // patch_size + 1) * patch_size + + if w % patch_size == 0: + new_w = w + else: + new_w = (w // patch_size + 1) * patch_size + image = pad_image(image, (new_h, new_w), value=127) + else: + scale = 1.0 + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + image = image.resize((new_h, new_w)) + + return image + +def process_anyres_video(image, processor): + if VIDEO_RESIZE is not None: + image = resize_video(image, patch_size=video_ps, base_size=video_base) + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + return image.unsqueeze(0) + else: + raise ValueError("VIDEO_RESIZE is not set") + +def process_anyres_highres_image_genli(image, processor): + h, w = image.size + if h < 32 and w < 32: + min_size = min(h, w) + ratio = 64 / min_size + image = image.resize((int(h * ratio), int(w * ratio))) + elif h < 32: + ratio = 64 / h + image = image.resize((int(h * ratio), int(w * ratio))) + elif w < 32: + ratio = 64 / w + image = image.resize((int(h * ratio), int(w * ratio))) + if HIGHRES_BASE is not None: + image = resize_images(image, patch_size=highres_ps, base_size=highres_base) + + if LOWRES_RESIZE is not None: + image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0]) + else: + image_original_resize = image.resize((384, 384)) + + # image_patches = [image_original_resize] + [image_original_resize] + # image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] + # for image_patch in image_patches] + image_patches = processor.preprocess(image_original_resize, return_tensors='pt')['pixel_values'][0] + image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + # return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0) + return image_patches.unsqueeze(0), image_padded.unsqueeze(0) + +def read_image_patch(patch_info): + if 'img_path' in patch_info.keys(): + image = Image.open(patch_info['img_path']).convert('RGB') + else: + if 'image_encoing' in patch_info.keys(): + patch_info['image_encoding'] = patch_info['image_encoing'] + image_file_name = patch_info['patch'] + start_bytes = int(patch_info['start_num']) + file_size = int(patch_info['size']) + + with open(image_file_name, 'rb') as f: + f.seek(start_bytes) + if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64': + image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB") + else: + image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB") + return image + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO + offset = min(output_ids.shape[1] - self.start_len, 3) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + if output_ids[0, -keyword_id.shape[0]:] == keyword_id: + return True + outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False diff --git a/ola/model/__init__.py b/ola/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7599857d36358a738f92ab2b534918dbcd4a7b45 --- /dev/null +++ b/ola/model/__init__.py @@ -0,0 +1 @@ +from .language_model.ola_qwen import OlaQwenForCausalLM, OlaConfigQwen \ No newline at end of file diff --git a/ola/model/__pycache__/__init__.cpython-310.pyc b/ola/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49cb05a153f306d49577723619b7f7627e94cc66 Binary files /dev/null and b/ola/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/ola/model/__pycache__/__init__.cpython-38.pyc b/ola/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0250c3ad8fb54f8d76d8b3b8123b0fb160312c4a Binary files /dev/null and b/ola/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/ola/model/__pycache__/builder.cpython-310.pyc b/ola/model/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b513d02ad535a91ce1dd8df753e63bc0283cf9c1 Binary files /dev/null and b/ola/model/__pycache__/builder.cpython-310.pyc differ diff --git a/ola/model/__pycache__/builder.cpython-38.pyc b/ola/model/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18854f0e9f992f20b717d64a10524f379fb1f28c Binary files /dev/null and b/ola/model/__pycache__/builder.cpython-38.pyc differ diff --git a/ola/model/__pycache__/ola_arch.cpython-310.pyc b/ola/model/__pycache__/ola_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4791853d1513029b1de4893143c3de9e3c48e70 Binary files /dev/null and b/ola/model/__pycache__/ola_arch.cpython-310.pyc differ diff --git a/ola/model/__pycache__/ola_arch.cpython-38.pyc b/ola/model/__pycache__/ola_arch.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe83d978af687ff22b7cbda3379aa51230efc64b Binary files /dev/null and b/ola/model/__pycache__/ola_arch.cpython-38.pyc differ diff --git a/ola/model/builder.py b/ola/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..42bd9622ce1f2dc8c589207017c0e82dc1809edc --- /dev/null +++ b/ola/model/builder.py @@ -0,0 +1,91 @@ +import os +import warnings +import shutil + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig +import torch +from ola.model import * +from ola.model.speech_encoder.builder import build_speech_encoder + +def load_pretrained_model(model_path, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs): + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.bfloat16 + + if use_flash_attn: + kwargs['attn_implementation'] = 'flash_attention_2' + + model_cls = OlaQwenForCausalLM + + # Load OmniSpeech model + if is_lora: + assert model_base is not None, "model_base is required for LoRA models." + from ola.model.language_model.ola_qwen import OlaConfigQwen + lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + print('Loading OmniSpeech from base model...') + model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs) + print('Loading additional OmniSpeech weights...') + if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): + non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') + non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} + if any(k.startswith('model.model.') for k in non_lora_trainables): + non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} + model.load_state_dict(non_lora_trainables, strict=False) + + from peft import PeftModel + print('Loading LoRA weights...') + model = PeftModel.from_pretrained(model, model_path) + print('Merging LoRA weights...') + model = model.merge_and_unload() + print('Model is loaded...') + elif model_base is not None: + print('Loading OmniSpeech from base model...') + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs) + + speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu') + speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()} + model.load_state_dict(speech_projector_weights, strict=False) + model = model.to(device=device) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = model_cls.from_pretrained( + model_path, + low_cpu_mem_usage=False, + **kwargs + ) + model = model.to(device=device) + + model.get_model().speech_encoder = build_speech_encoder(model.config) + model.get_model().speech_encoder.to(device=device, dtype=torch.float16) + + image_processor = None + model.resize_token_embeddings(len(tokenizer)) + vision_tower = model.get_vision_tower() + print("Loading vision tower...") + if not vision_tower.is_loaded: + vision_tower.load_model(device_map=device) + if device != "auto": + vision_tower.to(device="cuda", dtype=torch.bfloat16) + else: + vision_tower.to(device="cuda:0", dtype=torch.bfloat16) + image_processor = vision_tower.image_processor + print("Loading vision tower succeeded.") + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 16384 + + return tokenizer, model, image_processor, context_len diff --git a/ola/model/language_model/__pycache__/ola_qwen.cpython-310.pyc b/ola/model/language_model/__pycache__/ola_qwen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..247f9a8d194f3f158d47453268ba2e657c9193e1 Binary files /dev/null and b/ola/model/language_model/__pycache__/ola_qwen.cpython-310.pyc differ diff --git a/ola/model/language_model/__pycache__/ola_qwen.cpython-38.pyc b/ola/model/language_model/__pycache__/ola_qwen.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc44a1de70c3fd2db5eb1871ea97893302c60815 Binary files /dev/null and b/ola/model/language_model/__pycache__/ola_qwen.cpython-38.pyc differ diff --git a/ola/model/language_model/ola_qwen.py b/ola/model/language_model/ola_qwen.py new file mode 100644 index 0000000000000000000000000000000000000000..fd88538c53603ef929abc3dee892e109e2cd0844 --- /dev/null +++ b/ola/model/language_model/ola_qwen.py @@ -0,0 +1,237 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +import transformers +from transformers import AutoConfig, AutoModelForCausalLM + + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput + +from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM +from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM + + +class OlaConfigQwen(Qwen2Config): + model_type = "ola_qwen" + + +class OlaQwenModel(OlaMetaModel, Qwen2Model): + config_class = OlaConfigQwen + + def __init__(self, config: Qwen2Config): + super(OlaQwenModel, self).__init__(config) + + +class OlaQwenForCausalLM(Qwen2ForCausalLM, OlaMetaForCausalLM): + config_class = OlaConfigQwen + + def __init__(self, config): + super(Qwen2ForCausalLM, self).__init__(config) + + config.rope_scaling = None + self.model = OlaQwenModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + speech: Optional[torch.FloatTensor] = None, + speech_lengths: Optional[torch.LongTensor] = None, + speech_chunks: Optional[torch.LongTensor] = None, + speech_wav: Optional[torch.FloatTensor] = None, + images: Optional[torch.FloatTensor] = None, + images_highres: Optional[List[torch.FloatTensor]] = None, + image_sizes: Optional[List[List[int]]] = None, + modalities: Optional[List[str]] = ["image"], + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_speech_vision_text( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + speech, + speech_lengths, + speech_chunks, + speech_wav, + images, + modalities, + image_sizes, + images_highres + ) + + if labels is None: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + else: + return self.forward_llm_efficient( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + + def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_dim = hidden_states.size(-1) + shift_labels = labels[..., 1:].contiguous().reshape(-1) + shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim) + assert shift_labels.size(0) == shift_hidden_states.size(0) + mask = shift_labels > -1 + assert mask.float().sum() > 0 + shift_labels = shift_labels[mask] + shift_hidden_states = shift_hidden_states[mask, :] + logits = self.lm_head(shift_hidden_states) + logits = logits.float() + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits, shift_labels) + + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + speech: Optional[torch.Tensor] = None, + speech_lengths: Optional[torch.Tensor] = None, + speech_chunks: Optional[torch.Tensor] = None, + speech_wav: Optional[torch.FloatTensor] = None, + images: Optional[torch.Tensor] = None, + images_highres: Optional[List[torch.FloatTensor]] = None, + image_sizes: Optional[torch.Tensor] = None, + modalities: Optional[List[str]] = ["image"], + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_speech_vision_text( + inputs, + position_ids, + attention_mask, + None, + None, + speech, + speech_lengths, + speech_chunks, + speech_wav, + images, + modalities, + image_sizes, + images_highres + ) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + speech = kwargs.pop("speech", None) + speech_lengths = kwargs.pop("speech_lengths", None) + speech_chunks = kwargs.pop("speech_chunks", None) + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if speech is not None: + inputs['speech'] = speech + inputs['speech_lengths'] = speech_lengths + inputs['speech_chunks'] = speech_chunks + if images is not None: + inputs["images"] = images + if image_sizes is not None: + inputs["image_sizes"] = image_sizes + return inputs + +AutoConfig.register("ola_qwen", OlaConfigQwen) +AutoModelForCausalLM.register(OlaConfigQwen, OlaQwenForCausalLM) diff --git a/ola/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc b/ola/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b33c5fd51ecae22a138aae457f7c5e6516e8e06 Binary files /dev/null and b/ola/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc differ diff --git a/ola/model/multimodal_encoder/__pycache__/builder.cpython-38.pyc b/ola/model/multimodal_encoder/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb373aeb3ec8245f5456f67b76c028b159f07493 Binary files /dev/null and b/ola/model/multimodal_encoder/__pycache__/builder.cpython-38.pyc differ diff --git a/ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-310.pyc b/ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6405a5627a6a3f8e86547ff054cdd5124f601585 Binary files /dev/null and b/ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-310.pyc differ diff --git a/ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-38.pyc b/ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3578787a12be9eef0043ec3468ca0d5a49ed0e14 Binary files /dev/null and b/ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-38.pyc differ diff --git a/ola/model/multimodal_encoder/builder.py b/ola/model/multimodal_encoder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..154a20b1732a9766d01c3b784691de8596b49120 --- /dev/null +++ b/ola/model/multimodal_encoder/builder.py @@ -0,0 +1,9 @@ +import os +from .oryx_vit import SigLIPViTAnysizeWrapper + +def build_vision_tower(vision_tower_cfg, **kwargs): + vision_tower = getattr(vision_tower_cfg, 'vision_tower', getattr(vision_tower_cfg, 'mm_vision_tower', None)) + is_absolute_path_exists = os.path.exists(vision_tower) + print(f"Buiding OryxViTWrapper from {vision_tower}...") + # path = vision_tower.split(":")[1] + return SigLIPViTAnysizeWrapper(vision_tower, path=vision_tower, args=vision_tower_cfg, **kwargs) \ No newline at end of file diff --git a/ola/model/multimodal_encoder/oryx_vit.py b/ola/model/multimodal_encoder/oryx_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b3fccfde12c6fdf2d96606260238823132a1da --- /dev/null +++ b/ola/model/multimodal_encoder/oryx_vit.py @@ -0,0 +1,1126 @@ +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import ( + Callable, + Dict, + Final, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) + +from torch.utils.checkpoint import checkpoint +import torch +import torch.nn as nn +import torch.nn.functional as F +try: + from timm.layers import ( + AttentionPoolLatent, + DropPath, + LayerType, + Mlp, + PatchDropout, + PatchEmbed, + resample_abs_pos_embed, + ) + from timm.models._manipulate import checkpoint_seq, named_apply +except: + print('Wrong timm version') + +from flash_attn import flash_attn_func, flash_attn_varlen_func + +from typing import Optional + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed +import os +if 'LOAD_VISION_EARLY' in os.environ: + print("LOAD_VISION_EARLY is set") + LOAD_VISION_EARLY = True +else: + LOAD_VISION_EARLY = False + + +if 'SKIP_LOAD_VIT' in os.environ: + print("SKIP_LOAD_VIT is set") + SKIP_LOAD_VIT = True +else: + SKIP_LOAD_VIT = False + +if 'VIT_WITH_GRAD' in os.environ: + print("VIT_WITH_GRAD is set") + VIT_WITH_GRAD = True +else: + VIT_WITH_GRAD = False + + +if 'FIX_SIZE' in os.environ: + print("FIX_SIZE is set") + FIX_SIZE = True +else: + FIX_SIZE = False + + +if 'ANYRES_SPLIT' in os.environ: + ANYRES_SPLIT = int(os.environ['ANYRES_SPLIT']) + print(f"ANYRES_SPLIT is set as {ANYRES_SPLIT}") +else: + ANYRES_SPLIT = None + + +if 'FORCE_NO_DOWNSAMPLE' in os.environ: + print("FORCE_NO_DOWNSAMPLE is set") + FORCE_NO_DOWNSAMPLE = True +else: + FORCE_NO_DOWNSAMPLE = False + +if 'EVAL_72B' in os.environ: + print("EVAL_72B is set") + EVAL_72B = True +else: + EVAL_72B = False + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) # noqa: E741 + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first + convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype. + Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn + from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + + with torch.no_grad(): + dtype = tensor.dtype + tensor_fp32 = tensor.float() + tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b) + tensor_dtype = tensor_fp32.to(dtype=dtype) + tensor.copy_(tensor_dtype) + + +def init_weights(self): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_(self.latent, std=self.latent_dim**-0.5) + + +def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + # self.fused_attn = use_fused_attn() + self.fused_attn = True + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if cu_slens is not None: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + max_seqlen = torch.max(cu_slens[1:] - cu_slens[:-1]).item() + x = flash_attn_varlen_func( + q.squeeze(0), + k.squeeze(0), + v.squeeze(0), + cu_seqlens_q=cu_slens, + cu_seqlens_k=cu_slens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=self.scale, + causal=False, + ) + + x = x.reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + else: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + x = flash_attn_func(q, k, v, softmax_scale=self.scale) # -> b, n, h, c + + x = x.reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + # if self.fused_attn: + # x = F.scaled_dot_product_attention( + # q, + # k, + # v, + # dropout_p=self.attn_drop.p if self.training else 0.0, + # ) + # else: + # q = q * self.scale + # attn = q @ k.transpose(-2, -1) + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + # x = attn @ v + + # x = x.transpose(1, 2).reshape(B, N, C) + # x = self.proj(x) + # x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_slens=cu_slens))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal["", "avg", "token", "map"] = "token", + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0.0, + pos_drop_rate: float = 0.0, + patch_drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", + embed_layer: Callable = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + strict_img_size: bool = False, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ignore_head: bool = False, + add_patch2x2: bool = False, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Mumber of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + embed_layer: Patch embedding layer. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ("", "avg", "token", "map") + assert class_token or global_pool != "token" + use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm + # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + # act_layer = get_act_layer(act_layer) or nn.GELU + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = ( + no_embed_class # don't embed prefix positions (includes reg) + ) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + self.ignore_head = ignore_head + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + strict_img_size=strict_img_size, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = ( + nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + ) + self.reg_token = ( + nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + ) + embed_len = ( + num_patches if no_embed_class else num_patches + self.num_prefix_tokens + ) + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) + + + # deepspeed.zero.register_external_parameter(self, self.pos_embed) + # deepspeed.zero.register_external_parameter(self, self.patch_embed.proj.weight) + # deepspeed.zero.register_external_parameter(self, self.patch_embed.proj.bias) + # print(self.patch_embed.state_dict().keys()) + + + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth) + ] + ) + + + if add_patch2x2: + if add_patch2x2 == 'v2': + self.downsample = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim*2, kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(embed_dim*2, embed_dim*4, 1) + ) + else: + mid_dim = embed_dim * 2 + self.downsample = nn.Sequential( + nn.Conv2d(embed_dim, mid_dim, kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(mid_dim, mid_dim, 1) + ) + + else: + self.downsample = None + + + # self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # # Classifier Head + # if global_pool == "map": + # AttentionPoolLatent.init_weights = init_weights + # self.attn_pool = AttentionPoolLatent( + # self.embed_dim, + # num_heads=num_heads, + # mlp_ratio=mlp_ratio, + # norm_layer=norm_layer, + # ) + # else: + # self.attn_pool = None + # self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + # self.head_drop = nn.Dropout(drop_rate) + # self.head = ( + # nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + # ) + + # if weight_init != "skip": + # self.init_weights(weight_init) + + def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None: + assert mode in ("jax", "jax_nlhb", "moco", "") + # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {"pos_embed", "cls_token", "dist_token"} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r"^cls_token|pos_embed|patch_embed", # stem and embed + blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None) -> None: + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ("", "avg", "token", "map") + if global_pool == "map" and self.attn_pool is None: + assert ( + False + ), "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != "map " and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def rescale_positional_embedding(self, out_size): + h, w = out_size + pos_embed_shape = int((self.pos_embed.shape[1]) ** 0.5) + if (h, w) == (pos_embed_shape, pos_embed_shape): + return self.pos_embed + rescaled_positional_embedding = \ + self.pos_embed.new_zeros(1, h*w, self.pos_embed.shape[2]) + pe_2d = self.pos_embed[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape) + pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w) + rescaled_positional_embedding[0] = pe_2d.T.contiguous() + return rescaled_positional_embedding + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def _intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + ) -> List[torch.Tensor]: + outputs, num_blocks = [], len(self.blocks) + take_indices = set( + range(num_blocks - n, num_blocks) if isinstance(n, int) else n + ) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + """Intermediate layer accessor (NOTE: This is a WIP experiment). + Inspired by DINO / DINOv2 interface + """ + # take last n blocks if n is an int, if in is a sequence, select by matching indices + outputs = self._intermediate_layers(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] + outputs = [out[:, self.num_prefix_tokens :] for out in outputs] + + if reshape: + grid_size = self.patch_embed.grid_size + outputs = [ + out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) + return tuple(outputs) + + def forward_features_list(self, x_list): + x_all = [] + image_sizes = [] + for x in x_list: + if EVAL_72B: + x = x.to('cuda:0') + bs, _, h, w = x.shape + + # fix patch size=14 in datasets + pad_h = (self.patch_embed.patch_size[0] - h % self.patch_embed.patch_size[0]) % self.patch_embed.patch_size[0] + pad_w = (self.patch_embed.patch_size[1] - w % self.patch_embed.patch_size[1]) % self.patch_embed.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + + bs, _, h, w = x.shape + + h = h // self.patch_embed.patch_size[0] + w = w // self.patch_embed.patch_size[1] + + x = self.patch_embed(x) + # x = self._pos_embed(x) + x = x + self.rescale_positional_embedding(out_size=(h, w)) + x = self.patch_drop(x) + x = self.norm_pre(x) + x_all.append(x) + image_sizes.append((h, w)) + + slen = [xi.size(1) for xi in x_all] + x = torch.cat(x_all, dim=1) + + cu_indices = [0, ] + for i in slen: + cu_indices.append(cu_indices[-1] + i) + + cu_slens = torch.tensor(cu_indices, dtype=torch.int32).to(x.device) + for idx, blk in enumerate(self.blocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, cu_slens, use_reentrant=True) + else: + x = blk(x, cu_slens=cu_slens) + feats = x.split(slen, dim=1) #[(1, slen, c)] + + if self.downsample is not None: + new_feats = [] + new_sizes = [] + for f, s in zip(feats, image_sizes): + h, w = s + b, n, c = f.size() + f = f.reshape(b, h, w, c).permute(0, 3, 1, 2) + f = self.downsample(f) + b, c, h, w = f.size() + f = f.permute(0, 2, 3, 1).reshape(b, h*w, c) + new_feats.append(f) + new_sizes.append((h, w)) + return new_feats, new_sizes + + + return feats, image_sizes + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + if EVAL_72B: + x = x.to('cuda:0') + bs, _, h, w = x.shape + h = h // self.patch_embed.patch_size[0] + w = w // self.patch_embed.patch_size[1] + + x = self.patch_embed(x) + # x = self._pos_embed(x) + x = x + self.rescale_positional_embedding(out_size=(h, w)) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + + if self.downsample is not None: + b, n, c = x.size() + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + x = self.downsample(x) + b, c, h, w = x.size() + x = x.permute(0, 2, 3, 1).reshape(b, h*w, c) + new_feats = x + new_sizes = (h, w) + return new_feats, new_sizes + + return x, (h, w) + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = self.norm(x) + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == "avg": + x = x[:, self.num_prefix_tokens :].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x, cal_attn_pool=False): + if type(x) is list: + x, image_sizes = self.forward_features_list(x) + return x, image_sizes, None + else: + x, image_sizes = self.forward_features(x) + return x, image_sizes, None + +@dataclass +class SigLIPVisionCfg: + width: int = 1152 + layers: Union[Tuple[int, int, int, int], int] = 27 + heads: int = 16 + patch_size: int = 14 + image_size: Union[Tuple[int, int], int] = 336 + global_pool: str = "map" + mlp_ratio: float = 3.7362 + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + + +SigLIP_MODEL_CONFIG = { + "siglip_so400m_patch14_384": { + "image_size": 384, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_so400m_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_so400m_patch14_224": { + "image_size": 224, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_large_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1024, + "layers": 24, + "heads": 16, + "mlp_ratio": 4, + "global_pool": "map", + "use_checkpoint": False, + }, +} + + +def resize_evaclip_pos_embed(model: VisionTransformer, interpolation: str = 'bicubic'): + # interpolate position embedding + orig_size = 24 + new_size = 128 + pos_tokens = model.pos_embed + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, model.embed_dim).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode=interpolation, align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + model.pos_embed = nn.Parameter(pos_tokens, requires_grad=True) + return model + +def create_siglip_vit( + model_name: str = "siglip_so400m_patch14_384", + image_size: int = 384, + select_layer: int = -1, + path: str = "", + gradient_checkpointing: bool = False, + **kwargs, +): + assert ( + model_name in SigLIP_MODEL_CONFIG.keys() + ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" + + vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) + + if select_layer <= 0: + layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) + else: + layers = min(vision_cfg.layers, select_layer) + + + + if 'patch2x2' or 'patch4x4' in path: + add_patch2x2 = True + else: + add_patch2x2 = False + + if 'patch4x4pool' in path or 'patch2x2from4x4' in path: + add_patch2x2 = 'v2' + + if FORCE_NO_DOWNSAMPLE: + add_patch2x2 = False + + model = VisionTransformer( + img_size=2048, + patch_size=16, + embed_dim=vision_cfg.width, + depth=layers, + num_heads=vision_cfg.heads, + mlp_ratio=vision_cfg.mlp_ratio, + class_token=vision_cfg.class_token, + global_pool=vision_cfg.global_pool, + dynamic_img_pad=False, + strict_img_size=False, + ignore_head=kwargs.get("ignore_head", False), + weight_init=kwargs.get("weight_init", "skip"), + num_classes=0, + add_patch2x2=add_patch2x2 + ) + + if not SKIP_LOAD_VIT: + if path is not None and os.path.exists(path): + ckpt = path + else: + raise ValueError(f"Model checkpoint not found at {path}") + state_dict = torch.load(ckpt, map_location="cpu") + print('loading vision backbone from', path) + + if 'genli' in path: + new_sd = {} + for k in state_dict.keys(): + if k.startswith('base_model.model.model.vision_tower.vision_tower.'): + new_k = k.replace('base_model.model.model.vision_tower.vision_tower.', '') + new_sd[new_k] = state_dict[k] + + if add_patch2x2: + if k.startswith('base_model.model.model.mm_projector.proj'): + new_k = k.replace('base_model.model.model.mm_projector.proj', 'downsample') + new_sd[new_k] = state_dict[k] + + elif 'distill' in path: + new_sd = {} + state_dict = state_dict['model'] + for k in state_dict.keys(): + if k.startswith('vision_tower.'): + new_k = k.replace('vision_tower.', '') + new_sd[new_k] = state_dict[k] + else: + raise NotImplementedError + msg = model.load_state_dict(new_sd, strict=False) + print(msg) + + else: + print("#### Skip loading vision backbone") + + if gradient_checkpointing: + model.set_grad_checkpointing(True) + return model + +from transformers import CLIPImageProcessor +import torch.distributed as dist + +class SigLIPViTAnysizeWrapper(nn.Module): + def __init__(self, vision_tower, path, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.args = args + self.path = path + + self.select_layer = -1 + if self.select_layer < -1: self.select_layer += 1 + self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') + + self.output_dim = 1152 + if not FORCE_NO_DOWNSAMPLE: + if 'patch2x2' or 'patch4x4' in path: + self.output_dim = 1152*2 + + if 'patch4x4pool' in path or 'patch2x2from4x4' in path: + self.output_dim = 1152*4 + + if not delay_load or LOAD_VISION_EARLY: + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + + def load_model(self, device_map=None): + if self.is_loaded: + print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") + if self.args.mm_projector_type == "conv_mlp" or self.args.mm_projector_type == "multipath_conv_mlp" or self.args.mm_projector_type == "multipath_conv_mlp_woconv": + self.image_processor.crop_size['height'] = 384 + self.image_processor.crop_size['width'] = 384 + self.image_processor.size['shortest_edge'] = 384 + print("Resizeing clip processor to 384...") + self.image_processor.image_mean = [0.5, 0.5, 0.5] + self.image_processor.image_std = [0.5, 0.5, 0.5] + print("Loading vision model...") + if VIT_WITH_GRAD: + self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384', + gradient_checkpointing=True) + self.vision_tower.train() + else: + self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384', + gradient_checkpointing=False) + for p in self.vision_tower.parameters(): + p.requires_grad = False + self.vision_tower.eval() + self.is_loaded = True + + def train(self, mode = True): + self.training = mode + + if self.is_loaded and not VIT_WITH_GRAD: + self.vision_tower.eval() + + def split_images(self, images, split_res=512, base_size=32): + split_images = [] + sub_images_info = [] + for image in images: + now_sub_images = [] + _, c, h, w = image.shape + if h * w <= split_res * split_res: + split_images.append(image) + sub_images_info.append( + ( + 1, 1, 1, h // base_size, w // base_size, [(0, h // base_size, 0, w // base_size)] + ) + ) + continue + nsplit_h = math.ceil(h / split_res) + nsplit_w = math.ceil(w / split_res) + sub_h = int(h / nsplit_h / base_size) * base_size + sub_w = int(w / nsplit_w / base_size) * base_size + crop_infos = [] + for i in range(nsplit_h): + for j in range(nsplit_w): + begin_h = i * sub_h + begin_w = j * sub_w + + if i == nsplit_h - 1: + end_h = h + else: + end_h = (i + 1) * sub_h + + if j == nsplit_w - 1: + end_w = w + else: + end_w = (j + 1) * sub_w + + assert (end_h - begin_h) % base_size == 0 and (end_w - begin_w) % base_size == 0 + + sub_image = image[:, :, begin_h:end_h, begin_w:end_w] + now_sub_images.append(sub_image) + crop_infos.append( + (begin_h // base_size, end_h // base_size, begin_w // base_size, end_w // base_size) + ) + + split_images += now_sub_images + sub_images_info.append( + ( + len(now_sub_images), nsplit_h, nsplit_w, h // base_size, w // base_size, crop_infos + ) + ) + + return split_images, sub_images_info + + + def unsplit_images(self, features, sizes, sub_images_info): + new_features = [] + for feature, size in zip(features, sizes): + h, w = size + new_features.append( + feature.reshape(1, h, w, -1) + ) + + fused_images = [] + images_sizes = [] + sub_count = 0 + for n_split, nsplit_h, nsplit_w, total_h, total_w, crop_infos in sub_images_info: + sub_features = new_features[sub_count:sub_count+n_split] + sub_count += n_split + + total_feature = new_features[0].new_zeros(1, total_h, total_w, self.hidden_size) + for feature, (begin_h, end_h, begin_w, end_w) in zip(sub_features, crop_infos): + total_feature[:, begin_h:end_h, begin_w:end_w] += feature + + fused_images.append(total_feature.reshape(1, total_h * total_w, self.hidden_size)) + images_sizes.append((total_h, total_w)) + + return fused_images, images_sizes + + + + def forward_func(self, images, force_fix_size=False, cal_attn_pool=False): + if type(images) is list: + xs = [x.to(self.dtype) for x in images] + image_features, img_size, cls_token = self.vision_tower(xs, cal_attn_pool=cal_attn_pool) + image_features = [x.to(images[0].dtype) for x in image_features] + + else: + image_forward_outs, img_size, cls_token = self.vision_tower(images.to(self.dtype), cal_attn_pool=cal_attn_pool) + image_features = image_forward_outs.to(images.dtype) + + return image_features, img_size, cls_token + + def forward(self, images, cal_attn_pool=False): + if VIT_WITH_GRAD: + image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool) + return image_features, img_size + else: + with torch.no_grad(): + image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool) + return image_features, img_size + + + @property + def dummy_feature(self): + return torch.zeros(1, 1152, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.pos_embed.dtype + + @property + def device(self): + return self.vision_tower.pos_embed.device + + @property + def hidden_size(self): + return self.output_dim + + @property + def config(self): + return type('LLaVAConfigWrapper', (), { + # 'image_size': 224, + 'patch_size': 16, + })() diff --git a/ola/model/multimodal_projector/__pycache__/builder.cpython-310.pyc b/ola/model/multimodal_projector/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5604aad078d773d17b059ea8eff653c9846f96f3 Binary files /dev/null and b/ola/model/multimodal_projector/__pycache__/builder.cpython-310.pyc differ diff --git a/ola/model/multimodal_projector/__pycache__/builder.cpython-38.pyc b/ola/model/multimodal_projector/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..531bc218749820adfac35c4bb257ef3ec25e510f Binary files /dev/null and b/ola/model/multimodal_projector/__pycache__/builder.cpython-38.pyc differ diff --git a/ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc b/ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..086e7b4b733b608d2f7b72d566aab474fb45e675 Binary files /dev/null and b/ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc differ diff --git a/ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-38.pyc b/ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77329e2216d0da0900681ab3a9d8437739a2135b Binary files /dev/null and b/ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-38.pyc differ diff --git a/ola/model/multimodal_projector/builder.py b/ola/model/multimodal_projector/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f995ab8133f5c268e952a3fce98242669b3c7370 --- /dev/null +++ b/ola/model/multimodal_projector/builder.py @@ -0,0 +1,179 @@ +import torch +import torch.nn as nn +import re + +import math + +from .pooler_projector import NormalizedDwPooler +import os +import math + + +if 'REGIONAL_POOL' in os.environ: + REGIONAL_POOL = os.environ['REGIONAL_POOL'] +else: + REGIONAL_POOL = '2x' +print(f"REGIONAL_POOL is set as {REGIONAL_POOL}") + +class IdentityMap(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_projector_type": 'identity'} + + +class SimpleResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.pre_norm = nn.LayerNorm(channels) + + self.proj = nn.Sequential( + nn.Linear(channels, channels), + nn.GELU(), + nn.Linear(channels, channels) + ) + def forward(self, x): + x = self.pre_norm(x) + return x + self.proj(x) + +class OlaMLP(nn.Module): + def __init__(self, in_channels, out_channels, twoview=False): + super().__init__() + + self.proj1 = nn.Linear(in_channels, out_channels) + self.proj2 = nn.Linear(out_channels, out_channels) + self.act = nn.GELU() + self.pooler = NormalizedDwPooler(out_channels) + + embed_std = 1 / math.sqrt(out_channels) + self.image_newline = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + self.image_begin = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + self.image_end = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + + if twoview: + self.image_sep = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + + def forward(self, x, size=(16,16), x2=None, size2=(16, 16), modalities='image'): + + if modalities in ['image', 'text']: + h, w = size + dtype = x.dtype + x = x.reshape(x.shape[0], h, w, -1) + x = self.proj1(x) + x = self.pooler(x, forward_type=REGIONAL_POOL) + x = self.act(x) + x = self.proj2(x) + + + b, h, w, c = x.shape + x = torch.cat([ + x, + self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype) + ], dim=2) + x = x.reshape(b, -1, c) + + if x2 is not None: + h2, w2 = size2 + x2 = x2.reshape(x2.shape[0], h2, w2, -1) + x2 = self.proj1(x2) + x2 = self.pooler(x2, forward_type=REGIONAL_POOL) + x2 = self.act(x2) + x2 = self.proj2(x2) + + b2, h2, w2, c2 = x2.shape + x2 = torch.cat([ + x2, + self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype) + ], dim=2) + x2 = x2.reshape(b, -1, c) + sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype) + x = torch.cat([x, sep, x2], dim=1) + + begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype) + end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype) + x = torch.cat([begin, x, end], dim=1) + return x + elif modalities in ['video']: + # x2 is the true feature, ignore x + h, w = size + dtype = x.dtype + x = x.reshape(x.shape[0], h, w, -1) + x1 = self.proj1(x) + x1 = self.pooler(x1, forward_type=REGIONAL_POOL) + x1 = self.proj2(x1).mean() * 0.0 + + h2, w2 = size2 + x2 = x2.reshape(x2.shape[0], h2, w2, -1) + x2 = self.proj1(x2) + x2 = self.pooler(x2, forward_type=REGIONAL_POOL) + x2 = self.act(x2) + x2 = self.proj2(x2) + + b2, h2, w2, c = x2.shape + x2 = torch.cat([ + x2, + self.image_newline.reshape(1, 1, 1, c).expand(b2, h2, 1, c).to(dtype) + ], dim=2) + + x2 = x2.reshape(b2, -1, c) + + sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, c).to(dtype) + x2 = torch.cat([x2, sep], dim=1) + + x2 = x2.flatten(0, 1) + + begin = self.image_begin.reshape(1, -1).expand(1, c).to(dtype) + end = self.image_end.reshape(1, -1).expand(1, c).to(dtype) + x2 = torch.cat([begin, x2, end], dim=0) + x2 = x2.unsqueeze(0) + return x2 + else: + raise ValueError(f'Unknown modalities: {modalities}') + +def build_vision_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, 'mm_projector_type', 'linear') + + if projector_type == 'linear': + return nn.Linear(config.mm_hidden_size, config.hidden_size) + + elif projector_type == 'ola_mlp': + return OlaMLP(config.mm_hidden_size, config.hidden_size, twoview=True) + + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + mlp_gelu_resnet_match = re.match(r'^mlp(\d+)x_res(\d+)x_gelu$', projector_type) + if mlp_gelu_resnet_match: + mlp_depth = int(mlp_gelu_resnet_match.group(1)) + res_depth = int(mlp_gelu_resnet_match.group(2)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + for _ in range(res_depth): + modules.append(SimpleResBlock(config.hidden_size)) + return nn.Sequential(*modules) + + if projector_type == 'identity': + return IdentityMap() + + raise ValueError(f'Unknown projector type: {projector_type}') diff --git a/ola/model/multimodal_projector/pooler_projector.py b/ola/model/multimodal_projector/pooler_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..f5cbc8485c59eca0c186b96baa0c718f1ea2721c --- /dev/null +++ b/ola/model/multimodal_projector/pooler_projector.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from transformers.models.clip.modeling_clip import CLIPVisionModel +import os + +if 'NORMALIZE_POOL' in os.environ: + NORMALIZE_POOL = bool(int(os.environ['NORMALIZE_POOL'])) + print(f'NORMALIZE_POOL: {NORMALIZE_POOL}') +else: + NORMALIZE_POOL = True + + +class PoolerProjector(nn.Module): + def __init__(self, config, vision_cfg): + super().__init__() + self._config = config + self.hw = vision_cfg.image_size // vision_cfg.patch_size + + self.conv_pool = nn.Conv2d( + config.mm_hidden_size, config.hidden_size, + kernel_size=2, stride=2 + ) + + self.proj = nn.Sequential( + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size), + ) + + def forward(self, x, *args, **kwargs): + height = width = self.hw + assert height * width == x.shape[1] + x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) + x = self.conv_pool(x) + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + @property + def config(self): + return {"mm_projector_type": 'pooler'} + + +class NormalizedDwPooler(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + self.predictor = nn.Sequential( + nn.Linear(dim*2, dim), + nn.GELU(), + nn.Linear(dim, dim), + ) + + def forward(self, x, forward_type='2x'): + B, H, W, C = x.shape + + if forward_type == '2x': + new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C) + pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1) + fused_x = torch.cat([new_x, pooled_x], dim=-1) + elif forward_type == '1x': + new_x = x.reshape(B, H, W, 1, C) + fused_x = torch.cat([new_x, new_x], dim=-1) + elif forward_type == '4x': + new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C) + pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1) + fused_x = torch.cat([new_x, pooled_x], dim=-1) + + score = self.predictor(fused_x) + normalized_score = F.softmax(score, dim=-2) + new_x = (new_x * normalized_score).sum(dim=-2) + return new_x diff --git a/ola/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc b/ola/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2138654bbfe0ad7f5d1e9c6818858e070a2bb373 Binary files /dev/null and b/ola/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc differ diff --git a/ola/model/multimodal_resampler/__pycache__/builder.cpython-38.pyc b/ola/model/multimodal_resampler/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6adaf0218cb3206cebc0b7d3aa8f0d5d4ab66a7 Binary files /dev/null and b/ola/model/multimodal_resampler/__pycache__/builder.cpython-38.pyc differ diff --git a/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc b/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14cc5b34f7a5b9bcea16d2147254d5bc07efab61 Binary files /dev/null and b/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc differ diff --git a/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-38.pyc b/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..128cd5465170605a1ccd8a9580f707c215b69ec4 Binary files /dev/null and b/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-38.pyc differ diff --git a/ola/model/multimodal_resampler/builder.py b/ola/model/multimodal_resampler/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..91d77fc4cdf217fa84fa19e7faf08638407d8ea7 --- /dev/null +++ b/ola/model/multimodal_resampler/builder.py @@ -0,0 +1,24 @@ +import torch + +from .perceiver import DynamicCompressor + +class IdentityMap(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_resampler_type": None} + +def build_vision_resampler(model_args, delay_load=False, **kwargs): + # import pdb;pdb.set_trace() + resampler_type = getattr(model_args, 'mm_resampler_type', None) + if resampler_type == 'dynamic_compressor': + return DynamicCompressor(model_args, **kwargs) + elif resampler_type is None: + return IdentityMap() + else: + raise ValueError(f'Unknown resampler type: {resampler_type}') diff --git a/ola/model/multimodal_resampler/perceiver.py b/ola/model/multimodal_resampler/perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..e481559ce2e92eeebf891cc6bd8458e74e7eb051 --- /dev/null +++ b/ola/model/multimodal_resampler/perceiver.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +import math +import torch.nn.functional as F +import os +if 'EVAL_LARGE' in os.environ: + print("EVAL_LARGE is set") + EVAL_LARGE = True +else: + EVAL_LARGE = False + +class DynamicCompressor(nn.Module): + def __init__(self, model_args, vision_tower): + super().__init__() + + self.out_channels = vision_tower.hidden_size + self.mid_channel = 256 + + self.vlm_query_projector = nn.Linear(self.out_channels, self.mid_channel) + self.vlm_key_projector = nn.Linear(self.out_channels, self.mid_channel) + + def downsample(self, x): + return F.avg_pool2d(x, 2, 2) + + def downsample_4(self, x): + return F.avg_pool2d(x, 4, 4) + + def forward(self, image_features, forward_type, image_size=None): + if image_size is None: + ori_W = int(math.sqrt(image_features.shape[1])) + ori_H = int(ori_W) + else: + ori_H, ori_W = image_size + T, N, C = image_features.shape + image_features = image_features.view(T, ori_H, ori_W, C).permute(0, 3, 1, 2) # T, C, H, W + + if forward_type == 'video': + image_features_pool = self.downsample(image_features) + image_feature_attn = image_features.reshape(T, C, ori_H // 2, 2, ori_W // 2, 2).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H // 2 * ori_W // 2, 4, C) + new_image_size = (ori_H // 2, ori_W // 2) + elif forward_type == 'image' or forward_type == 'text': + image_features_pool = image_features + image_feature_attn = image_features.reshape(T, C, ori_H, 1, ori_W, 1).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H * ori_W, 1, C) + new_image_size = (ori_H, ori_W) + elif forward_type == 'video_long': + image_features_pool = self.downsample_4(image_features) + image_feature_attn = image_features.reshape(T, C, ori_H // 4, 4, ori_W // 4, 4).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H // 4 * ori_W // 4, 16, C) + new_image_size = (ori_H // 4, ori_W // 4) + else: + raise NotImplementedError + + image_features_pool = image_features_pool.flatten(2).permute(0, 2, 1) # T, H*W, C + new_t, new_p, _ = image_features_pool.shape + + if EVAL_LARGE: + image_features_pool = image_features_pool.to(self.vlm_query_projector.weight.device) + image_feature_attn = image_feature_attn.to(self.vlm_key_projector.weight.device) + + image_query = self.vlm_query_projector(image_features_pool).reshape(new_t*new_p, self.mid_channel) + image_key = self.vlm_key_projector(image_feature_attn).reshape(new_t*new_p, -1, self.mid_channel) + + image_value = image_feature_attn.reshape(new_t*new_p, -1, self.out_channels) + # import pdb;pdb.set_trace() + + image_attn = image_query[:,None] @ (image_key.transpose(-1,-2) / (image_key.shape[-1]**0.5)) + image_attn = image_attn.nan_to_num() + attn_feat = (image_attn.softmax(-1) @ image_value).mean(1).reshape(new_t, new_p, C) + + image_features_pool = image_features_pool + attn_feat + + return image_features_pool, new_image_size + + @property + def config(self): + return { + 'mm_resampler_type': 'dynamic_compressor', + 'mm_out_channels': self.out_channels, + } + + @property + def hidden_size(self): + return self.out_channels diff --git a/ola/model/ola_arch.py b/ola/model/ola_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5aa80d90c24c0301e9f94ee9902079b92e3c23 --- /dev/null +++ b/ola/model/ola_arch.py @@ -0,0 +1,418 @@ +from abc import ABC, abstractmethod + +import torch + +from .speech_encoder.builder import build_speech_encoder +from .speech_projector.builder import build_speech_projector +from ola.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX +from ola.utils import lengths_to_padding_mask + +from .multimodal_encoder.builder import build_vision_tower +from .multimodal_resampler.builder import build_vision_resampler +from .multimodal_projector.builder import build_vision_projector + +from ola.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + +class OlaMetaModel: + + def __init__(self, config): + super(OlaMetaModel, self).__init__(config) + + if hasattr(config, "speech_encoder"): + self.speech_encoder = build_speech_encoder(config) + self.speech_projector = build_speech_projector(config) + + if hasattr(config, "mm_vision_tower"): + self.vision_tower = build_vision_tower(config, delay_load=True) + self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower) + self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) + + def get_speech_encoder(self): + speech_encoder = getattr(self, 'speech_encoder', None) + if type(speech_encoder) is list: + speech_encoder = speech_encoder[0] + return speech_encoder + + def get_vision_tower(self): + vision_tower = getattr(self, 'vision_tower', None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def initialize_speech_modules(self, model_args, fsdp=None): + self.config.speech_encoder = getattr(model_args, "speech_encoder", None) + self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None) + self.config.speech_projector_type = getattr(model_args, 'speech_projector_type', 'linear') + self.config.speech_encoder_ds_rate = getattr(model_args, 'speech_encoder_ds_rate', 5) + self.config.speech_encoder_hidden_size = getattr(model_args, 'speech_encoder_hidden_size', 1280) + + if self.get_speech_encoder() is None: + speech_encoder = build_speech_encoder(self.config) + if fsdp is not None and len(fsdp) > 0: + self.speech_encoder = [speech_encoder] + else: + self.speech_encoder = speech_encoder + + if getattr(self, 'speech_projector', None) is None: + self.speech_projector = build_speech_projector(self.config) + else: + # In case it is frozen by LoRA + for p in self.speech_projector.parameters(): + p.requires_grad = True + + if model_args.pretrain_speech_projector is not None: + pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + print('Loading pretrain speech projector weights') + + msg = self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, 'speech_projector'), strict=False) + print(msg) + + def initialize_vision_modules(self, model_args, fsdp=None): + vision_tower = model_args.vision_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + + self.config.mm_vision_tower = vision_tower + + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) + vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower) + ## Get the mm_spatial_pool_mode and mm_spatial_pool_stride + for k, v in vision_resampler.config.items(): + setattr(self.config, k, v) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + self.vision_resampler = [vision_resampler] + else: + self.vision_tower = vision_tower + self.vision_resampler = vision_resampler + else: + if fsdp is not None and len(fsdp) > 0: + vision_resampler = self.vision_resampler[0] + vision_tower = self.vision_tower[0] + else: + vision_resampler = self.vision_resampler + vision_tower = self.vision_tower + vision_tower.load_model() + + # In case it is frozen by LoRA + for p in self.vision_resampler.parameters(): + p.requires_grad = True + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') + self.config.mm_hidden_size = getattr(vision_resampler, 'hidden_size', vision_tower.hidden_size) + + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + + if getattr(self, 'mm_projector', None) is None: + self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) + else: + for p in self.mm_projector.parameters(): + p.requires_grad = True + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + print('Loading pretrain mm projector weights') + incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, 'vision_resampler'), strict=False) + print(incompatible_keys) + +class OlaMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_speech_encoder(self): + return self.get_model().get_speech_encoder() + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_speech_projector(self): + return self.get_model().speech_projector + + def encode_speech(self, speech, speech_lengths, speech_wav): + # import pdb; pdb.set_trace() + speech_encoder_type = self.config.speech_encoder_type + speech_encoder = self.get_speech_encoder() + if "whisper" in speech_encoder_type.lower(): + encoder_outs = speech_encoder(speech.permute(0, 2, 1)) + speech_lengths = (speech_lengths + 1) // 2 + else: + encoder_outs = speech_encoder(speech.permute(0, 2, 1), raw_wav=speech_wav) + speech_lengths = (speech_lengths + 1) // 2 + speech_projector_type = self.config.speech_projector_type + speech_projector = self.get_speech_projector() + if speech_projector_type == "linear": + encoder_outs = speech_projector(encoder_outs) + speech_lengths = speech_lengths // speech_projector.k + else: + raise ValueError(f'Unknown speech projector: {speech_projector_type}') + # speech_features = [encoder_outs[i, :speech_lengths[i]] for i in range(len(encoder_outs))] + return encoder_outs + + def prepare_inputs_labels_for_speech_vision_text( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + speech, speech_lengths, speech_chunks, speech_wav, images, modalities, image_sizes=None, images_highres=None + ): + speech_encoder = self.get_speech_encoder() + vision_tower = self.get_vision_tower() + + if speech_encoder is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if vision_tower is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + # encode speech + if not isinstance(speech, list): + speech = torch.split(speech, speech_chunks.tolist(), dim=0) + speech_lengths = torch.split(speech_lengths, speech_chunks.tolist(), dim=0) + speech_wav = torch.split(speech_wav, speech_chunks.tolist(), dim=0) + speech_features = [] + for idx in range(len(speech)): + speech_features.append(self.encode_speech(speech[idx], speech_lengths[idx], speech_wav[idx])) + + # encode vision + if isinstance(modalities, str): + modalities = [modalities] + + video_idx_in_batch = [] + for modal in range(len(modalities)): + if 'video' in modalities[modal]: + video_idx_in_batch.append(modal) + + # Fix training with deepspeed zero3 + num_modality = len(modalities) + # try: + # world_size = dist.get_world_size() + # tensor_in = torch.zeros(1, dtype=torch.int64, device=images[0].device).fill_(num_modality) + # tensor_out = torch.zeros(world_size, dtype=torch.int64, device=images[0].device) + # dist.all_gather_into_tensor(tensor_out, tensor_in) + # max_num_modality = tensor_out.max().item() + # except: + # max_num_modality = num_modality + aimg = images[-1] + lowres_img = [] + for idx, img_feat in enumerate(images): + if idx in video_idx_in_batch: + img_feat = aimg.new(1, 3, 128, 128).fill_(0) + lowres_img.append(img_feat) + + # Fix training with deepspeed zero3 + # if max_num_modality > num_modality: + # for _ in range(max_num_modality - num_modality): + # lowres_img.append(aimg.new(1, 3, 64, 64).fill_(0)) + # images_highres.append(aimg.new(1, 3, 64, 64).fill_(0)) + # modalities.append('image') + lowres_img_features, lowres_img_sizes = self.get_model().get_vision_tower()(lowres_img) + highres_img_features = [] + highres_img_sizes = [] + for idx, img_feat in enumerate(images_highres): + if img_feat.ndim == 5: + img_feat = img_feat.squeeze(1) + highres_img_feature, highres_img_size = self.get_model().get_vision_tower()(img_feat) + highres_img_features.append(highres_img_feature) + highres_img_sizes.append(highres_img_size) + image_features = [] + for idx in range(len(modalities)): + img_feat = self.get_model().mm_projector(lowres_img_features[idx], + lowres_img_sizes[idx], + highres_img_features[idx], + highres_img_sizes[idx], + modalities[idx]) + image_features.append(img_feat.flatten(0, 1)) + + # if max_num_modality > num_modality: + # image_features = image_features[:num_modality] + # modalities = modalities[:num_modality] + + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_speech_idx = 0 + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + + num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + + num_speech_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + + if num_speech_images == 0: + cur_speech_features = speech_features[cur_speech_idx] + cur_images_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0], cur_images_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_speech_idx += 1 + cur_image_idx += 1 + continue + + speech_image_token_indices = [-1] + torch.where((cur_input_ids == SPEECH_TOKEN_INDEX) | (cur_input_ids == IMAGE_TOKEN_INDEX))[0].tolist() + [cur_input_ids.shape[0]] + + cur_input_ids_nospeech_image = [] + cur_labels = labels[batch_idx] + cur_labels_nospeech_image = [] + for i in range(len(speech_image_token_indices) - 1): + cur_input_ids_nospeech_image.append(cur_input_ids[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + cur_labels_nospeech_image.append(cur_labels[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + split_sizes = [x.shape[0] for x in cur_labels_nospeech_image] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech_image)) + cur_input_embeds_no_speech_image = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_speech_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i]) + cur_new_labels.append(cur_labels_nospeech_image[i]) + if i < num_speech_images: + if i < num_images: + cur_images_features = image_features[cur_image_idx] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_images_features) + cur_new_labels.append(torch.full((cur_images_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + else: + cur_speech_features = speech_features[cur_speech_idx] + cur_speech_idx += 1 + cur_new_input_embeds.append(cur_speech_features) + cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + if num_images == 0: + cur_new_input_embeds = torch.cat([cur_new_input_embeds, image_features[cur_image_idx][0:0]], dim=0) + cur_image_idx += 1 + + if num_speech == 0: + cur_new_input_embeds = torch.cat([cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0) + cur_speech_idx += 1 + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as speech features can make the sequence longer + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": + new_input_embeds_padded.append(torch.cat(( + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), + cur_new_embed + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat(( + cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + elif model_args.mm_use_im_patch_token: + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False \ No newline at end of file diff --git a/ola/model/speech_encoder/__pycache__/builder.cpython-310.pyc b/ola/model/speech_encoder/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b02014c4fe99c2a493d667fc6139144d9b7439c Binary files /dev/null and b/ola/model/speech_encoder/__pycache__/builder.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/__pycache__/builder.cpython-38.pyc b/ola/model/speech_encoder/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..929df7adeef4ac05860b00b89215f007764ef684 Binary files /dev/null and b/ola/model/speech_encoder/__pycache__/builder.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/__pycache__/speech_encoder.cpython-310.pyc b/ola/model/speech_encoder/__pycache__/speech_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ba86db536972672e30794df17a055ce33beff05 Binary files /dev/null and b/ola/model/speech_encoder/__pycache__/speech_encoder.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/__pycache__/speech_encoder.cpython-38.pyc b/ola/model/speech_encoder/__pycache__/speech_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96e740e65508082ae2d7e393f70849045ddc71fc Binary files /dev/null and b/ola/model/speech_encoder/__pycache__/speech_encoder.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/beats/BEATs.py b/ola/model/speech_encoder/beats/BEATs.py new file mode 100644 index 0000000000000000000000000000000000000000..9bbf36a82bc86705f851d479bd8acb5237c9e425 --- /dev/null +++ b/ola/model/speech_encoder/beats/BEATs.py @@ -0,0 +1,182 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +from torch.nn import LayerNorm +# import torchaudio.compliance.kaldi as ta_kaldi + +from .kaldi import fbank as kaldi_fbank + +from .backbone import ( + TransformerEncoder, +) + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class BEATsConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = -1 # path size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # label predictor + self.finetuned_model: bool = False # whether the model is a fine-tuned model. + self.predictor_dropout: float = 0.1 # dropout probability for the predictor + self.predictor_class: int = 527 # target class number for the predictor + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class BEATs(nn.Module): + def __init__( + self, + cfg: BEATsConfig, + ) -> None: + super().__init__() + logger.info(f"BEATs Config: {cfg.__dict__}") + + self.cfg = cfg + + self.embed = cfg.embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.input_patch_size = cfg.input_patch_size + self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, + bias=cfg.conv_bias) + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + assert not cfg.deep_norm or not cfg.layer_norm_first + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + if cfg.finetuned_model: + self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) + self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) + else: + self.predictor = None + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def preprocess( + self, + source: torch.Tensor, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ) -> torch.Tensor: + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2 ** 15 + fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + feature_only=False, + ): + fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(fbank, padding_mask) + + fbank = fbank.unsqueeze(1) + features = self.patch_embedding(fbank) + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + x = self.dropout_input(features) + + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + ) + + if not feature_only and self.predictor is not None: + x = self.predictor_dropout(x) + logits = self.predictor(x) + + if padding_mask is not None and padding_mask.any(): + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) + else: + logits = logits.mean(dim=1) + + lprobs = torch.sigmoid(logits) + + return lprobs, padding_mask + else: + return x, padding_mask \ No newline at end of file diff --git a/ola/model/speech_encoder/beats/Tokenizers.py b/ola/model/speech_encoder/beats/Tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..597c8902493b38689136b7153f114842c8fd66a3 --- /dev/null +++ b/ola/model/speech_encoder/beats/Tokenizers.py @@ -0,0 +1,174 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +from torch.nn import LayerNorm +# import torchaudio.compliance.kaldi as ta_kaldi + +from .kaldi import fbank as kaldi_fbank + +from .backbone import ( + TransformerEncoder, +) +from .quantizer import ( + NormEMAVectorQuantizer, +) + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class TokenizersConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = -1 # path size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # quantizer + self.quant_n: int = 1024 # codebook number in quantizer + self.quant_dim: int = 256 # codebook dimension in quantizer + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class Tokenizers(nn.Module): + def __init__( + self, + cfg: TokenizersConfig, + ) -> None: + super().__init__() + logger.info(f"Tokenizers Config: {cfg.__dict__}") + + self.cfg = cfg + + self.embed = cfg.embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.input_patch_size = cfg.input_patch_size + self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, + bias=cfg.conv_bias) + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + assert not cfg.deep_norm or not cfg.layer_norm_first + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.quantize = NormEMAVectorQuantizer( + n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99, + ) + self.quant_n = cfg.quant_n + self.quantize_layer = nn.Sequential( + nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim), + nn.Tanh(), + nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize + ) + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def preprocess( + self, + source: torch.Tensor, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ) -> torch.Tensor: + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2 ** 15 + fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + def extract_labels( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ): + fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(fbank, padding_mask) + + fbank = fbank.unsqueeze(1) + features = self.patch_embedding(fbank) + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + x = self.dropout_input(features) + + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + ) + + quantize_input = self.quantize_layer(x) + quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input) + + return embed_ind diff --git a/ola/model/speech_encoder/beats/__init__.py b/ola/model/speech_encoder/beats/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-310.pyc b/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94376c3689a9e92f4525a843f0ac988d274c4ba3 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-38.pyc b/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60f224b8561f7d5f1b353617e8cd7333c517215a Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-310.pyc b/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9beb2f573eb0c6c104f6cc14e6b431be226bddc5 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-38.pyc b/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98ac0a5133f4d84dd41f659f366bca6018097ec4 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-310.pyc b/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36069f9616f9fc29a194619564ae6913819a1d4e Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-38.pyc b/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b246688189857fd09dd839c319ae11206b8cab0d Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-310.pyc b/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd53d27f2f3476b6c400204e4d2176c35e07e2a2 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-38.pyc b/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2327bab2db3b6be7920bd6f133d2d9987364c020 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/modules.cpython-310.pyc b/ola/model/speech_encoder/beats/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e041b12f844521ec9a7b0be68423dec99486c91 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/modules.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/modules.cpython-38.pyc b/ola/model/speech_encoder/beats/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93d8b8f5392233e6e0e8216bc0b8e003988a975e Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/modules.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/beats/backbone.py b/ola/model/speech_encoder/beats/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..ef6ba72a47bb537e025d32580ab0b530a34ca4fb --- /dev/null +++ b/ola/model/speech_encoder/beats/backbone.py @@ -0,0 +1,782 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import numpy as np +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +import torch.nn.functional as F +from torch.nn import LayerNorm, Parameter +from .modules import ( + GradMultiply, + SamePad, + get_activation_fn, + GLU_Linear, + quant_noise, +) + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + deep_norm=args.deep_norm, + has_relative_attention_bias=self.relative_position_embedding, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + encoder_layers=args.encoder_layers, + ) + for i in range(args.encoder_layers) + ] + ) + if self.relative_position_embedding: + for i in range(1, args.encoder_layers): + del self.layers[i].self_attn.relative_attention_bias + self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + if args.deep_norm: + deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4) + for i in range(args.encoder_layers): + nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta) + + self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1) + + def forward(self, x, padding_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + if self.layer_wise_gradient_decay_ratio != 1.0: + x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio) + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + deep_norm: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + encoder_layers: int = 0, + ) -> None: + + super().__init__() + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + self.final_layer_norm = LayerNorm(self.embedding_dim) + + self.deep_norm = deep_norm + if self.deep_norm: + self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4) + else: + self.deep_norm_alpha = 1 + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual * self.deep_norm_alpha + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual * self.deep_norm_alpha + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + alpha = 32 + q *= 1 / alpha + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size()) + + attn_weights = attn_weights + attn_mask_rel_pos + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) \ No newline at end of file diff --git a/ola/model/speech_encoder/beats/kaldi.py b/ola/model/speech_encoder/beats/kaldi.py new file mode 100644 index 0000000000000000000000000000000000000000..f97fa85308e785af571bfdc912d6f12bb092e10e --- /dev/null +++ b/ola/model/speech_encoder/beats/kaldi.py @@ -0,0 +1,813 @@ +import math +from typing import Tuple + +import torch +# import torchaudio +from torch import Tensor + +__all__ = [ + "get_mel_banks", + "inverse_mel_scale", + "inverse_mel_scale_scalar", + "mel_scale", + "mel_scale_scalar", + "spectrogram", + "fbank", + "mfcc", + "vtln_warp_freq", + "vtln_warp_mel_freq", +] + +# numeric_limits::epsilon() 1.1920928955078125e-07 +EPSILON = torch.tensor(torch.finfo(torch.float).eps) +# 1 milliseconds = 0.001 seconds +MILLISECONDS_TO_SECONDS = 0.001 + +# window types +HAMMING = "hamming" +HANNING = "hanning" +POVEY = "povey" +RECTANGULAR = "rectangular" +BLACKMAN = "blackman" +WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] + + +def _get_epsilon(device, dtype): + return EPSILON.to(device=device, dtype=dtype) + + +def _next_power_of_2(x: int) -> int: + r"""Returns the smallest power of 2 that is greater than x""" + return 1 if x == 0 else 2 ** (x - 1).bit_length() + + +def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor: + r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``) + representing how the window is shifted along the waveform. Each row is a frame. + + Args: + waveform (Tensor): Tensor of size ``num_samples`` + window_size (int): Frame length + window_shift (int): Frame shift + snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. + + Returns: + Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame + """ + assert waveform.dim() == 1 + num_samples = waveform.size(0) + strides = (window_shift * waveform.stride(0), waveform.stride(0)) + + if snip_edges: + if num_samples < window_size: + return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device) + else: + m = 1 + (num_samples - window_size) // window_shift + else: + reversed_waveform = torch.flip(waveform, [0]) + m = (num_samples + (window_shift // 2)) // window_shift + pad = window_size // 2 - window_shift // 2 + pad_right = reversed_waveform + if pad > 0: + # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect' + # but we want [2, 1, 0, 0, 1, 2] + pad_left = reversed_waveform[-pad:] + waveform = torch.cat((pad_left, waveform, pad_right), dim=0) + else: + # pad is negative so we want to trim the waveform at the front + waveform = torch.cat((waveform[-pad:], pad_right), dim=0) + + sizes = (m, window_size) + return waveform.as_strided(sizes, strides) + + +def _feature_window_function( + window_type: str, + window_size: int, + blackman_coeff: float, + device: torch.device, + dtype: int, +) -> Tensor: + r"""Returns a window function with the given type and size""" + if window_type == HANNING: + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype) + elif window_type == HAMMING: + return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype) + elif window_type == POVEY: + # like hanning but goes to zero at edges + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85) + elif window_type == RECTANGULAR: + return torch.ones(window_size, device=device, dtype=dtype) + elif window_type == BLACKMAN: + a = 2 * math.pi / (window_size - 1) + window_function = torch.arange(window_size, device=device, dtype=dtype) + # can't use torch.blackman_window as they use different coefficients + return ( + blackman_coeff + - 0.5 * torch.cos(a * window_function) + + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function) + ).to(device=device, dtype=dtype) + else: + raise Exception("Invalid window type " + window_type) + + +def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor: + r"""Returns the log energy of size (m) for a strided_input (m,*)""" + device, dtype = strided_input.device, strided_input.dtype + log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) + if energy_floor == 0.0: + return log_energy + return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype)) + + +def _get_waveform_and_window_properties( + waveform: Tensor, + channel: int, + sample_frequency: float, + frame_shift: float, + frame_length: float, + round_to_power_of_two: bool, + preemphasis_coefficient: float, +) -> Tuple[Tensor, int, int, int]: + r"""Gets the waveform and window properties""" + channel = max(channel, 0) + assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0)) + waveform = waveform[channel, :] # size (n) + window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS) + window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS) + padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size + + assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format( + window_size, len(waveform) + ) + assert 0 < window_shift, "`window_shift` must be greater than 0" + assert padded_window_size % 2 == 0, ( + "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`" + ) + assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]" + assert sample_frequency > 0, "`sample_frequency` must be greater than zero" + return waveform, window_shift, window_size, padded_window_size + + +def _get_window( + waveform: Tensor, + padded_window_size: int, + window_size: int, + window_shift: int, + window_type: str, + blackman_coeff: float, + snip_edges: bool, + raw_energy: bool, + energy_floor: float, + dither: float, + remove_dc_offset: bool, + preemphasis_coefficient: float, +) -> Tuple[Tensor, Tensor]: + r"""Gets a window and its log energy + + Returns: + (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m) + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + # size (m, window_size) + strided_input = _get_strided(waveform, window_size, window_shift, snip_edges) + + if dither != 0.0: + rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype) + strided_input = strided_input + rand_gauss * dither + + if remove_dc_offset: + # Subtract each row/frame by its mean + row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1) + strided_input = strided_input - row_means + + if raw_energy: + # Compute the log energy of each row/frame before applying preemphasis and + # window function + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + if preemphasis_coefficient != 0.0: + # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j + offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze( + 0 + ) # size (m, window_size + 1) + strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1] + + # Apply window_function to each row/frame + window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze( + 0 + ) # size (1, window_size) + strided_input = strided_input * window_function # size (m, window_size) + + # Pad columns with zero until we reach size (m, padded_window_size) + if padded_window_size != window_size: + padding_right = padded_window_size - window_size + strided_input = torch.nn.functional.pad( + strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0 + ).squeeze(0) + + # Compute energy after window function (not the raw one) + if not raw_energy: + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + return strided_input, signal_log_energy + + +def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor: + # subtracts the column mean of the tensor size (m, n) if subtract_mean=True + # it returns size (m, n) + if subtract_mean: + col_means = torch.mean(tensor, dim=0).unsqueeze(0) + tensor = tensor - col_means + return tensor + + +def spectrogram( + waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + min_duration: float = 0.0, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + window_type: str = POVEY, +) -> Tensor: + r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's + compute-spectrogram-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A spectrogram identical to what Kaldi would output. The shape is + (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient + ) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0) + + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + + # size (m, padded_window_size // 2 + 1, 2) + fft = torch.fft.rfft(strided_input) + + # Convert the FFT into a power spectrum + power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1) + power_spectrum[:, 0] = signal_log_energy + + power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) + return power_spectrum + + +def inverse_mel_scale_scalar(mel_freq: float) -> float: + return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0) + + +def inverse_mel_scale(mel_freq: Tensor) -> Tensor: + return 700.0 * ((mel_freq / 1127.0).exp() - 1.0) + + +def mel_scale_scalar(freq: float) -> float: + return 1127.0 * math.log(1.0 + freq / 700.0) + + +def mel_scale(freq: Tensor) -> Tensor: + return 1127.0 * (1.0 + freq / 700.0).log() + + +def vtln_warp_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq: float, + high_freq: float, + vtln_warp_factor: float, + freq: Tensor, +) -> Tensor: + r"""This computes a VTLN warping function that is not the same as HTK's one, + but has similar inputs (this function has the advantage of never producing + empty bins). + + This function computes a warp function F(freq), defined between low_freq + and high_freq inclusive, with the following properties: + F(low_freq) == low_freq + F(high_freq) == high_freq + The function is continuous and piecewise linear with two inflection + points. + The lower inflection point (measured in terms of the unwarped + frequency) is at frequency l, determined as described below. + The higher inflection point is at a frequency h, determined as + described below. + If l <= f <= h, then F(f) = f/vtln_warp_factor. + If the higher inflection point (measured in terms of the unwarped + frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. + Since (by the last point) F(h) == h/vtln_warp_factor, then + max(h, h/vtln_warp_factor) == vtln_high_cutoff, so + h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). + = vtln_high_cutoff * min(1, vtln_warp_factor). + If the lower inflection point (measured in terms of the unwarped + frequency) is at l, then min(l, F(l)) == vtln_low_cutoff + This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) + = vtln_low_cutoff * max(1, vtln_warp_factor) + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + freq (Tensor): given frequency in Hz + + Returns: + Tensor: Freq after vtln warp + """ + assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq" + assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]" + l = vtln_low_cutoff * max(1.0, vtln_warp_factor) + h = vtln_high_cutoff * min(1.0, vtln_warp_factor) + scale = 1.0 / vtln_warp_factor + Fl = scale * l # F(l) + Fh = scale * h # F(h) + assert l > low_freq and h < high_freq + # slope of left part of the 3-piece linear function + scale_left = (Fl - low_freq) / (l - low_freq) + # [slope of center part is just "scale"] + + # slope of right part of the 3-piece linear function + scale_right = (high_freq - Fh) / (high_freq - h) + + res = torch.empty_like(freq) + + outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq + before_l = torch.lt(freq, l) # freq < l + before_h = torch.lt(freq, h) # freq < h + after_h = torch.ge(freq, h) # freq >= h + + # order of operations matter here (since there is overlapping frequency regions) + res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq) + res[before_h] = scale * freq[before_h] + res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq) + res[outside_low_high_freq] = freq[outside_low_high_freq] + + return res + + +def vtln_warp_mel_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq, + high_freq: float, + vtln_warp_factor: float, + mel_freq: Tensor, +) -> Tensor: + r""" + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + mel_freq (Tensor): Given frequency in Mel + + Returns: + Tensor: ``mel_freq`` after vtln warp + """ + return mel_scale( + vtln_warp_freq( + vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq) + ) + ) + + +def get_mel_banks( + num_bins: int, + window_length_padded: int, + sample_freq: float, + low_freq: float, + high_freq: float, + vtln_low: float, + vtln_high: float, + vtln_warp_factor: float, +) -> Tuple[Tensor, Tensor]: + """ + Returns: + (Tensor, Tensor): The tuple consists of ``bins`` (which is + melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is + center frequencies of bins of size (``num_bins``)). + """ + assert num_bins > 3, "Must have at least 3 mel bins" + assert window_length_padded % 2 == 0 + num_fft_bins = window_length_padded / 2 + nyquist = 0.5 * sample_freq + + if high_freq <= 0.0: + high_freq += nyquist + + assert ( + (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq) + ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist) + + # fft-bin width [think of it as Nyquist-freq / half-window-length] + fft_bin_width = sample_freq / window_length_padded + mel_low_freq = mel_scale_scalar(low_freq) + mel_high_freq = mel_scale_scalar(high_freq) + + # divide by num_bins+1 in next line because of end-effects where the bins + # spread out to the sides. + mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + + if vtln_high < 0.0: + vtln_high += nyquist + + assert vtln_warp_factor == 1.0 or ( + (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high) + ), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format( + vtln_low, vtln_high, low_freq, high_freq + ) + + bin = torch.arange(num_bins).unsqueeze(1) + left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1) + center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1) + right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1) + + if vtln_warp_factor != 1.0: + left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel) + center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel) + right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel) + + center_freqs = inverse_mel_scale(center_mel) # size (num_bins) + # size(1, num_fft_bins) + mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0) + + # size (num_bins, num_fft_bins) + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + + if vtln_warp_factor == 1.0: + # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values + bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) + else: + # warping can move the order of left_mel, center_mel, right_mel anywhere + bins = torch.zeros_like(up_slope) + up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel + down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel + bins[up_idx] = up_slope[up_idx] + bins[down_idx] = down_slope[down_idx] + + return bins, center_freqs + + +def fbank( + waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + use_log_fbank: bool = True, + use_power: bool = True, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> Tensor: + r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's + compute-fbank-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features + (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``) + use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``) + where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient + ) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0, device=device, dtype=dtype) + + # strided_input, size (m, padded_window_size) and signal_log_energy, size (m) + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + + # size (m, padded_window_size // 2 + 1) + spectrum = torch.fft.rfft(strided_input).abs() + if use_power: + spectrum = spectrum.pow(2.0) + + # size (num_mel_bins, padded_window_size // 2) + mel_energies, _ = get_mel_banks( + num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp + ) + mel_energies = mel_energies.to(device=device, dtype=dtype) + + # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1) + mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) + + # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) + mel_energies = torch.mm(spectrum, mel_energies.T) + if use_log_fbank: + # avoid log of zero (which should be prevented anyway by dithering) + mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() + + # if use_energy then add it as the last column for htk_compat == true else first column + if use_energy: + signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1) + # returns size (m, num_mel_bins + 1) + if htk_compat: + mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1) + else: + mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1) + + mel_energies = _subtract_column_mean(mel_energies, subtract_mean) + return mel_energies + + +def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor: + # returns a dct matrix of size (num_mel_bins, num_ceps) + # size (num_mel_bins, num_mel_bins) + dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho") + # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins) + # this would be the first column in the dct_matrix for torchaudio as it expects a + # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi + # expects a left multiply e.g. dct_matrix * vector). + dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins)) + dct_matrix = dct_matrix[:, :num_ceps] + return dct_matrix + + +def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor: + # returns size (num_ceps) + # Compute liftering coefficients (scaling on cepstral coeffs) + # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected. + i = torch.arange(num_ceps) + return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter) + + +def mfcc( + waveform: Tensor, + blackman_coeff: float = 0.42, + cepstral_lifter: float = 22.0, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + num_ceps: int = 13, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> Tensor: + r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's + compute-mfcc-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible + features (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``"povey"``) + + Returns: + Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``) + where m is calculated in _get_strided + """ + assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins) + + device, dtype = waveform.device, waveform.dtype + + # The mel_energies should not be squared (use_power=True), not have mean subtracted + # (subtract_mean=False), and use log (use_log_fbank=True). + # size (m, num_mel_bins + use_energy) + feature = fbank( + waveform=waveform, + blackman_coeff=blackman_coeff, + channel=channel, + dither=dither, + energy_floor=energy_floor, + frame_length=frame_length, + frame_shift=frame_shift, + high_freq=high_freq, + htk_compat=htk_compat, + low_freq=low_freq, + min_duration=min_duration, + num_mel_bins=num_mel_bins, + preemphasis_coefficient=preemphasis_coefficient, + raw_energy=raw_energy, + remove_dc_offset=remove_dc_offset, + round_to_power_of_two=round_to_power_of_two, + sample_frequency=sample_frequency, + snip_edges=snip_edges, + subtract_mean=False, + use_energy=use_energy, + use_log_fbank=True, + use_power=True, + vtln_high=vtln_high, + vtln_low=vtln_low, + vtln_warp=vtln_warp, + window_type=window_type, + ) + + if use_energy: + # size (m) + signal_log_energy = feature[:, num_mel_bins if htk_compat else 0] + # offset is 0 if htk_compat==True else 1 + mel_offset = int(not htk_compat) + feature = feature[:, mel_offset : (num_mel_bins + mel_offset)] + + # size (num_mel_bins, num_ceps) + dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device) + + # size (m, num_ceps) + feature = feature.matmul(dct_matrix) + + if cepstral_lifter != 0.0: + # size (1, num_ceps) + lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0) + feature *= lifter_coeffs.to(device=device, dtype=dtype) + + # if use_energy then replace the last column for htk_compat == true else first column + if use_energy: + feature[:, 0] = signal_log_energy + + if htk_compat: + energy = feature[:, 0].unsqueeze(1) # size (m, 1) + feature = feature[:, 1:] # size (m, num_ceps - 1) + if not use_energy: + # scale on C0 (actually removing a scale we previously added that's + # part of one common definition of the cosine transform.) + energy *= math.sqrt(2) + + feature = torch.cat((feature, energy), dim=1) + + feature = _subtract_column_mean(feature, subtract_mean) + return feature diff --git a/ola/model/speech_encoder/beats/modules.py b/ola/model/speech_encoder/beats/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..18e2d2066b93139acc9427f0edcdd96b12769f25 --- /dev/null +++ b/ola/model/speech_encoder/beats/modules.py @@ -0,0 +1,218 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +import torch +from torch import Tensor, nn +import torch.nn.functional as F + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module diff --git a/ola/model/speech_encoder/beats/quantizer.py b/ola/model/speech_encoder/beats/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..704be4c357bce7ee425ea2b6737b536333a5a63c --- /dev/null +++ b/ola/model/speech_encoder/beats/quantizer.py @@ -0,0 +1,215 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on VQGAN code bases +# https://github.com/CompVis/taming-transformers +# --------------------------------------------------------' + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as distributed + +try: + from einops import rearrange, repeat +except ImportError: + pass + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): + dim, dtype, device = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = l2norm(new_means) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''): + super().__init__() + self.num_tokens = num_tokens + self.codebook_dim = codebook_dim + self.decay = decay + self.eps = eps + if codebook_init_path == '': + if not kmeans_init: + weight = torch.randn(num_tokens, codebook_dim) + weight = l2norm(weight) + else: + weight = torch.zeros(num_tokens, codebook_dim) + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + else: + print(f"load init codebook weight from {codebook_init_path}") + codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu') + weight = codebook_ckpt_weight.clone() + self.register_buffer('initted', torch.Tensor([True])) + + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + # self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.update = True + + @torch.jit.ignore + def init_embed_(self, data): + if self.initted: + return + print("Performing Kemans init for codebook") + embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) + self.weight.data.copy_(embed) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + # normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) + self.weight.data.copy_(embed_normalized) + + +def norm_ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + moving_avg.data.copy_(l2norm(moving_avg.data)) + + +class NormEMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + statistic_code_usage=True, kmeans_init=False, codebook_init_path=''): + super().__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.decay = decay + + # learnable = True if orthogonal_reg_weight > 0 else False + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path) + + self.statistic_code_usage = statistic_code_usage + if statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(n_embed)) + if distributed.is_available() and distributed.is_initialized(): + print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!") + self.all_reduce_fn = distributed.all_reduce + else: + self.all_reduce_fn = nn.Identity() + + def reset_cluster_size(self, device): + if self.statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) + self.cluster_size = self.cluster_size.to(device) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + # z, 'b c h w -> b h w c' + # z = rearrange(z, 'b c h w -> b h w c') + # z = z.transpose(1, 2) + z = l2norm(z) + z_flattened = z.reshape(-1, self.codebook_dim) + + self.embedding.init_embed_(z_flattened) + + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + + if not self.training: + with torch.no_grad(): + cluster_size = encodings.sum(0) + self.all_reduce_fn(cluster_size) + ema_inplace(self.cluster_size, cluster_size, self.decay) + + if self.training and self.embedding.update: + # EMA cluster size + + bins = encodings.sum(0) + self.all_reduce_fn(bins) + + # self.embedding.cluster_size_ema_update(bins) + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = z_flattened.t() @ encodings + self.all_reduce_fn(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = l2norm(embed_normalized) + + embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, + embed_normalized) + norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q, 'b h w c -> b c h w' + # z_q = rearrange(z_q, 'b h w c -> b c h w') + # z_q = z_q.transpose(1, 2) + return z_q, loss, encoding_indices \ No newline at end of file diff --git a/ola/model/speech_encoder/builder.py b/ola/model/speech_encoder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b1fcb068ed79c1c90f1f0866bb72326f140de9 --- /dev/null +++ b/ola/model/speech_encoder/builder.py @@ -0,0 +1,11 @@ +from .speech_encoder import WhisperWrappedEncoder, DualWrappedEncoder + + +def build_speech_encoder(config): + speech_encoder_type = getattr(config, 'speech_encoder_type', None) + if "whisper" in speech_encoder_type.lower(): + return WhisperWrappedEncoder.load(config) + elif "dual" in speech_encoder_type.lower(): + return DualWrappedEncoder(config) + + raise ValueError(f'Unknown speech encoder: {speech_encoder_type}') diff --git a/ola/model/speech_encoder/speech_encoder.py b/ola/model/speech_encoder/speech_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..458dfca622df564e9fe889ccba3745ae71238219 --- /dev/null +++ b/ola/model/speech_encoder/speech_encoder.py @@ -0,0 +1,74 @@ +import types +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import WhisperFeatureExtractor +import whisper + +from ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs + +class WhisperWrappedEncoder: + + @classmethod + def load(cls, model_config): + + def replace_layer_norm(module): + from whisper.model import LayerNorm + for name, child in module.named_children(): + if isinstance(child, LayerNorm): + old_params = child.state_dict() + new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) + new_layer_norm.load_state_dict(old_params) + setattr(module, name, new_layer_norm) + else: + replace_layer_norm(child) + + encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder + replace_layer_norm(encoder) + return encoder + +class DualWrappedEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.whisper_model = self.load_whisper(config) + self.beats_model = self.load_beats(config) + + def load_whisper(cls, model_config): + + def replace_layer_norm(module): + from whisper.model import LayerNorm + for name, child in module.named_children(): + if isinstance(child, LayerNorm): + old_params = child.state_dict() + new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) + new_layer_norm.load_state_dict(old_params) + setattr(module, name, new_layer_norm) + else: + replace_layer_norm(child) + + encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder + replace_layer_norm(encoder) + return encoder + + def load_beats(cls, model_config): + beats_path = model_config.music_encoder + print("Loading BEATs Model") + beats_ckpt = torch.load(beats_path, map_location='cpu') + beats_cfg = BEATsConfig(beats_ckpt['cfg']) + beats = BEATs(beats_cfg) + beats.load_state_dict(beats_ckpt['model']) + return beats + + def forward(self, x, raw_wav=None, audio_padding_mask=None): + with torch.no_grad(): + self.beats_model = self.beats_model.float() + speech_embeds = self.whisper_model(x) + audio_embeds, _ = self.beats_model.extract_features(raw_wav.float(), padding_mask=audio_padding_mask, feature_only=True) + if audio_embeds.size(1) < speech_embeds.size(1): + audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) + elif audio_embeds.size(1) > speech_embeds.size(1): + speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) + speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) + speech_embeds = speech_embeds.to(torch.bfloat16) + return speech_embeds \ No newline at end of file diff --git a/ola/model/speech_generator/builder.py b/ola/model/speech_generator/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c93ef5963ee03ca2d90bb23abc6797aae2c7a6aa --- /dev/null +++ b/ola/model/speech_generator/builder.py @@ -0,0 +1,15 @@ +from .speech_generator import SpeechGeneratorCTC, SpeechGeneratorCTCQwen, SpeechGeneratorCEQwen, SpeechGeneratorCosyVoice + + +def build_speech_generator(config): + generator_type = getattr(config, 'speech_generator_type', 'ctc') + if generator_type == 'ctc': + return SpeechGeneratorCTC(config) + elif generator_type == 'ctc_qwen': + return SpeechGeneratorCTCQwen(config) + elif generator_type == 'ce_qwen': + return SpeechGeneratorCEQwen(config) + elif generator_type == 'cosy_qwen': + return SpeechGeneratorCosyVoice(config) + + raise ValueError(f'Unknown generator type: {generator_type}') diff --git a/ola/model/speech_generator/generation.py b/ola/model/speech_generator/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea10763ba7ce72cfab4dbfb97f499a3344a4a74 --- /dev/null +++ b/ola/model/speech_generator/generation.py @@ -0,0 +1,612 @@ +import copy +import torch +import inspect +import warnings +import numpy as np +import torch.nn as nn +from typing import Optional, Union, List, Callable +import torch.distributed as dist + +from transformers.generation.streamers import BaseStreamer +from transformers.generation.utils import ( + GenerationConfig, + GenerationMode, + LogitsProcessorList, + StoppingCriteriaList, + GenerateOutput, + GenerationMixin, + GenerateEncoderDecoderOutput, + GenerateDecoderOnlyOutput, + GenerateNonBeamOutput, + is_deepspeed_zero3_enabled, + is_torchdynamo_compiling, + NEED_SETUP_CACHE_CLASSES_MAPPING, + QUANT_BACKEND_CLASSES_MAPPING, + is_hqq_available, + QuantizedCacheConfig, + is_quanto_available, + DynamicCache, + EncoderDecoderCache, + logging +) +# from transformers.generation.stopping_criteria import validate_stopping_criteria + +logger = logging.get_logger(__name__) + + +class GenerationWithCTC(GenerationMixin): + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + streamer_unit: Optional["BaseStreamer"] = None, + streaming_unit_gen = False, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria + generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) + self._validate_model_kwargs(model_kwargs.copy()) + self._validate_assistant(assistant_model) + + # 2. Set generation parameters if not already defined + if synced_gpus is None: + if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: + synced_gpus = True + else: + synced_gpus = False + + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + + # 3. Define model inputs + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + + # decoder-only models must use left-padding for batched generation. + if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. + if ( + generation_config._pad_token_tensor is not None + and batch_size > 1 + and len(inputs_tensor.shape) == 2 + and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + # 4. Define other model kwargs + # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are + # generating the first new token or not, and we only want to use the embeddings for the first new token) + if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": + model_kwargs["use_cache"] = True + else: + model_kwargs["use_cache"] = generation_config.use_cache + + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name, generation_config + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=generation_config._decoder_start_token_tensor, + device=inputs_tensor.device, + ) + else: + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + + if generation_config.token_healing: + input_ids = self.heal_tokens(input_ids, tokenizer) + + if streamer is not None: + streamer.put(input_ids.cpu()) + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + + use_dynamic_cache_by_default = False + if "mamba" in self.__class__.__name__.lower(): + cache_name = "cache_params" + else: + cache_name = "past_key_values" + if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None): + raise ValueError( + f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + elif generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + model_kwargs[cache_name] = self._get_cache( + generation_config.cache_implementation, + getattr(generation_config, "num_beams", 1) * batch_size, + generation_config.max_length, + model_kwargs, + ) + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() + ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_quanto_available(): + raise ImportError( + "You need to install `quanto` in order to use KV cache quantization with quanto backend. " + "Please install it via with `pip install quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs[cache_name] = cache_class(cache_config) + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that + # keeps copying the cache thus using much more memory + elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): + past = model_kwargs.get(cache_name, None) + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + if past is None: + model_kwargs[cache_name] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) + use_dynamic_cache_by_default = True + elif isinstance(past, tuple): + model_kwargs[cache_name] = ( + DynamicCache.from_legacy_cache(past) + if not requires_cross_attention_cache + else EncoderDecoderCache.from_legacy_cache(past) + ) + use_dynamic_cache_by_default = True + + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 7. determine generation mode + generation_mode = generation_config.get_generation_mode(assistant_model) + + if (streamer is not None or streamer_unit is not None) and (generation_config.num_beams > 1): + raise ValueError( + "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + + # 8. prepare distribution pre_processing samplers + prepared_logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + device=inputs_tensor.device, + model_kwargs=model_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + # 9. prepare stopping criteria + prepared_stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs + ) + # 10. go into different generation modes + + if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): + # 11. prepare logits warper + prepared_logits_warper = ( + self._get_logits_warper(generation_config, device=input_ids.device) + if generation_config.do_sample + else None + ) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + if streaming_unit_gen: + return self._sample_streaming_unit( + input_ids, + logits_processor=prepared_logits_processor, + logits_warper=prepared_logits_warper, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + streamer_unit=streamer_unit, + **model_kwargs, + ) + else: + return self._sample( + input_ids, + logits_processor=prepared_logits_processor, + logits_warper=prepared_logits_warper, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + else: + raise NotImplementedError + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + logits_warper: Optional[LogitsProcessorList], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): + raise ValueError( + "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " + f"{logits_warper})." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size = input_ids.shape[0] + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + # forward pass to get next token + outputs = self(**model_inputs, return_dict=True) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone() + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + if do_sample: + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + def _sample_streaming_unit( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + streamer_unit: Optional["BaseStreamer"], + logits_warper: Optional[LogitsProcessorList], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): + raise ValueError( + "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " + f"{logits_warper})." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size = input_ids.shape[0] + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + generated_units = torch.tensor([]) + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + # forward pass to get next token + outputs = self(**model_inputs, return_dict=True) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone() + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + if do_sample: + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # speechgen + hidden_states = torch.cat([decoder_hidden_states[0][-1][:, -1:, :]] + [decoder_hidden_states[i][-1] for i in range(1, len(decoder_hidden_states))], dim=1) + ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0)) + cur_units = ctc_postprocess(ctc_pred, blank=self.model.config.unit_vocab_size) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + if streamer_unit is not None: + for i in range(len(generated_units), len(cur_units)): + streamer_unit.put(cur_units[i].unsqueeze(0)) + generated_units = cur_units + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + +def ctc_postprocess(tokens, blank): + _toks = tokens.squeeze(0).tolist() + deduplicated_toks = [v for i, v in enumerate(_toks) if i == 0 or v != _toks[i - 1]] + hyp = torch.tensor([v for v in deduplicated_toks if v != blank]) + return hyp \ No newline at end of file diff --git a/ola/model/speech_generator/speech_generator.py b/ola/model/speech_generator/speech_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..75a93a5332805ab01e2f37c1a59d5a8c25ef268a --- /dev/null +++ b/ola/model/speech_generator/speech_generator.py @@ -0,0 +1,468 @@ +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer +from omni_speech.constants import IGNORE_INDEX + +torch.autograd.set_detect_anomaly(True) + +try: + import sys + sys.path.append('/mnt/lzy/LLaMA-Omni/CosyVoice/') + from cosyvoice.cli.cosyvoice import CosyVoice +except: + print('CosyVoice not found') + +import os +if 'SPEECH_GEN_CONV_KERNEL' in os.environ: + SPEECH_GEN_CONV_KERNEL = int(os.environ['SPEECH_GEN_CONV_KERNEL']) + print(f'Using SPEECH_GEN_CONV_KERNEL={SPEECH_GEN_CONV_KERNEL}') +else: + SPEECH_GEN_CONV_KERNEL = -1 + +if 'DISTILL_EMBEDDING' in os.environ: + DISTILL_EMBEDDING = True + print(f'DISTILL_EMBEDDING is set.') +else: + DISTILL_EMBEDDING = False + +def lengths_to_padding_mask(lens): + bsz, max_lens = lens.size(0), torch.max(lens).item() + mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) + mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) + return mask + + +def _uniform_assignment(src_lens, tgt_lens): + tgt_indices = torch.arange(torch.max(tgt_lens)).expand(len(tgt_lens), -1).to(tgt_lens.device) + ratio = tgt_lens / src_lens + index_t = (tgt_indices / ratio.view(-1, 1)).long() + return index_t + + +class SpeechGeneratorCTC(nn.Module): + def __init__(self, config): + super().__init__() + n_layers, n_dims, n_heads, n_inter_dims = list(map(int, config.ctc_decoder_config[1:-1].split(","))) + _config = copy.deepcopy(config) + _config.hidden_size = n_dims + _config.num_hidden_layers = n_layers + _config.num_attention_heads = n_heads + _config.num_key_value_heads = n_heads + _config.intermediate_size = n_inter_dims + _config._attn_implementation = "flash_attention_2" + self.upsample_factor = config.ctc_upsample_factor + self.input_proj = nn.Linear(config.hidden_size, n_dims) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] + ) + self.unit_vocab_size = config.unit_vocab_size + self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) + + def upsample(self, reps, tgt_units=None): + src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) + up_lens = src_lens * self.upsample_factor + if tgt_units is not None: + tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) + up_lens = torch.max(up_lens, tgt_lens) + reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) + padding_mask = lengths_to_padding_mask(up_lens) + mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( + padding_mask, 0 + ) + copied_reps = torch.gather( + reps, + 1, + mapped_inputs.unsqueeze(-1).expand( + *mapped_inputs.size(), reps.size(-1) + ), + ) + copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) + position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) + return copied_reps, ~padding_mask, position_ids + + def forward(self, tgt_reps, labels, tgt_units): + tgt_label_reps = [] + for tgt_rep, label in zip(tgt_reps, labels): + tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) + hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) + hidden_states = self.input_proj(hidden_states) + for layer in self.layers: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = layer_outputs[0] + ctc_logits = self.output_proj(hidden_states) + ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) + ctc_lens = attention_mask.long().sum(dim=-1) + ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) + ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens) + ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask) + ctc_loss = F.ctc_loss( + ctc_lprobs.transpose(0, 1), + ctc_tgt_flat, + ctc_lens, + ctc_tgt_lens, + reduction="sum", + zero_infinity=True, + blank=self.unit_vocab_size + ) + ctc_loss /= ctc_tgt_lens.sum().item() + return ctc_loss + + def predict(self, tgt_reps): + hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) + hidden_states = self.input_proj(hidden_states) + for layer in self.layers: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = layer_outputs[0] + ctc_logits = self.output_proj(hidden_states) + ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) + ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) + return ctc_pred + + +class SpeechGeneratorCTCQwen(nn.Module): + def __init__(self, config): + super().__init__() + n_layers, n_dims, n_heads, n_inter_dims, n_kv_heads = list(map(int, config.ctc_decoder_config[1:-1].split(","))) + _config = copy.deepcopy(config) + _config.hidden_size = n_dims + _config.num_hidden_layers = n_layers + _config.num_attention_heads = n_heads + _config.num_key_value_heads = n_kv_heads + _config.intermediate_size = n_inter_dims + _config._attn_implementation = "flash_attention_2" + self.upsample_factor = config.ctc_upsample_factor + self.input_proj = nn.Linear(config.hidden_size, n_dims) + self.layers = nn.ModuleList( + [Qwen2DecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] + ) + self.unit_vocab_size = config.unit_vocab_size + self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) + + if SPEECH_GEN_CONV_KERNEL > 0: + self.temporal_conv = nn.Conv1d(n_dims, n_dims, SPEECH_GEN_CONV_KERNEL, padding=0) + self.learnable_pad_left = nn.Parameter(torch.zeros(SPEECH_GEN_CONV_KERNEL // 2, n_dims)) + self.learnable_pad_right = nn.Parameter(torch.zeros(SPEECH_GEN_CONV_KERNEL // 2, n_dims)) + # self.conv_layer_id = n_layers // 2 # Insert temporal conv layer in the middle of the decoder layers + + def upsample(self, reps, tgt_units=None): + src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) + up_lens = src_lens * self.upsample_factor + if tgt_units is not None: + tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) + up_lens = torch.max(up_lens, tgt_lens) + reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) + padding_mask = lengths_to_padding_mask(up_lens) + mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( + padding_mask, 0 + ) + copied_reps = torch.gather( + reps, + 1, + mapped_inputs.unsqueeze(-1).expand( + *mapped_inputs.size(), reps.size(-1) + ), + ) + copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) + position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) + return copied_reps, ~padding_mask, position_ids + + def forward(self, tgt_reps, labels, tgt_units): + tgt_label_reps = [] + for tgt_rep, label in zip(tgt_reps, labels): + if SPEECH_GEN_CONV_KERNEL > 0: + now_rep = tgt_rep[label != IGNORE_INDEX] + now_rep = torch.cat([self.learnable_pad_left, now_rep, self.learnable_pad_right], dim=0) + now_rep = self.input_proj(now_rep)[None] + now_rep = self.temporal_conv(now_rep.transpose(1, 2)).transpose(1, 2)[0] + tgt_label_reps.append(now_rep) + else: + tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) + + hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) + + if SPEECH_GEN_CONV_KERNEL < 0: + hidden_states = self.input_proj(hidden_states) + + for layer_id, layer in enumerate(self.layers): + # if SPEECH_GEN_CONV_KERNEL: + # if layer_id == self.conv_layer_id: + # hidden_states = self.temporal_conv(hidden_states.transpose(1, 2)).transpose(1, 2) + + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = layer_outputs[0] + + ctc_logits = self.output_proj(hidden_states) + ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) + ctc_lens = attention_mask.long().sum(dim=-1) + ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) + ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens) + ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask) + ctc_loss = F.ctc_loss( + ctc_lprobs.transpose(0, 1), + ctc_tgt_flat, + ctc_lens, + ctc_tgt_lens, + reduction="sum", + zero_infinity=True, + blank=self.unit_vocab_size + ) + ctc_loss /= ctc_tgt_lens.sum().item() + return ctc_loss + + def predict(self, tgt_reps): + hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) + hidden_states = self.input_proj(hidden_states) + for layer in self.layers: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = layer_outputs[0] + ctc_logits = self.output_proj(hidden_states) + ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) + ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) + return ctc_pred + + +class SpeechGeneratorCEQwen(nn.Module): + def __init__(self, config): + super().__init__() + n_layers, n_dims, n_heads, n_inter_dims, n_kv_heads = list(map(int, config.ctc_decoder_config[1:-1].split(","))) + _config = copy.deepcopy(config) + _config.hidden_size = n_dims + _config.num_hidden_layers = n_layers + _config.num_attention_heads = n_heads + _config.num_key_value_heads = n_kv_heads + _config.intermediate_size = n_inter_dims + _config._attn_implementation = "flash_attention_2" + self.upsample_factor = 1 + self.input_proj = nn.Linear(config.hidden_size, n_dims) + self.layers = nn.ModuleList( + [Qwen2DecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] + ) + self.unit_vocab_size = config.unit_vocab_size + self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) + + def upsample(self, reps, tgt_units=None): + src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) + up_lens = src_lens * self.upsample_factor + if tgt_units is not None: + tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) + up_lens = torch.max(up_lens, tgt_lens) + reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) + padding_mask = lengths_to_padding_mask(up_lens) + mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( + padding_mask, 0 + ) + copied_reps = torch.gather( + reps, + 1, + mapped_inputs.unsqueeze(-1).expand( + *mapped_inputs.size(), reps.size(-1) + ), + ) + copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) + position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) + return copied_reps, ~padding_mask, position_ids + + def forward(self, tgt_reps, labels, tgt_units): + tgt_label_reps = [] + for tgt_rep, label in zip(tgt_reps, labels): + # if SPEECH_GEN_CONV_KERNEL > 0: + # now_rep = tgt_rep[label != IGNORE_INDEX] + # now_rep = torch.cat([self.learnable_pad_left, now_rep, self.learnable_pad_right], dim=0) + # now_rep = self.input_proj(now_rep)[None] + # now_rep = self.temporal_conv(now_rep.transpose(1, 2)).transpose(1, 2)[0] + # tgt_label_reps.append(now_rep) + # else: + tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) + hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) + # if SPEECH_GEN_CONV_KERNEL < 0: + hidden_states = self.input_proj(hidden_states) + + for layer_id, layer in enumerate(self.layers): + # if SPEECH_GEN_CONV_KERNEL: + # if layer_id == self.conv_layer_id: + # hidden_states = self.temporal_conv(hidden_states.transpose(1, 2)).transpose(1, 2) + + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = layer_outputs[0] + + shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_states.size(-1)) + logits = self.output_proj(shift_hidden_states) + shift_labels = tgt_units[..., 1:].contiguous().reshape(-1) + assert shift_labels.size(0) == shift_hidden_states.size(0) + loss_fct = nn.CrossEntropyLoss() + logits = logits.float() + loss = loss_fct(logits, shift_labels) + # loss = (loss / 1.0).sum().item() + # loss = loss.sum().item() + return loss + + # def predict(self, tgt_reps): + # hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) + # hidden_states = self.input_proj(hidden_states) + # for layer in self.layers: + # layer_outputs = layer( + # hidden_states, + # attention_mask=attention_mask, + # position_ids=position_ids, + # ) + # hidden_states = layer_outputs[0] + # ctc_logits = self.output_proj(hidden_states) + # ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) + # ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) + # return ctc_pred + +# class SpeechGeneratorCosyVoice(nn.Module): +# def __init__(self, config): +# super().__init__() +# self.input_proj = nn.Sequential( +# nn.Linear(config.hidden_size, 1024), +# nn.GELU(), +# nn.Linear(1024, 512) +# ) +# self.cosyvoice1 = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_onnx=False, fp16=False) +# self.cosyvoice = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True) +# self.llm = self.cosyvoice1.model.llm + +# if DISTILL_EMBEDDING: +# self.criterion = nn.CosineEmbeddingLoss() + +# def forward(self, tgt_reps, labels, answer): +# tgt_label_reps = [] +# batch_speech_tokens = [] +# embeddings = [] +# target_embeddings = [] +# if DISTILL_EMBEDDING: +# for tgt_rep, label, ans in zip(tgt_reps, labels, answer): +# # make all label id in [151644,151645,198] to IGNORE_INDEX +# label[label == 151644] = IGNORE_INDEX +# label[label == 151645] = IGNORE_INDEX +# label[label == 198] = IGNORE_INDEX +# tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) +# normalized_text = self.cosyvoice1.frontend.text_normalize(ans, split=True) +# tts_text_token_all = [] +# for norm_text in normalized_text: +# tts_text_token, tts_text_token_len = self.cosyvoice1.frontend._extract_text_token(norm_text) +# tts_text_token_all.append(tts_text_token) +# tts_text_token_all = torch.cat(tts_text_token_all, dim=0) +# target_embedding = self.cosyvoice1.model.llm.text_embedding(tts_text_token) +# target_embeddings.append(target_embedding) +# import pdb;pdb.set_trace() +# tgt_label_reps = torch.stack(tgt_label_reps) +# target_embeddings = torch.stack(target_embeddings).squeeze(1) +# hidden_states = self.input_proj(tgt_label_reps).reshape(-1, 512) +# target_embeddings = target_embeddings.reshape(-1, 512) +# loss = self.criterion(hidden_states, target_embeddings, torch.ones(hidden_states.size(0)).to(hidden_states.device)) +# else: +# for tgt_rep, label, ans in zip(tgt_reps, labels, answer): +# # make all label id in [151644,151645,198] to IGNORE_INDEX +# label[label == 151644] = IGNORE_INDEX +# label[label == 151645] = IGNORE_INDEX +# label[label == 198] = IGNORE_INDEX +# tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) +# speech_token = self.cosyvoice.inference_label(ans, 'θ‹±ζ–‡ε₯³', stream=False) +# speech_tokens = [] +# for i,j in enumerate(speech_token): +# speech_tokens.append(j['tts_speech_token'].squeeze(0)) +# speech_tokens.append(torch.tensor([0])) +# speech_tokens = torch.cat(speech_tokens, dim=0) +# if speech_tokens.size(0) > 1: +# speech_tokens = speech_tokens[:-1] +# batch_speech_tokens.append(speech_tokens) +# embedding = self.cosyvoice.frontend.frontend_embedding('θ‹±ζ–‡ε₯³') +# embeddings.append(embedding['llm_embedding'].squeeze(0)) + +# tgt_label_reps = torch.stack(tgt_label_reps) +# batch_speech_token = torch.stack(batch_speech_tokens) +# embeddings = torch.stack(embeddings) +# hidden_states = self.input_proj(tgt_label_reps) +# batch = {'text_feature': hidden_states, 'text_token_len': torch.tensor([hidden_states.size(1)]).repeat(hidden_states.size(0)), +# 'speech_token': batch_speech_token, 'speech_token_len': torch.tensor([batch_speech_token.size(1)]).repeat(hidden_states.size(0)), +# 'embedding': embeddings} +# output = self.llm.forward_ours(batch, 'cuda') +# loss = output['loss'] +# return loss + +class SpeechGeneratorCosyVoice(nn.Module): + def __init__(self, config): + super().__init__() + self.cosyvoice = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True) + + def forward(self, tgt_reps, labels, answer): + tgt_label_reps = [] + batch_speech_tokens = [] + embeddings = [] + target_embeddings = [] + if DISTILL_EMBEDDING: + for tgt_rep, label, ans in zip(tgt_reps, labels, answer): + # make all label id in [151644,151645,198] to IGNORE_INDEX + label[label == 151644] = IGNORE_INDEX + label[label == 151645] = IGNORE_INDEX + label[label == 198] = IGNORE_INDEX + tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) + normalized_text = self.cosyvoice1.frontend.text_normalize(ans, split=True) + tts_text_token_all = [] + for norm_text in normalized_text: + tts_text_token, tts_text_token_len = self.cosyvoice1.frontend._extract_text_token(norm_text) + tts_text_token_all.append(tts_text_token) + tts_text_token_all = torch.cat(tts_text_token_all, dim=0) + target_embedding = self.cosyvoice1.model.llm.text_embedding(tts_text_token) + target_embeddings.append(target_embedding) + import pdb;pdb.set_trace() + tgt_label_reps = torch.stack(tgt_label_reps) + target_embeddings = torch.stack(target_embeddings).squeeze(1) + hidden_states = self.input_proj(tgt_label_reps).reshape(-1, 512) + target_embeddings = target_embeddings.reshape(-1, 512) + loss = self.criterion(hidden_states, target_embeddings, torch.ones(hidden_states.size(0)).to(hidden_states.device)) + else: + for tgt_rep, label, ans in zip(tgt_reps, labels, answer): + # make all label id in [151644,151645,198] to IGNORE_INDEX + label[label == 151644] = IGNORE_INDEX + label[label == 151645] = IGNORE_INDEX + label[label == 198] = IGNORE_INDEX + tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) + speech_token = self.cosyvoice.inference_label(ans, 'θ‹±ζ–‡ε₯³', stream=False) + speech_tokens = [] + for i,j in enumerate(speech_token): + speech_tokens.append(j['tts_speech_token'].squeeze(0)) + speech_tokens.append(torch.tensor([0])) + speech_tokens = torch.cat(speech_tokens, dim=0) + if speech_tokens.size(0) > 1: + speech_tokens = speech_tokens[:-1] + batch_speech_tokens.append(speech_tokens) + embedding = self.cosyvoice.frontend.frontend_embedding('θ‹±ζ–‡ε₯³') + embeddings.append(embedding['llm_embedding'].squeeze(0)) + + tgt_label_reps = torch.stack(tgt_label_reps) + batch_speech_token = torch.stack(batch_speech_tokens) + embeddings = torch.stack(embeddings) + hidden_states = self.input_proj(tgt_label_reps) + batch = {'text_feature': hidden_states, 'text_token_len': torch.tensor([hidden_states.size(1)]).repeat(hidden_states.size(0)), + 'speech_token': batch_speech_token, 'speech_token_len': torch.tensor([batch_speech_token.size(1)]).repeat(hidden_states.size(0)), + 'embedding': embeddings} + output = self.llm.forward_ours(batch, 'cuda') + loss = output['loss'] + return loss \ No newline at end of file diff --git a/ola/model/speech_projector/__pycache__/builder.cpython-310.pyc b/ola/model/speech_projector/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..103404d74996ff02bd22ca9e37daeb8624916cc3 Binary files /dev/null and b/ola/model/speech_projector/__pycache__/builder.cpython-310.pyc differ diff --git a/ola/model/speech_projector/__pycache__/builder.cpython-38.pyc b/ola/model/speech_projector/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f691ff24734044c18b82d18e506eab817644be9f Binary files /dev/null and b/ola/model/speech_projector/__pycache__/builder.cpython-38.pyc differ diff --git a/ola/model/speech_projector/__pycache__/speech_projector.cpython-310.pyc b/ola/model/speech_projector/__pycache__/speech_projector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d775faac13531f2c656f872c48149be7772ae388 Binary files /dev/null and b/ola/model/speech_projector/__pycache__/speech_projector.cpython-310.pyc differ diff --git a/ola/model/speech_projector/__pycache__/speech_projector.cpython-38.pyc b/ola/model/speech_projector/__pycache__/speech_projector.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66bc5faab4c42b9be28ae1b387f39674686807a3 Binary files /dev/null and b/ola/model/speech_projector/__pycache__/speech_projector.cpython-38.pyc differ diff --git a/ola/model/speech_projector/builder.py b/ola/model/speech_projector/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..5c5959e229de1e43544dae3ae3a333693ba3e824 --- /dev/null +++ b/ola/model/speech_projector/builder.py @@ -0,0 +1,9 @@ +from .speech_projector import EncoderProjectorConcat + + +def build_speech_projector(config): + projector_type = getattr(config, 'speech_projector_type', 'linear') + if projector_type == 'linear': + return EncoderProjectorConcat(config) + + raise ValueError(f'Unknown projector type: {projector_type}') diff --git a/ola/model/speech_projector/speech_projector.py b/ola/model/speech_projector/speech_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0156065bc238807c57918aeeed48df575ff533 --- /dev/null +++ b/ola/model/speech_projector/speech_projector.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import math + +class EncoderProjectorConcat(nn.Module): + def __init__(self, config): + super().__init__() + self.k = config.speech_encoder_ds_rate + self.encoder_dim = config.speech_encoder_hidden_size + self.llm_dim = config.hidden_size + self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(2048, config.hidden_size) + + embed_std = 1 / math.sqrt(config.hidden_size) + self.speech_newline = nn.Parameter( + torch.randn(config.hidden_size) * embed_std + ) + self.speech_begin = nn.Parameter( + torch.randn(config.hidden_size) * embed_std + ) + self.speech_end = nn.Parameter( + torch.randn(config.hidden_size) * embed_std + ) + + def forward(self, x): + batch_size, seq_len, dim = x.size() + num_frames_to_discard = seq_len % self.k + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view(batch_size, seq_len // self.k, dim * self.k) + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + x = torch.cat([ + x, + self.speech_newline.reshape(1, 1, -1).expand(batch_size, 1, -1).to(x.dtype) + ], dim=1) + begin = self.speech_begin.reshape(1, -1).to(x.dtype) + end = self.speech_end.reshape(1, -1).to(x.dtype) + x = x.flatten(0, 1) + x = torch.cat([begin, x, end], dim=0) + # x = x.flatten(0, 1) + return x \ No newline at end of file diff --git a/ola/serve_ola/__init__.py b/ola/serve_ola/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ola/serve_ola/cli.py b/ola/serve_ola/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..86b31b6da7e8d56973ffb77b86e48b3c91cd6528 --- /dev/null +++ b/ola/serve_ola/cli.py @@ -0,0 +1,123 @@ +import argparse +import torch + +import sys +sys.path.append('/mnt/lzy/Ola') + +from ola.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from ola.conversation import conv_templates, SeparatorStyle +from ola.model.builder import load_pretrained_model +from ola.utils import disable_torch_init +from ola.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image + +import requests +from PIL import Image +from io import BytesIO +from transformers import TextStreamer + + +def load_image(image_file): + if image_file.startswith('http') or image_file.startswith('https'): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert('RGB') + else: + image = Image.open(image_file).convert('RGB') + return image + + +def main(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) + + # if 'llama-2' in model_name.lower(): + # conv_mode = "llava_llama_2" + # elif "v1" in model_name.lower(): + # conv_mode = "llava_v1" + # elif "mpt" in model_name.lower(): + # conv_mode = "mpt" + # else: + # conv_mode = "llava_v0" + conv_mode = "qwen_1_5" + + if args.conv_mode is not None and conv_mode != args.conv_mode: + print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) + else: + args.conv_mode = conv_mode + + conv = conv_templates[args.conv_mode].copy() + if "mpt" in model_name.lower(): + roles = ('user', 'assistant') + else: + roles = conv.roles + + image = load_image(args.image_file) + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() + + while True: + try: + inp = input(f"{roles[0]}: ") + except EOFError: + inp = "" + if not inp: + print("exit...") + break + + print(f"{roles[1]}: ", end="") + + if image is not None: + # first message + if model.config.mm_use_im_start_end: + inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp + else: + inp = DEFAULT_IMAGE_TOKEN + '\n' + inp + conv.append_message(conv.roles[0], inp) + image = None + else: + # later messages + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensor, + do_sample=True, + temperature=0.2, + max_new_tokens=1024, + streamer=streamer, + use_cache=True, + stopping_criteria=[stopping_criteria]) + + outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() + conv.messages[-1][-1] = outputs + + if args.debug: + print("\n", {"prompt": prompt, "outputs": outputs}, "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image-file", type=str, required=True) + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + main(args) diff --git a/ola/serve_ola/controller.py b/ola/serve_ola/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..73d6b3c36d2a4452691f8ab53e14613018560f61 --- /dev/null +++ b/ola/serve_ola/controller.py @@ -0,0 +1,314 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" +import os, sys +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['REGIONAL_POOL'] = '2x' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['SKIP_LOAD_VIT'] = '1' + + +sys.path.append('/mnt/lzy/Ola') + +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from ola.constants import CONTROLLER_HEART_BEAT_EXPIRATION +from ola.utils import build_logger, server_error_msg + + +logger = build_logger("controller", "controller.log") + + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stable_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread( + target=heart_beat_controller, args=(self,)) + self.heart_beat_thread.start() + + logger.info("Init controller") + + def register_worker(self, worker_name: str, check_heart_beat: bool, + worker_status: dict): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo( + worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], + check_heart_beat, time.time()) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker(w_name, w_info.check_heart_beat, None): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), + p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), + p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stable_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + logger.info(f"no worker: {params['model']}") + ret = { + "text": server_error_msg, + "error_code": 2, + } + yield json.dumps(ret).encode() + b"\0" + + try: + response = requests.post(worker_addr + "/worker_generate_stream", + json=params, stream=True, timeout=5) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + logger.info(f"worker timeout: {worker_addr}") + ret = { + "text": server_error_msg, + "error_code": 3, + } + yield json.dumps(ret).encode() + b"\0" + + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + return { + "model_names": list(model_names), + "speed": speed, + "queue_length": queue_length, + } + + +app = FastAPI() + + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker( + data["worker_name"], data["check_heart_beat"], + data.get("worker_status", None)) + + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat( + data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=12345) + parser.add_argument("--dispatch-method", type=str, choices=[ + "lottery", "shortest_queue"], default="shortest_queue") + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/ola/serve_ola/examples/extreme_ironing.jpg b/ola/serve_ola/examples/extreme_ironing.jpg new file mode 100644 index 0000000000000000000000000000000000000000..638b078837f175039b2db49a63821288d9681daa Binary files /dev/null and b/ola/serve_ola/examples/extreme_ironing.jpg differ diff --git a/ola/serve_ola/examples/waterview.jpg b/ola/serve_ola/examples/waterview.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6f44ebaba1aa493b8bab3baa4e827b76752b1869 Binary files /dev/null and b/ola/serve_ola/examples/waterview.jpg differ diff --git a/ola/serve_ola/gradio_web_server.py b/ola/serve_ola/gradio_web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..66956db1475d2edbd191cff3561c21e37d5f0851 --- /dev/null +++ b/ola/serve_ola/gradio_web_server.py @@ -0,0 +1,516 @@ +import os, sys +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['REGIONAL_POOL'] = '2x' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['SKIP_LOAD_VIT'] = '1' + +sys.path.append('/mnt/lzy/Ola') + + + +import argparse +import datetime +import json +import os +import time + +import gradio as gr +import requests + +from ola.conversation import (default_conversation, conv_templates,SeparatorStyle) +from ola.constants import LOGDIR +from ola.utils import (build_logger, server_error_msg, + violates_moderation, moderation_msg) +import hashlib + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +headers = {"User-Agent": "Oryx"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=False) + +priority = { + "vicuna-13b": "aaaaaaa", + "koala-13b": "aaaaaab", +} + + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def get_model_list(): + # ret = requests.post(args.controller_url + "/refresh_all_workers") + # assert ret.status_code == 200 + # ret = requests.post(args.controller_url + "/list_models") + # models = ret.json()["models"] + # models.sort(key=lambda x: priority.get(x, x)) + models = [ + 'ola-7b' + ] + logger.info(f"Models: {models}") + return models + + +get_window_url_params = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log(url_params); + return url_params; + } +""" + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update( + value=model, visible=True) + + state = default_conversation.copy() + return state, dropdown_update + + +def load_demo_refresh_model_list(request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}") + models = get_model_list() + state = default_conversation.copy() + dropdown_update = gr.Dropdown.update( + choices=models, + value=models[0] if len(models) > 0 else "" + ) + return state, dropdown_update + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + +def upvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"upvote. ip: {request.client.host}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"downvote. ip: {request.client.host}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + logger.info(f"flag. ip: {request.client.host}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def regenerate(state, image_process_mode, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = default_conversation.copy() + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + + +def add_text(state, text, image, image_process_mode, request: gr.Request): + logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") + if len(text) <= 0 and image is None: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 + if args.moderate: + flagged = violates_moderation(text) + if flagged: + state.skip_next = True + return (state, state.to_gradio_chatbot(), moderation_msg, None) + ( + no_change_btn,) * 5 + + text = text[:1536] # Hard cut-off + if image is not None: + text = text[:1200] # Hard cut-off for images + if '' not in text: + # text = '' + text + text = text + '\n' + text = (text, image, image_process_mode) + if len(state.get_images(return_pil=True)) > 0: + state = default_conversation.copy() + state.append_message(state.roles[0], text) + state.append_message(state.roles[1], None) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + + +def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request): + logger.info(f"http_bot. ip: {request.client.host}") + start_tstamp = time.time() + model_name = model_selector + + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + if len(state.messages) == state.offset + 2: + # First round of conversation + # if "llava" in model_name.lower(): + # if 'llama-2' in model_name.lower(): + # template_name = "llava_llama_2" + # elif "mistral" in model_name.lower() or "mixtral" in model_name.lower(): + # if 'orca' in model_name.lower(): + # template_name = "mistral_orca" + # elif 'hermes' in model_name.lower(): + # template_name = "mistral_direct" + # else: + # template_name = "mistral_instruct" + # elif "zephyr" in model_name.lower(): + # template_name = "mistral_zephyr" + # elif 'hermes' in model_name.lower(): + # template_name = "mistral_direct" + # elif "v1" in model_name.lower(): + # if 'mmtag' in model_name.lower(): + # template_name = "v1_mmtag" + # elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): + # template_name = "v1_mmtag" + # else: + # template_name = "llava_v1" + # elif "mpt" in model_name.lower(): + # template_name = "mpt" + # else: + # if 'mmtag' in model_name.lower(): + # template_name = "v0_mmtag" + # elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): + # template_name = "v0_mmtag" + # else: + # template_name = "llava_v0" + # elif "mistral" in model_name.lower() or "mixtral" in model_name.lower(): + # if 'orca' in model_name.lower(): + # template_name = "mistral_orca" + # elif 'hermes' in model_name.lower(): + # template_name = "mistral_direct" + # else: + # template_name = "mistral_instruct" + # elif 'hermes' in model_name.lower(): + # template_name = "mistral_direct" + # elif "zephyr" in model_name.lower(): + # template_name = "mistral_zephyr" + # elif "mpt" in model_name: + # template_name = "mpt_text" + # elif "llama-2" in model_name: + # template_name = "llama_2" + # else: + # template_name = "vicuna_v1" + template_name = 'qwen_1_5' + new_state = conv_templates[template_name].copy() + new_state.append_message(new_state.roles[0], state.messages[-2][1]) + new_state.append_message(new_state.roles[1], None) + state = new_state + + # Query worker address + # controller_url = args.controller_url + # ret = requests.post(controller_url + "/get_worker_address", + # json={"model": model_name}) + # worker_addr = ret.json()["address"] + # logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # # No available worker + # if worker_addr == "": + # state.messages[-1][-1] = server_error_msg + # yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + # return + + worker_addr = args.controller_url + + # Construct prompt + prompt = state.get_prompt() + + all_images = state.get_images(return_pil=True) + all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] + for image, hash in zip(all_images, all_image_hash): + t = datetime.datetime.now() + filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") + if not os.path.isfile(filename): + os.makedirs(os.path.dirname(filename), exist_ok=True) + image.save(filename) + + # Make requests + pload = { + "model": model_name, + "prompt": prompt, + "temperature": float(temperature), + "top_p": float(top_p), + "max_new_tokens": min(int(max_new_tokens), 1536), + "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, + "images": f'List of {len(state.get_images())} images: {all_image_hash}', + } + logger.info(f"==== request ====\n{pload}") + + pload['images'] = state.get_images() + + state.messages[-1][-1] = "β–Œ" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + # Stream output + response = requests.post(worker_addr + "/worker_generate_stream", + headers=headers, json=pload, stream=True, timeout=100) + last_print_time = time.time() + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][len(prompt):].strip() + state.messages[-1][-1] = output + "β–Œ" + if time.time() - last_print_time > 0.05: + last_print_time = time.time() + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f" (error_code: {data['error_code']})" + state.messages[-1][-1] = output + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + time.sleep(0.03) + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(start_tstamp, 4), + "state": state.dict(), + "images": all_image_hash, + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + +title_markdown = (""" +# Oryx 7B Chatbot +""") + +tos_markdown = (""" +### Terms of use +By using this service, users are required to agree to the following terms: +The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. +Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. +For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. +""") + + +learn_more_markdown = (""" +### License +# """) + +block_css = """ + +#buttons button { + min-width: min(120px,100%); +} + +""" + +def build_demo(embed_mode): + logger.info(f"build_demo. embed_mode: {embed_mode}") + textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) + with gr.Blocks(title="Oryx", theme=gr.themes.Default(), css=block_css) as demo: + state = gr.State() + + if not embed_mode: + gr.Markdown(title_markdown) + + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + container=False) + + imagebox = gr.Image(type="pil") + image_process_mode = gr.Radio( + ["Crop", "Resize", "Pad", "Default"], + value="Default", + label="Preprocess for non-square image", visible=False) + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + gr.Examples(examples=[ + [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"], + [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], + ], inputs=[imagebox, textbox]) + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) + top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) + max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) + + with gr.Column(scale=8): + chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550) + with gr.Row(): + with gr.Column(scale=8): + textbox.render() + with gr.Column(scale=1, min_width=50): + submit_btn = gr.Button(value="Send", variant="primary") + with gr.Row(elem_id="buttons") as button_row: + upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False) + downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) + regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False) + clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False) + + if not embed_mode: + gr.Markdown(tos_markdown) + gr.Markdown(learn_more_markdown) + url_params = gr.JSON(visible=False) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + queue=False + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + queue=False + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + queue=False + ) + + regenerate_btn.click( + regenerate, + [state, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + queue=False + ).then( + http_bot, + [state, model_selector, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list + ) + + clear_btn.click( + clear_history, + None, + [state, chatbot, textbox, imagebox] + btn_list, + queue=False + ) + + textbox.submit( + add_text, + [state, textbox, imagebox, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + queue=False + ).then( + http_bot, + [state, model_selector, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list + ) + + submit_btn.click( + add_text, + [state, textbox, imagebox, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + queue=False + ).then( + http_bot, + [state, model_selector, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list + ) + + if args.model_list_mode == "once": + demo.load( + load_demo, + [url_params], + [state, model_selector], + _js=get_window_url_params, + queue=False + ) + elif args.model_list_mode == "reload": + demo.load( + load_demo_refresh_model_list, + None, + [state, model_selector], + queue=False + ) + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str, default="http://localhost:21002") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--model-list-mode", type=str, default="once", + choices=["once", "reload"]) + parser.add_argument("--share", action="store_true") + parser.add_argument("--moderate", action="store_true") + parser.add_argument("--embed", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + models = get_model_list() + + logger.info(args) + demo = build_demo(args.embed) + demo.queue( + concurrency_count=args.concurrency_count, + api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share + ) \ No newline at end of file diff --git a/ola/serve_ola/model_worker.py b/ola/serve_ola/model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..73fd69ead1a66c03dfb52ecb7fb33ce4ee3c54fb --- /dev/null +++ b/ola/serve_ola/model_worker.py @@ -0,0 +1,320 @@ +""" +A model worker executes the model. +""" +import os, sys +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['REGIONAL_POOL'] = '2x' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['SKIP_LOAD_VIT'] = '1' + +sys.path.append('/mnt/lzy/Ola') + +import argparse +import asyncio +import json +import time +import threading +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +import requests +import torch +import uvicorn +from functools import partial + +from ola.constants import WORKER_HEART_BEAT_INTERVAL +from ola.utils import (build_logger, server_error_msg, + pretty_print_semaphore) +from ola.model.builder import load_pretrained_model +from ola.mm_utils import process_anyres_highres_image_genli, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria +from ola.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from transformers import TextIteratorStreamer +from threading import Thread + + +GB = 1 << 30 + +worker_id = str(uuid.uuid4())[:6] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") +global_counter = 0 + +model_semaphore = None + + +def heart_beat_worker(controller): + + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + controller.send_heart_beat() + + +class ModelWorker: + def __init__(self, controller_addr, worker_addr, + worker_id, no_register, + model_path, model_base, model_name, + load_8bit, load_4bit): + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + if model_name is None: + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + self.model_name = model_paths[-2] + "_" + model_paths[-1] + else: + self.model_name = model_paths[-1] + else: + self.model_name = model_name + + logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") + self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( + model_path, None, self.model_name, load_8bit, load_4bit, device_map='cuda:0') + self.model = self.model.eval() + self.model = self.model.bfloat16() + self.is_multimodal = 'ola' in self.model_name.lower() + + if not no_register: + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, args=(self,)) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status() + } + r = requests.post(url, json=data) + assert r.status_code == 200, f"Failed to register to controller: {r.text}" + + def send_heart_beat(self): + logger.info(f"Send heart beat. Models: {[self.model_name]}. " + f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " + f"global_counter: {global_counter}") + + print('skip heart beat') + return + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post(url, json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length()}, timeout=5) + exist = ret.json()["exist"] + break + except requests.exceptions.RequestException as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if model_semaphore is None: + return 0 + else: + return args.limit_model_concurrency - model_semaphore._value + (len( + model_semaphore._waiters) if model_semaphore._waiters is not None else 0) + + def get_status(self): + return { + "model_names": [self.model_name], + "speed": 1, + "queue_length": self.get_queue_length(), + } + + @torch.inference_mode() + def generate_stream(self, params): + tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor + + prompt = params["prompt"] + ori_prompt = prompt + images = params.get("images", None) + num_image_tokens = 0 + if images is not None and len(images) > 0 and self.is_multimodal: + if len(images) > 0: + if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): + raise ValueError("Number of images does not match number of tokens in prompt") + + images = [load_image_from_base64(image) for image in images] + image_sizes = [image.size for image in images] + logger.info(f"image_sizes: {image_sizes}") + image_tensor, image_highres_tensor = process_anyres_highres_image_genli(images, image_processor, model.config) + + if type(image_tensor) is list: + image_tensor = [image_.to(self.model.device, dtype=torch.bfloat16) for image_ in image_tensor] + else: + image_tensor = image_tensor.to(self.model.device, dtype=torch.bfloat16) + + if type(image_highres_tensor) is list: + image_highres_tensor = [image_.to(self.model.device, dtype=torch.bfloat16) for image_ in image_highres_tensor] + else: + image_highres_tensor = image_highres_tensor.to(self.model.device, dtype=torch.bfloat16) + + replace_token = DEFAULT_IMAGE_TOKEN + if getattr(self.model.config, 'mm_use_im_start_end', False): + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) + + # num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches + else: + images = None + image_sizes = None + image_args = {"images": images, "images_highres": image_highres_tensor, "image_sizes": image_sizes} + else: + images = None + image_args = {} + + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_context_length = getattr(model.config, 'max_position_embeddings', 2048) + max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) + stop_str = params.get("stop", None) + stop_str = '<|im_end|>' if stop_str is None else stop_str + do_sample = True if temperature > 0.001 else False + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) + + # max_new_tokens = 1024 # min(max_new_tokens, max_context_length - input_ids.shape[-1] - 576) + + if max_new_tokens < 1: + yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" + return + + thread = Thread(target=model.generate, kwargs=dict( + inputs=input_ids, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + streamer=streamer, + # stopping_criteria=[stopping_criteria], + use_cache=True, + modalities=['image'] + **image_args + )) + thread.start() + + start_time = time.time() + generated_text = ori_prompt + for new_text in streamer: + generated_text += new_text + if generated_text.endswith(stop_str): + generated_text = generated_text[:-len(stop_str)] + yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" + + end_time = time.time() + + new_generated = generated_text[len(ori_prompt):] + new_generated_tokens = tokenizer(new_generated).input_ids + token_per_second = len(new_generated_tokens) / (end_time - start_time) + print(f"token_per_second: {token_per_second}") + + def generate_stream_gate(self, params): + # try: + for x in self.generate_stream(params): + yield x + # except ValueError as e: + # print("Caught ValueError:", e) + # ret = { + # "text": server_error_msg, + # "error_code": 1, + # } + # yield json.dumps(ret).encode() + b"\0" + # except torch.cuda.CudaError as e: + # print("Caught torch.cuda.CudaError:", e) + # ret = { + # "text": server_error_msg, + # "error_code": 1, + # } + # yield json.dumps(ret).encode() + b"\0" + # except Exception as e: + # print("Caught Unknown Error", e) + # ret = { + # "text": server_error_msg, + # "error_code": 1, + # } + # yield json.dumps(ret).encode() + b"\0" + + +app = FastAPI() + + +def release_model_semaphore(fn=None): + model_semaphore.release() + if fn is not None: + fn() + + +@app.post("/worker_generate_stream") +async def generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) + await model_semaphore.acquire() + worker.send_heart_beat() + generator = worker.generate_stream_gate(params) + background_tasks = BackgroundTasks() + background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_get_status") +async def get_status(request: Request): + return worker.get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, + default="http://0.0.0.0:21002") + parser.add_argument("--controller-address", type=str, + default="http://0.0.0.0:12345") + parser.add_argument("--model-path", type=str, default="/mnt/lzy/ola-model/ola-7b") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--model-name", type=str) + parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") + parser.add_argument("--limit-model-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=1) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.multi_modal: + logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") + + worker = ModelWorker(args.controller_address, + args.worker_address, + worker_id, + args.no_register, + args.model_path, + args.model_base, + args.model_name, + args.load_8bit, + args.load_4bit) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/ola/serve_ola/register_worker.py b/ola/serve_ola/register_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2c40295e0351f25709ba25554c9329f15bf0d2 --- /dev/null +++ b/ola/serve_ola/register_worker.py @@ -0,0 +1,26 @@ +""" +Manually register workers. + +Usage: +python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--controller-address", type=str) + parser.add_argument("--worker-name", type=str) + parser.add_argument("--check-heart-beat", action="store_true") + args = parser.parse_args() + + url = args.controller_address + "/register_worker" + data = { + "worker_name": args.worker_name, + "check_heart_beat": args.check_heart_beat, + "worker_status": None, + } + r = requests.post(url, json=data) + assert r.status_code == 200 diff --git a/ola/serve_ola/sglang_worker.py b/ola/serve_ola/sglang_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..703ee94939f19afa984f8cff74604f112aca63cd --- /dev/null +++ b/ola/serve_ola/sglang_worker.py @@ -0,0 +1,255 @@ +""" +A model worker executes the model. +""" +import argparse +import asyncio +from concurrent.futures import ThreadPoolExecutor +import json +import time +import threading +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +import requests +import re +import uvicorn +from functools import partial + +from llava.constants import WORKER_HEART_BEAT_INTERVAL +from llava.utils import (build_logger, server_error_msg, + pretty_print_semaphore) +from llava.model.builder import load_pretrained_model +from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from transformers import AutoTokenizer + +import sglang as sgl +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.utils import read_jsonl, dump_state_text +from sglang.lang.interpreter import ProgramState + + +GB = 1 << 30 + +worker_id = str(uuid.uuid4())[:6] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") +global_counter = 0 + +model_semaphore = None + + +def heart_beat_worker(controller): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + controller.send_heart_beat() + + +@sgl.function +def pipeline(s, prompt, max_tokens): + for p in prompt: + if type(p) is str: + s += p + else: + s += sgl.image(p) + s += sgl.gen("response", max_tokens=max_tokens) + + +class ModelWorker: + def __init__(self, controller_addr, worker_addr, sgl_endpoint, + worker_id, no_register, model_name): + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + + # Select backend + backend = RuntimeEndpoint(sgl_endpoint) + sgl.set_default_backend(backend) + model_path = backend.model_info["model_path"] + + if model_path.endswith("/"): + model_path = model_path[:-1] + if model_name is None: + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + self.model_name = model_paths[-2] + "_" + model_paths[-1] + else: + self.model_name = model_paths[-1] + else: + self.model_name = model_name + + logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...") + + if not no_register: + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, args=(self,)) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status() + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info(f"Send heart beat. Models: {[self.model_name]}. " + f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " + f"global_counter: {global_counter}") + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post(url, json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length()}, timeout=5) + exist = ret.json()["exist"] + break + except requests.exceptions.RequestException as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if model_semaphore is None: + return 0 + else: + return args.limit_model_concurrency - model_semaphore._value + (len( + model_semaphore._waiters) if model_semaphore._waiters is not None else 0) + + def get_status(self): + return { + "model_names": [self.model_name], + "speed": 1, + "queue_length": self.get_queue_length(), + } + + async def generate_stream(self, params): + ori_prompt = prompt = params["prompt"] + images = params.get("images", None) + if images is not None and len(images) > 0: + if len(images) > 0: + if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): + raise ValueError("Number of images does not match number of tokens in prompt") + + images = [load_image_from_base64(image) for image in images] + # FIXME: hacky padding + images = [expand2square(image, tuple(int(x*255) for x in [0.48145466, 0.4578275, 0.40821073])) for image in images] + + # FIXME: for image-start/end token + # replace_token = DEFAULT_IMAGE_TOKEN + # if getattr(self.model.config, 'mm_use_im_start_end', False): + # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) + prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN) + prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN) + prompt = [] + for i in range(len(prompt_split)): + prompt.append(prompt_split[i]) + if i < len(images): + prompt.append(images[i]) + else: + prompt = [prompt] + + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + # max_context_length = getattr(model.config, 'max_position_embeddings', 2048) + max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) + stop_str = params.get("stop", None) + stop_str = [stop_str] if stop_str is not None else None + + if max_new_tokens < 1: + yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" + return + + # print(prompt) + state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True) + + generated_text = ori_prompt + async for text_outputs in state.text_async_iter(var_name="response"): + generated_text += text_outputs + yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" + + async def generate_stream_gate(self, params): + try: + async for x in self.generate_stream(params): + yield x + except ValueError as e: + print("Caught ValueError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + print("Caught Unknown Error", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + + +app = FastAPI() + + +def release_model_semaphore(fn=None): + model_semaphore.release() + if fn is not None: + fn() + + +@app.post("/worker_generate_stream") +async def generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) + await model_semaphore.acquire() + worker.send_heart_beat() + generator = worker.generate_stream_gate(params) + background_tasks = BackgroundTasks() + background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_get_status") +async def get_status(request: Request): + return worker.get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, + default="http://localhost:21002") + parser.add_argument("--controller-address", type=str, + default="http://localhost:21001") + parser.add_argument("--model-name", type=str) + parser.add_argument("--sgl-endpoint", type=str) + parser.add_argument("--limit-model-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=1) + parser.add_argument("--no-register", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + worker = ModelWorker(args.controller_address, + args.worker_address, + args.sgl_endpoint, + worker_id, + args.no_register, + args.model_name) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/ola/serve_ola/test_message.py b/ola/serve_ola/test_message.py new file mode 100644 index 0000000000000000000000000000000000000000..77f383895c21de905351bf625167c230409d9c0c --- /dev/null +++ b/ola/serve_ola/test_message.py @@ -0,0 +1,79 @@ +import os, sys +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['REGIONAL_POOL'] = '2x' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['SKIP_LOAD_VIT'] = '1' + +sys.path.append('/mnt/lzy/Ola') + +import argparse +import json + +import requests + +from llava.conversation import default_conversation, conv_templates + + +def main(): + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post(controller_addr + "/get_worker_address", + json={"model": args.model_name}) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + return + + # conv = default_conversation.copy() + conv = conv_templates['v1_qwen2'].copy() + conv.append_message(conv.roles[0], args.message) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + headers = {"User-Agent": "LLaVA Client"} + pload = { + "model": args.model_name, + "prompt": prompt, + "max_new_tokens": args.max_new_tokens, + "temperature": 0.7, + "stop": conv.sep, + } + response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, + json=pload, stream=True) + + print(prompt.replace(conv.sep, "\n"), end="") + for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"].split(conv.sep)[-1] + print(output, end="\r") + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--controller-address", type=str, default="http://localhost:21001") + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, default="facebook/opt-350m") + parser.add_argument("--max-new-tokens", type=int, default=32) + parser.add_argument("--message", type=str, default= + "写一δΈͺ100ε­—ηš„η«₯话故事") + args = parser.parse_args() + + main() diff --git a/ola/utils.py b/ola/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..36d9d0ca05846cd31f15a5e230db75ac9f03759e --- /dev/null +++ b/ola/utils.py @@ -0,0 +1,213 @@ +# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import torch +import logging +import logging.handlers +import transformers + +from ola.constants import LOGDIR + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +handler = None + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when='D', utc=True, encoding='UTF-8') + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = '' + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = '' + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == '\n': + self.logger.log(self.log_level, line.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != '': + self.logger.log(self.log_level, self.linebuf.rstrip()) + self.linebuf = '' + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_speech_projector_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + +def lengths_to_padding_mask(lens): + bsz, max_lens = lens.size(0), torch.max(lens).item() + mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) + mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) + return mask + + +def lengths_to_mask(lens): + return ~lengths_to_padding_mask(lens) + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = {"Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..53d1c882d9c9a8b5d256939d86ad765683d2eb16 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,27 @@ +transformers==4.43.4 +accelerate==0.33.0 +wandb==0.17.4 +peft==0.11.1 +tokenizers==0.19.1 +sentencepiece==0.1.99 +shortuuid +pydantic +uvicorn +fastapi +soundfile +einops==0.6.1 +einops-exts==0.0.4 +timm==0.9.16 +openai-whisper +deepspeed==0.12.2 +loguru +av +librosa +gradio +urllib3==1.26.6 +moviepy +hyperpyyaml +onnxruntime +inflect +pynini==2.1.5 +WeTextProcessing==1.0.2 \ No newline at end of file