diff --git a/InputSans-Regular.ttf b/InputSans-Regular.ttf new file mode 100644 index 0000000000000000000000000000000000000000..787bf8e9ff7f43e8081a3502e5c3b892e1879913 Binary files /dev/null and b/InputSans-Regular.ttf differ diff --git a/app.py b/app.py index a699bc5b3c2e987102ca93e0ee28d601e0a93d02..810c4bf6f5a48fb2a46353470df31e0987110437 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,302 @@ +import argparse +import os +import random +# import sys +# import os +# +# BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# sys.path.append(BASE_DIR) + +import numpy as np +import torch +import torch.backends.cudnn as cudnn import gradio as gr -def greet(name): - return "Hello " + name + "!!" +from constants.constant import LIGHTER_COLOR_MAP_HEX +# NOTE: Must import LlamaTokenizer before `bubogpt.common.config` +# otherwise, it will cause seg fault when `llama_tokenizer.decode` is called + +from grounding_model import GroundingModule +from match import MatchModule +from bubogpt.common.config import Config +from bubogpt.common.dist_utils import get_rank +from bubogpt.common.registry import registry +from eval_scripts.conversation import Chat, CONV_X, DummyChat +# NOTE&TODO: put this before bubogpt import will cause circular import +# possibly because `imagebind` imports `bubogpt` and `bubogpt` also imports `imagebind` +from imagebind.models.image_bind import ModalityType +# from ner import NERModule +from tagging_model import TaggingModule + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Qualitative") + parser.add_argument("--cfg-path", help="path to configuration file.", deafult='./eval_configs/mmgpt4_eval.yaml') + parser.add_argument("--dummy", action="store_true", help="Debug Mode") + parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + parser.add_argument("--ground-all", action="store_true") + args = parser.parse_args() + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +# ======================================== +# Model Initialization +# ======================================== + +print('Initializing Chat') +args = parse_args() + +assert args.dummy or (args.cfg_path is not None), "Invalid Config! Set --dummy or configurate the cfg_path!" + +if not args.dummy: + cfg = Config(args) + + # Create processors + vis_processor_cfg = cfg.datasets_cfg.default.vis_processor.eval + aud_processor_cfg = cfg.datasets_cfg.default.audio_processor.eval + vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) + aud_processor = registry.get_processor_class(aud_processor_cfg.name).from_config(aud_processor_cfg) + processors = {ModalityType.VISION: vis_processor, ModalityType.AUDIO: aud_processor} + + # Create model + model_config = cfg.model_cfg + model_config.device_8bit = args.gpu_id + model_cls = registry.get_model_class(model_config.arch) + model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) + chat = Chat(model, processors, device='cuda:{}'.format(args.gpu_id)) +else: + model = None + chat = DummyChat() + +match = MatchModule(model='gpt-4') +tagging_module = TaggingModule(device='cuda:{}'.format(args.gpu_id)) +grounding_dino = GroundingModule(device='cuda:{}'.format(args.gpu_id)) +print('Initialization Finished') + + +# ======================================== +# Gradio Setting +# ======================================== + +def gradio_reset(chat_state, emb_list): + if chat_state is not None: + chat_state.messages = [] + if emb_list is not None: + emb_list = [] + return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=False), \ + gr.update(value=None, interactive=True), \ + gr.update(placeholder='Please upload your image/audio first', interactive=False), \ + gr.update(value=None), \ + gr.update(value="Upload & Start Chat", interactive=True), \ + chat_state, emb_list, gr.update(value={}) + + +def upload_x(gr_img, gr_aud, chat_state): + if gr_img is None and gr_aud is None: + return None, None, None, gr.update(interactive=True), chat_state, None, {} + chat_state = CONV_X.copy() + emb_list = [] + if gr_img is not None: + chat.upload_img(gr_img, chat_state, emb_list) + state = { + 'tags': tagging_module(gr_img) + } + # print(state) + else: + state = {} + if gr_aud is not None: + chat.upload_aud(gr_aud, chat_state, emb_list) + return gr.update(interactive=False), gr.update(interactive=False), \ + gr.update(interactive=True, placeholder='Type and press Enter'), \ + gr.update(value="Start Chatting", interactive=False), \ + chat_state, emb_list, state + + +def gradio_ask(user_message, chatbot, chat_state, text_output, last_answer): + if len(user_message) == 0: + return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state, \ + gr.update(value=None, color_map=None, show_legend=False), gr.update(value=None) + if last_answer is not None: + chatbot[-1][1] = last_answer + chat.ask(user_message, chat_state) + if text_output is not None: + os.makedirs('results', exist_ok=True) + # print("****** Text output is:", text_output) + chatbot[-1][1] = ''.join(map(lambda x: x[0], text_output)) + chatbot = chatbot + [[user_message, None]] + return '', chatbot, chat_state, gr.update(value=None, color_map=None, show_legend=False), gr.update(value=None) + + +def gradio_answer(image, chatbot, chat_state, emb_list, num_beams, temperature, entity_state): + llm_message = chat.answer(conversation=chat_state, + emb_list=emb_list, + num_beams=num_beams, + temperature=temperature, + max_new_tokens=300, + max_length=2000)[0] + if image is not None: + # new_entity_state = entity_state.value() + # new_entity_state.update({"answer": llm_message}) + entity_state["answer"] = llm_message + rich_text, match_state, color_map = match(llm_message, entity_state) + print("Original Color Map: ", color_map) + color_map = {key: LIGHTER_COLOR_MAP_HEX[color_map[key]] for key in color_map} + print("Modified Color Map: ", color_map) + chatbot[-1][1] = "The answer can be found in the textbox below and I'm trying my best to highlight the " \ + "corresponding region on the image." + # new_entity_state.update({"match_state": match_state}) + entity_state['match_state'] = match_state # item_id -> local_id + new_grounded_image = grounding_dino.draw(image, entity_state) + show_legend = bool(match_state) + print('gradio_answer ==> current state: ', entity_state) + + # if args.ground_all: + # ground_img, local_results = grounding_dino.prompt2mask(image, + # '.'.join(map(lambda x: x, state['entity'])), + # state=state) + # else: + # ground_img = None + return chatbot, chat_state, emb_list, \ + gr.update(value=rich_text, color_map=color_map, show_legend=show_legend), \ + gr.update(value=entity_state), \ + gr.update(value=llm_message), gr.update(value=new_grounded_image) + else: + chatbot[-1][1] = llm_message + return chatbot, chat_state, emb_list, \ + gr.update(value=None), \ + entity_state, \ + gr.update(value=None), gr.update(value=None) + +def grounding_fn(image, chatbot, entity_state): + # print("Grounding fn: ", entity_state) + if image and entity_state: + ground_img, local_results = grounding_dino.prompt2mask2( + image, ','.join(map(lambda x: x, entity_state['tags'])), state=entity_state + ) + entity_state['grounding'] = { + 'full': ground_img, + 'local': local_results + } + print('grounding_fn ==> current state: ', entity_state) + return chatbot, gr.update(value=ground_img, interactive=False), entity_state + return chatbot, gr.update(value=None, interactive=False), entity_state + + +def select_fn(image, ground_img, entity_state, evt: gr.SelectData): + if image is None: + return gr.update(value=None, interactive=False) + item, label = evt.value[0], evt.value[1] + + if label is None: + return ground_img + print('select_fn ==> current state: ', entity_state) + if 'grounding' not in entity_state: + ground_img, local_results = grounding_dino.prompt2mask2(image, + ','.join(map(lambda x: x[0], entity_state['tags'])), + state=entity_state) + entity_state['grounding'] = { + 'full': ground_img, + 'local': local_results + } + # local_img = entity_state['grounding']['local'][entity]['image'] + # print("DEBUG INFO: ", entity_state) + local_img = grounding_dino.draw(image, entity_state, item.lower()) + return gr.update(value=local_img, interactive=False) + + +title = """

Demo of BuboGPT

""" +description = """

This is the demo of BuboGPT. Upload and start chatting!

""" +# article = """

+# """ + +# TODO show examples below + +with gr.Blocks() as demo: + gr.Markdown(title) + gr.Markdown(description) + # gr.Markdown(article) + + with gr.Row(): + with gr.Column(scale=0.5): + image = gr.Image(type="pil") + grounded_image = gr.Image(type="pil", interactive=False) + audio = gr.Audio() + upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") + clear = gr.Button("Restart") + + num_beams = gr.Slider( + minimum=1, + maximum=10, + value=1, + step=1, + interactive=True, + label="beam search numbers", + ) + + temperature = gr.Slider( + minimum=0.1, + maximum=2.0, + value=1.0, + step=0.1, + interactive=True, + label="Temperature", + ) + + with gr.Column(): + chat_state = gr.State() + last_answer = gr.State() + entity_state = gr.State(value={}) + emb_list = gr.State() + chatbot = gr.Chatbot(label='BindGPT-4') + text_output = gr.HighlightedText(value=None, label="Response", show_legend=False) + text_input = gr.Textbox(label='User', placeholder='Please upload your image/audio first', interactive=False) + + upload_button.click( + upload_x, [image, audio, chat_state], + [image, audio, text_input, upload_button, chat_state, emb_list, entity_state]).then( + grounding_fn, + [image, chatbot, entity_state], + [chatbot, grounded_image, entity_state] + ) + + text_input.submit(gradio_ask, + [text_input, chatbot, chat_state, text_output, last_answer], + [text_input, chatbot, chat_state, text_output, last_answer] + ).then( + gradio_answer, + [image, chatbot, chat_state, emb_list, num_beams, temperature, entity_state], + [chatbot, chat_state, emb_list, text_output, entity_state, last_answer, grounded_image] + ) + + clear.click(gradio_reset, + [chat_state, emb_list], + [chatbot, image, grounded_image, audio, text_input, text_output, + upload_button, chat_state, emb_list, entity_state], + queue=False) + + text_output.select( + select_fn, + [image, grounded_image, entity_state], + [grounded_image] + ) -iface = gr.Interface(fn=greet, inputs="text", outputs="text") -iface.launch() \ No newline at end of file +demo.launch(enable_queue=True) diff --git a/bubogpt/__init__.py b/bubogpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89bd10e5a9d5ae51fceeddcae87321b538766ad1 --- /dev/null +++ b/bubogpt/__init__.py @@ -0,0 +1,31 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import sys + +from omegaconf import OmegaConf + +from bubogpt.common.registry import registry + +from bubogpt.datasets.builders import * +from bubogpt.models import * +from bubogpt.processors import * +from bubogpt.tasks import * + + +root_dir = os.path.dirname(os.path.abspath(__file__)) +default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) + +registry.register_path("library_root", root_dir) +repo_root = os.path.join(root_dir, "..") +registry.register_path("repo_root", repo_root) +cache_root = os.path.join(repo_root, default_cfg.env.cache_root) +registry.register_path("cache_root", cache_root) + +registry.register("MAX_INT", sys.maxsize) +registry.register("SPLIT_NAMES", ["train", "val", "test"]) diff --git a/bubogpt/common/__init__.py b/bubogpt/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bubogpt/common/config.py b/bubogpt/common/config.py new file mode 100644 index 0000000000000000000000000000000000000000..7db25b13a2414fc6e05205ea6e84eb3a1957b78b --- /dev/null +++ b/bubogpt/common/config.py @@ -0,0 +1,473 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import json +from typing import Dict + +from omegaconf import OmegaConf +from bubogpt.common.registry import registry + +# logging.info = print + + +class Config: + def __init__(self, args): + self.config = {} + + self.args = args + + # Register the config and configuration for setup + registry.register("configuration", self) + + user_config = self._build_opt_list(self.args.options) + + config = OmegaConf.load(self.args.cfg_path) + + runner_config = self.build_runner_config(config) + model_config = self.build_model_config(config, **user_config) + if not config.run.evaluate: + dataset_config = self.build_dataset_config(config) + else: + dataset_config = OmegaConf.create({"datasets": config.datasets}) + + # Validate the user-provided runner configuration + # model and dataset configuration are supposed to be validated by the respective classes + # [TODO] validate the model/dataset configuration + # self._validate_runner_config(runner_config) + + # Override the default configuration with user options. + self.config = OmegaConf.merge( + runner_config, model_config, dataset_config, user_config + ) + + def _validate_runner_config(self, runner_config): + """ + This method validates the configuration, such that + 1) all the user specified options are valid; + 2) no type mismatches between the user specified options and the config. + """ + runner_config_validator = create_runner_config_validator() + runner_config_validator.validate(runner_config) + + def _build_opt_list(self, opts): + opts_dot_list = self._convert_to_dot_list(opts) + return OmegaConf.from_dotlist(opts_dot_list) + + @staticmethod + def build_model_config(config, **kwargs): + model = config.get("model", None) + assert model is not None, "Missing model configuration file." + + model_cls = registry.get_model_class(model.arch) + assert model_cls is not None, f"Model '{model.arch}' has not been registered." + + model_type = kwargs.get("model.model_type", None) + if not model_type: + model_type = model.get("model_type", None) + # else use the model type selected by user. + + assert model_type is not None, "Missing model_type." + + model_config_path = model_cls.default_config_path(model_type=model_type) + + model_config = OmegaConf.create() + # hierarchy override, customized config > default config + model_config = OmegaConf.merge( + model_config, + OmegaConf.load(model_config_path), + {"model": config["model"]}, + ) + + return model_config + + @staticmethod + def build_runner_config(config): + return {"run": config.run} + + @staticmethod + def build_dataset_config(config): + datasets = config.get("datasets", None) + if datasets is None: + raise KeyError( + "Expecting 'datasets' as the root key for dataset configuration." + ) + + dataset_config = OmegaConf.create() + + for dataset_name in datasets: + builder_cls = registry.get_builder_class(dataset_name) + + dataset_config_type = datasets[dataset_name].get("type", "default") + dataset_config_path = builder_cls.default_config_path( + type=dataset_config_type + ) + + # hierarchy override, customized config > default config + dataset_config = OmegaConf.merge( + dataset_config, + OmegaConf.load(dataset_config_path) if dataset_config_path is not None else {}, + {"datasets": {dataset_name: config["datasets"][dataset_name]}}, + ) + + return dataset_config + + def _convert_to_dot_list(self, opts): + if opts is None: + opts = [] + + if len(opts) == 0: + return opts + + has_equal = opts[0].find("=") != -1 + + if has_equal: + return opts + + return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] + + def get_config(self): + return self.config + + @property + def run_cfg(self): + return self.config.run + + @property + def datasets_cfg(self): + return self.config.datasets + + @property + def model_cfg(self): + return self.config.model + + def pretty_print(self): + logging.info("\n===== Running Parameters =====") + logging.info(self._convert_node_to_json(self.config.run)) + + logging.info("\n====== Dataset Attributes ======") + datasets = self.config.datasets + + for dataset in datasets: + if dataset in self.config.datasets: + logging.info(f"\n======== {dataset} =======") + dataset_config = self.config.datasets[dataset] + logging.info(self._convert_node_to_json(dataset_config)) + else: + logging.warning(f"No dataset named '{dataset}' in config. Skipping") + + logging.info(f"\n====== Model Attributes ======") + logging.info(self._convert_node_to_json(self.config.model)) + + def _convert_node_to_json(self, node): + container = OmegaConf.to_container(node, resolve=True) + return json.dumps(container, indent=4, sort_keys=True) + + def to_dict(self): + return OmegaConf.to_container(self.config) + + +def node_to_dict(node): + return OmegaConf.to_container(node) + + +class ConfigValidator: + """ + This is a preliminary implementation to centralize and validate the configuration. + May be altered in the future. + + A helper class to validate configurations from yaml file. + + This serves the following purposes: + 1. Ensure all the options in the yaml are defined, raise error if not. + 2. when type mismatches are found, the validator will raise an error. + 3. a central place to store and display helpful messages for supported configurations. + + """ + + class _Argument: + def __init__(self, name, choices=None, type=None, help=None): + self.name = name + self.val = None + self.choices = choices + self.type = type + self.help = help + + def __str__(self): + s = f"{self.name}={self.val}" + if self.type is not None: + s += f", ({self.type})" + if self.choices is not None: + s += f", choices: {self.choices}" + if self.help is not None: + s += f", ({self.help})" + return s + + def __init__(self, description): + self.description = description + + self.arguments = dict() + + self.parsed_args = None + + def __getitem__(self, key): + assert self.parsed_args is not None, "No arguments parsed yet." + + return self.parsed_args[key] + + def __str__(self) -> str: + return self.format_help() + + def add_argument(self, *args, **kwargs): + """ + Assume the first argument is the name of the argument. + """ + self.arguments[args[0]] = self._Argument(*args, **kwargs) + + def validate(self, config=None): + """ + Convert yaml config (dict-like) to list, required by argparse. + """ + for k, v in config.items(): + assert ( + k in self.arguments + ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}.""" + + if self.arguments[k].type is not None: + try: + self.arguments[k].val = self.arguments[k].type(v) + except ValueError: + raise ValueError(f"{k} is not a valid {self.arguments[k].type}.") + + if self.arguments[k].choices is not None: + assert ( + v in self.arguments[k].choices + ), f"""{k} must be one of {self.arguments[k].choices}.""" + + return config + + def format_arguments(self): + return str([f"{k}" for k in sorted(self.arguments.keys())]) + + def format_help(self): + # description + key-value pair string for each argument + help_msg = str(self.description) + return help_msg + ", available arguments: " + self.format_arguments() + + def print_help(self): + # display help message + print(self.format_help()) + + +def create_runner_config_validator(): + validator = ConfigValidator(description="Runner configurations") + + validator.add_argument( + "runner", + type=str, + choices=["runner_base", "runner_iter"], + help="""Runner to use. The "runner_base" uses epoch-based training while iter-based + runner runs based on iters. Default: runner_base""", + ) + # add argumetns for training dataset ratios + validator.add_argument( + "train_dataset_ratios", + type=Dict[str, float], + help="""Ratios of training dataset. This is used in iteration-based runner. + Do not support for epoch-based runner because how to define an epoch becomes tricky. + Default: None""", + ) + validator.add_argument( + "max_iters", + type=float, + help="Maximum number of iterations to run.", + ) + validator.add_argument( + "max_epoch", + type=int, + help="Maximum number of epochs to run.", + ) + # add arguments for iters_per_inner_epoch + validator.add_argument( + "iters_per_inner_epoch", + type=float, + help="Number of iterations per inner epoch. This is required when runner is runner_iter.", + ) + lr_scheds_choices = registry.list_lr_schedulers() + validator.add_argument( + "lr_sched", + type=str, + choices=lr_scheds_choices, + help="Learning rate scheduler to use, from {}".format(lr_scheds_choices), + ) + task_choices = registry.list_tasks() + validator.add_argument( + "task", + type=str, + choices=task_choices, + help="Task to use, from {}".format(task_choices), + ) + # add arguments for init_lr + validator.add_argument( + "init_lr", + type=float, + help="Initial learning rate. This will be the learning rate after warmup and before decay.", + ) + # add arguments for min_lr + validator.add_argument( + "min_lr", + type=float, + help="Minimum learning rate (after decay).", + ) + # add arguments for warmup_lr + validator.add_argument( + "warmup_lr", + type=float, + help="Starting learning rate for warmup.", + ) + # add arguments for learning rate decay rate + validator.add_argument( + "lr_decay_rate", + type=float, + help="Learning rate decay rate. Required if using a decaying learning rate scheduler.", + ) + # add arguments for weight decay + validator.add_argument( + "weight_decay", + type=float, + help="Weight decay rate.", + ) + # add arguments for training batch size + validator.add_argument( + "batch_size_train", + type=int, + help="Training batch size.", + ) + # add arguments for evaluation batch size + validator.add_argument( + "batch_size_eval", + type=int, + help="Evaluation batch size, including validation and testing.", + ) + # add arguments for number of workers for data loading + validator.add_argument( + "num_workers", + help="Number of workers for data loading.", + ) + # add arguments for warm up steps + validator.add_argument( + "warmup_steps", + type=int, + help="Number of warmup steps. Required if a warmup schedule is used.", + ) + # add arguments for random seed + validator.add_argument( + "seed", + type=int, + help="Random seed.", + ) + # add arguments for output directory + validator.add_argument( + "output_dir", + type=str, + help="Output directory to save checkpoints and logs.", + ) + # add arguments for whether only use evaluation + validator.add_argument( + "evaluate", + help="Whether to only evaluate the model. If true, training will not be performed.", + ) + # add arguments for splits used for training, e.g. ["train", "val"] + validator.add_argument( + "train_splits", + type=list, + help="Splits to use for training.", + ) + # add arguments for splits used for validation, e.g. ["val"] + validator.add_argument( + "valid_splits", + type=list, + help="Splits to use for validation. If not provided, will skip the validation.", + ) + # add arguments for splits used for testing, e.g. ["test"] + validator.add_argument( + "test_splits", + type=list, + help="Splits to use for testing. If not provided, will skip the testing.", + ) + # add arguments for accumulating gradient for iterations + validator.add_argument( + "accum_grad_iters", + type=int, + help="Number of iterations to accumulate gradient for.", + ) + + # ====== distributed training ====== + validator.add_argument( + "device", + type=str, + choices=["cpu", "cuda"], + help="Device to use. Support 'cuda' or 'cpu' as for now.", + ) + validator.add_argument( + "world_size", + type=int, + help="Number of processes participating in the job.", + ) + validator.add_argument("dist_url", type=str) + validator.add_argument("distributed", type=bool) + # add arguments to opt using distributed sampler during evaluation or not + validator.add_argument( + "use_dist_eval_sampler", + type=bool, + help="Whether to use distributed sampler during evaluation or not.", + ) + + # ====== task specific ====== + # generation task specific arguments + # add arguments for maximal length of text output + validator.add_argument( + "max_len", + type=int, + help="Maximal length of text output.", + ) + # add arguments for minimal length of text output + validator.add_argument( + "min_len", + type=int, + help="Minimal length of text output.", + ) + # add arguments number of beams + validator.add_argument( + "num_beams", + type=int, + help="Number of beams used for beam search.", + ) + + # vqa task specific arguments + # add arguments for number of answer candidates + validator.add_argument( + "num_ans_candidates", + type=int, + help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""", + ) + # add arguments for inference method + validator.add_argument( + "inference_method", + type=str, + choices=["genearte", "rank"], + help="""Inference method to use for question answering. If rank, requires a answer list.""", + ) + + # ====== model specific ====== + validator.add_argument( + "k_test", + type=int, + help="Number of top k most similar samples from ITC/VTC selection to be tested.", + ) + + return validator diff --git a/bubogpt/common/dist_utils.py b/bubogpt/common/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9280150bf5122d51bb810a9f0258a233e7088647 --- /dev/null +++ b/bubogpt/common/dist_utils.py @@ -0,0 +1,137 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import functools +import os + +import torch +import torch.distributed as dist +import timm.models.hub as timm_hub + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}, world {}): {}".format( + args.rank, args.world_size, args.dist_url + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + timeout=datetime.timedelta( + days=365 + ), # allow auto-downloading and de-compressing + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def get_dist_info(): + if torch.__version__ < "1.0": + initialized = dist._initialized + else: + initialized = dist.is_initialized() + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: # non-distributed training + rank = 0 + world_size = 1 + return rank, world_size + + +def main_process(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + +def download_cached_file(url, check_hash=True, progress=False): + """ + Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. + If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. + """ + + def get_cached_file_path(): + # a hack to sync the file path across processes + parts = torch.hub.urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(timm_hub.get_cache_dir(), filename) + + return cached_file + + if is_main_process(): + timm_hub.download_cached_file(url, check_hash, progress) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return get_cached_file_path() diff --git a/bubogpt/common/gradcam.py b/bubogpt/common/gradcam.py new file mode 100644 index 0000000000000000000000000000000000000000..d53a5254d4b319eaf2cbfbd081b0ca8e38c5c7a0 --- /dev/null +++ b/bubogpt/common/gradcam.py @@ -0,0 +1,24 @@ +import numpy as np +from matplotlib import pyplot as plt +from scipy.ndimage import filters +from skimage import transform as skimage_transform + + +def getAttMap(img, attMap, blur=True, overlap=True): + attMap -= attMap.min() + if attMap.max() > 0: + attMap /= attMap.max() + attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") + if blur: + attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) + attMap -= attMap.min() + attMap /= attMap.max() + cmap = plt.get_cmap("jet") + attMapV = cmap(attMap) + attMapV = np.delete(attMapV, 3, 2) + if overlap: + attMap = ( + 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img + + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV + ) + return attMap diff --git a/bubogpt/common/logger.py b/bubogpt/common/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..dcf8ac8897cc53ccdf23778120c9d2c6566be58e --- /dev/null +++ b/bubogpt/common/logger.py @@ -0,0 +1,195 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import logging +import time +from collections import defaultdict, deque + +import torch +import torch.distributed as dist + +from bubogpt.common import dist_utils + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not dist_utils.is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def global_avg(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def setup_logger(): + logging.basicConfig( + level=logging.INFO if dist_utils.is_main_process() else logging.WARN, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], + ) diff --git a/bubogpt/common/optims.py b/bubogpt/common/optims.py new file mode 100644 index 0000000000000000000000000000000000000000..0bbf89c36073f22f444432d0e603a23ea88e3b3d --- /dev/null +++ b/bubogpt/common/optims.py @@ -0,0 +1,119 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import math + +from bubogpt.common.registry import registry + + +@registry.register_lr_scheduler("linear_warmup_step_lr") +class LinearWarmupStepLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + min_lr, + init_lr, + decay_rate=1, + warmup_start_lr=-1, + warmup_steps=0, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.min_lr = min_lr + + self.decay_rate = decay_rate + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + if cur_epoch == 0: + warmup_lr_schedule( + step=cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + step_lr_schedule( + epoch=cur_epoch, + optimizer=self.optimizer, + init_lr=self.init_lr, + min_lr=self.min_lr, + decay_rate=self.decay_rate, + ) + + +@registry.register_lr_scheduler("linear_warmup_cosine_lr") +class LinearWarmupCosineLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + iters_per_epoch, + min_lr, + init_lr, + warmup_steps=0, + warmup_start_lr=-1, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.iters_per_epoch = iters_per_epoch + self.min_lr = min_lr + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + total_cur_step = cur_epoch * self.iters_per_epoch + cur_step + if total_cur_step < self.warmup_steps: + warmup_lr_schedule( + step=cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + cosine_lr_schedule( + epoch=total_cur_step, + optimizer=self.optimizer, + max_epoch=self.max_epoch * self.iters_per_epoch, + init_lr=self.init_lr, + min_lr=self.min_lr, + ) + + +def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): + """Decay the learning rate""" + lr = (init_lr - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * epoch / max_epoch) + ) + min_lr + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): + """Warmup the learning rate""" + lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): + """Decay the learning rate""" + lr = max(min_lr, init_lr * (decay_rate**epoch)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr diff --git a/bubogpt/common/registry.py b/bubogpt/common/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..564d59ca58bb12b3f863513de90ec6d90fabba34 --- /dev/null +++ b/bubogpt/common/registry.py @@ -0,0 +1,333 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + + + +class Registry: + mapping = { + "builder_name_mapping": {}, + "task_name_mapping": {}, + "processor_name_mapping": {}, + "model_name_mapping": {}, + "lr_scheduler_name_mapping": {}, + "runner_name_mapping": {}, + "state": {}, + "paths": {}, + } + + @classmethod + def register_builder(cls, name): + r"""Register a dataset builder to registry with key 'name' + + Args: + name: Key with which the builder will be registered. + + Usage: + + from bubogpt.common.registry import registry + from bubogpt.datasets.base_dataset_builder import BaseDatasetBuilder + """ + + def wrap(builder_cls): + # TODO: merge them or split builders by modality + from bubogpt.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder + from bubogpt.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder + from bubogpt.datasets.builders.multimodal_base_dataset_builder import MultimodalBaseDatasetBuilder + + assert issubclass( + builder_cls, (ImageBaseDatasetBuilder, AudioBaseDatasetBuilder, MultimodalBaseDatasetBuilder) + ), "All builders must inherit BaseDatasetBuilder class, found {}".format( + builder_cls + ) + if name in cls.mapping["builder_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["builder_name_mapping"][name] + ) + ) + cls.mapping["builder_name_mapping"][name] = builder_cls + return builder_cls + + return wrap + + @classmethod + def register_task(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from bubogpt.common.registry import registry + """ + + def wrap(task_cls): + from bubogpt.tasks.base_task import BaseTask + + assert issubclass( + task_cls, BaseTask + ), "All tasks must inherit BaseTask class" + if name in cls.mapping["task_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["task_name_mapping"][name] + ) + ) + cls.mapping["task_name_mapping"][name] = task_cls + return task_cls + + return wrap + + @classmethod + def register_model(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from bubogpt.common.registry import registry + """ + + def wrap(model_cls): + from bubogpt.models import BaseModel + + assert issubclass( + model_cls, BaseModel + ), "All models must inherit BaseModel class" + if name in cls.mapping["model_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["model_name_mapping"][name] + ) + ) + cls.mapping["model_name_mapping"][name] = model_cls + return model_cls + + return wrap + + @classmethod + def register_processor(cls, name): + r"""Register a processor to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from bubogpt.common.registry import registry + """ + + def wrap(processor_cls): + from bubogpt.processors import BaseProcessor + + assert issubclass( + processor_cls, BaseProcessor + ), "All processors must inherit BaseProcessor class" + if name in cls.mapping["processor_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["processor_name_mapping"][name] + ) + ) + cls.mapping["processor_name_mapping"][name] = processor_cls + return processor_cls + + return wrap + + @classmethod + def register_lr_scheduler(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from bubogpt.common.registry import registry + """ + + def wrap(lr_sched_cls): + if name in cls.mapping["lr_scheduler_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["lr_scheduler_name_mapping"][name] + ) + ) + cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls + return lr_sched_cls + + return wrap + + @classmethod + def register_runner(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from bubogpt.common.registry import registry + """ + + def wrap(runner_cls): + if name in cls.mapping["runner_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["runner_name_mapping"][name] + ) + ) + cls.mapping["runner_name_mapping"][name] = runner_cls + return runner_cls + + return wrap + + @classmethod + def register_path(cls, name, path): + r"""Register a path to registry with key 'name' + + Args: + name: Key with which the path will be registered. + + Usage: + + from bubogpt.common.registry import registry + """ + assert isinstance(path, str), "All path must be str." + if name in cls.mapping["paths"]: + raise KeyError("Name '{}' already registered.".format(name)) + cls.mapping["paths"][name] = path + + @classmethod + def register(cls, name, obj): + r"""Register an item to registry with key 'name' + + Args: + name: Key with which the item will be registered. + + Usage:: + + from bubogpt.common.registry import registry + + registry.register("config", {}) + """ + path = name.split(".") + current = cls.mapping["state"] + + for part in path[:-1]: + if part not in current: + current[part] = {} + current = current[part] + + current[path[-1]] = obj + + # @classmethod + # def get_trainer_class(cls, name): + # return cls.mapping["trainer_name_mapping"].get(name, None) + + @classmethod + def get_builder_class(cls, name): + return cls.mapping["builder_name_mapping"].get(name, None) + + @classmethod + def get_model_class(cls, name): + return cls.mapping["model_name_mapping"].get(name, None) + + @classmethod + def get_task_class(cls, name): + return cls.mapping["task_name_mapping"].get(name, None) + + @classmethod + def get_processor_class(cls, name): + return cls.mapping["processor_name_mapping"].get(name, None) + + @classmethod + def get_lr_scheduler_class(cls, name): + return cls.mapping["lr_scheduler_name_mapping"].get(name, None) + + @classmethod + def get_runner_class(cls, name): + return cls.mapping["runner_name_mapping"].get(name, None) + + @classmethod + def list_runners(cls): + return sorted(cls.mapping["runner_name_mapping"].keys()) + + @classmethod + def list_models(cls): + return sorted(cls.mapping["model_name_mapping"].keys()) + + @classmethod + def list_tasks(cls): + return sorted(cls.mapping["task_name_mapping"].keys()) + + @classmethod + def list_processors(cls): + return sorted(cls.mapping["processor_name_mapping"].keys()) + + @classmethod + def list_lr_schedulers(cls): + return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) + + @classmethod + def list_datasets(cls): + return sorted(cls.mapping["builder_name_mapping"].keys()) + + @classmethod + def get_path(cls, name): + return cls.mapping["paths"].get(name, None) + + @classmethod + def get(cls, name, default=None, no_warning=False): + r"""Get an item from registry with key 'name' + + Args: + name (string): Key whose value needs to be retrieved. + default: If passed and key is not in registry, default value will + be returned with a warning. Default: None + no_warning (bool): If passed as True, warning when key doesn't exist + will not be generated. Useful for MMF's + internal operations. Default: False + """ + original_name = name + name = name.split(".") + value = cls.mapping["state"] + for subname in name: + value = value.get(subname, default) + if value is default: + break + + if ( + "writer" in cls.mapping["state"] + and value == default + and no_warning is False + ): + cls.mapping["state"]["writer"].warning( + "Key {} is not present in registry, returning default value " + "of {}".format(original_name, default) + ) + return value + + @classmethod + def unregister(cls, name): + r"""Remove an item from registry with key 'name' + + Args: + name: Key which needs to be removed. + Usage:: + + from mmf.common.registry import registry + + config = registry.unregister("config") + """ + return cls.mapping["state"].pop(name, None) + + +registry = Registry() diff --git a/bubogpt/common/utils.py b/bubogpt/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8045339b215127f66bf78e226faac07fe496964a --- /dev/null +++ b/bubogpt/common/utils.py @@ -0,0 +1,424 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import io +import json +import logging +import os +import pickle +import re +import shutil +import urllib +import urllib.error +import urllib.request +from typing import Optional +from urllib.parse import urlparse + +import numpy as np +import pandas as pd +import yaml +from iopath.common.download import download +from iopath.common.file_io import file_lock, g_pathmgr +from bubogpt.common.registry import registry +from torch.utils.model_zoo import tqdm +from torchvision.datasets.utils import ( + check_integrity, + download_file_from_google_drive, + extract_archive, +) + + +def now(): + from datetime import datetime + + return datetime.now().strftime("%Y%m%d%H%M")[:-1] + + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + + +def get_cache_path(rel_path): + return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path)) + + +def get_abs_path(rel_path): + return os.path.join(registry.get_path("library_root"), rel_path) + + +def load_json(filename): + with open(filename, "r") as f: + return json.load(f) + + +# The following are adapted from torchvision and vissl +# torchvision: https://github.com/pytorch/vision +# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + print(f"Error creating directory: {dir_path}") + return is_success + + +def get_redirected_url(url: str): + """ + Given a URL, returns the URL it redirects to or the + original URL in case of no indirection + """ + import requests + + with requests.Session() as session: + with session.get(url, stream=True, allow_redirects=True) as response: + if response.history: + return response.url + else: + return url + + +def to_google_drive_download_url(view_url: str) -> str: + """ + Utility function to transform a view URL of google drive + to a download URL for google drive + Example input: + https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view + Example output: + https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp + """ + splits = view_url.split("/") + assert splits[-1] == "view" + file_id = splits[-2] + return f"https://drive.google.com/uc?export=download&id={file_id}" + + +def download_google_drive_url(url: str, output_path: str, output_file_name: str): + """ + Download a file from google drive + Downloading an URL from google drive requires confirmation when + the file of the size is too big (google drive notifies that + anti-viral checks cannot be performed on such files) + """ + import requests + + with requests.Session() as session: + + # First get the confirmation token and append it to the URL + with session.get(url, stream=True, allow_redirects=True) as response: + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + url = url + "&confirm=" + v + + # Then download the content of the file + with session.get(url, stream=True, verify=True) as response: + makedir(output_path) + path = os.path.join(output_path, output_file_name) + total_size = int(response.headers.get("Content-length", 0)) + with open(path, "wb") as file: + from tqdm import tqdm + + with tqdm(total=total_size) as progress_bar: + for block in response.iter_content( + chunk_size=io.DEFAULT_BUFFER_SIZE + ): + file.write(block) + progress_bar.update(len(block)) + + +def _get_google_drive_file_id(url: str) -> Optional[str]: + parts = urlparse(url) + + if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: + return None + + match = re.match(r"/file/d/(?P[^/]*)", parts.path) + if match is None: + return None + + return match.group("id") + + +def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: + with open(filename, "wb") as fh: + with urllib.request.urlopen( + urllib.request.Request(url, headers={"User-Agent": "vissl"}) + ) as response: + with tqdm(total=response.length) as pbar: + for chunk in iter(lambda: response.read(chunk_size), ""): + if not chunk: + break + pbar.update(chunk_size) + fh.write(chunk) + + +def download_url( + url: str, + root: str, + filename: Optional[str] = None, + md5: Optional[str] = None, +) -> None: + """Download a file from a url and place it in root. + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. + If None, use the basename of the URL. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + makedir(root) + + # check if file is already present locally + if check_integrity(fpath, md5): + print("Using downloaded and verified file: " + fpath) + return + + # expand redirect chain if needed + url = get_redirected_url(url) + + # check if file is located on Google Drive + file_id = _get_google_drive_file_id(url) + if file_id is not None: + return download_file_from_google_drive(file_id, root, filename, md5) + + # download the file + try: + print("Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] + if url[:5] == "https": + url = url.replace("https:", "http:") + print( + "Failed download. Trying https -> http instead." + " Downloading " + url + " to " + fpath + ) + _urlretrieve(url, fpath) + else: + raise e + + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def download_and_extract_archive( + url: str, + download_root: str, + extract_root: Optional[str] = None, + filename: Optional[str] = None, + md5: Optional[str] = None, + remove_finished: bool = False, +) -> None: + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print("Extracting {} to {}".format(archive, extract_root)) + extract_archive(archive, extract_root, remove_finished) + + +def cache_url(url: str, cache_dir: str) -> str: + """ + This implementation downloads the remote resource and caches it locally. + The resource will only be downloaded if not previously requested. + """ + parsed_url = urlparse(url) + dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/"))) + makedir(dirname) + filename = url.split("/")[-1] + cached = os.path.join(dirname, filename) + with file_lock(cached): + if not os.path.isfile(cached): + logging.info(f"Downloading {url} to {cached} ...") + cached = download(url, dirname, filename=filename) + logging.info(f"URL {url} cached in {cached}") + return cached + + +# TODO (prigoyal): convert this into RAII-style API +def create_file_symlink(file1, file2): + """ + Simply create the symlinks for a given file1 to file2. + Useful during model checkpointing to symlinks to the + latest successful checkpoint. + """ + try: + if g_pathmgr.exists(file2): + g_pathmgr.rm(file2) + g_pathmgr.symlink(file1, file2) + except Exception as e: + logging.info(f"Could NOT create symlink. Error: {e}") + + +def save_file(data, filename, append_to_json=True, verbose=True): + """ + Common i/o utility to handle saving data to various file formats. + Supported: + .pkl, .pickle, .npy, .json + Specifically for .json, users have the option to either append (default) + or rewrite by passing in Boolean value to append_to_json. + """ + if verbose: + logging.info(f"Saving data to file: {filename}") + file_ext = os.path.splitext(filename)[1] + if file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "wb") as fopen: + pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL) + elif file_ext == ".npy": + with g_pathmgr.open(filename, "wb") as fopen: + np.save(fopen, data) + elif file_ext == ".json": + if append_to_json: + with g_pathmgr.open(filename, "a") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + else: + with g_pathmgr.open(filename, "w") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "w") as fopen: + dump = yaml.dump(data) + fopen.write(dump) + fopen.flush() + else: + raise Exception(f"Saving {file_ext} is not supported yet") + + if verbose: + logging.info(f"Saved data to file: {filename}") + + +def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False): + """ + Common i/o utility to handle loading data from various file formats. + Supported: + .pkl, .pickle, .npy, .json + For the npy files, we support reading the files in mmap_mode. + If the mmap_mode of reading is not successful, we load data without the + mmap_mode. + """ + if verbose: + logging.info(f"Loading data from file: {filename}") + + file_ext = os.path.splitext(filename)[1] + if file_ext == ".txt": + with g_pathmgr.open(filename, "r") as fopen: + data = fopen.readlines() + elif file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "rb") as fopen: + data = pickle.load(fopen, encoding="latin1") + elif file_ext == ".npy": + if mmap_mode: + try: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load( + fopen, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + except ValueError as e: + logging.info( + f"Could not mmap {filename}: {e}. Trying without g_pathmgr" + ) + data = np.load( + filename, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + logging.info("Successfully loaded without g_pathmgr") + except Exception: + logging.info("Could not mmap without g_pathmgr. Trying without mmap") + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + else: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + elif file_ext == ".json": + with g_pathmgr.open(filename, "r") as fopen: + data = json.load(fopen) + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "r") as fopen: + data = yaml.load(fopen, Loader=yaml.FullLoader) + elif file_ext == ".csv": + with g_pathmgr.open(filename, "r") as fopen: + data = pd.read_csv(fopen) + else: + raise Exception(f"Reading from {file_ext} is not supported yet") + return data + + +def abspath(resource_path: str): + """ + Make a path absolute, but take into account prefixes like + "http://" or "manifold://" + """ + regex = re.compile(r"^\w+://") + if regex.match(resource_path) is None: + return os.path.abspath(resource_path) + else: + return resource_path + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + logging.info(f"Error creating directory: {dir_path}") + return is_success + + +def is_url(input_url): + """ + Check if an input string is a url. look for http(s):// and ignoring the case + """ + is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None + return is_url + + +def cleanup_dir(dir): + """ + Utility for deleting a directory. Useful for cleaning the storage space + that contains various training artifacts like checkpoints, data etc. + """ + if os.path.exists(dir): + logging.info(f"Deleting directory: {dir}") + shutil.rmtree(dir) + logging.info(f"Deleted contents of directory: {dir}") + + +def get_file_size(filename): + """ + Given a file, get the size of file in MB + """ + size_in_mb = os.path.getsize(filename) / float(1024**2) + return size_in_mb diff --git a/bubogpt/configs/datasets/aud_img_neg/default.yaml b/bubogpt/configs/datasets/aud_img_neg/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..08efb6906650a8df1e641f66d0616d59d5dc05b3 --- /dev/null +++ b/bubogpt/configs/datasets/aud_img_neg/default.yaml @@ -0,0 +1,10 @@ +datasets: + aud_img_neg: + data_type: audio_image + build_info: + image: + storage: /path/to/cc_sbu_align + ann_files: ['filter_cap.json'] + audio: + storage: /path/to/clotho + ann_files: ['audio_cap.json'] diff --git a/bubogpt/configs/datasets/audioset/defaults.yaml b/bubogpt/configs/datasets/audioset/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8257e338991f1118343cfba04ed8b3dd2051167c --- /dev/null +++ b/bubogpt/configs/datasets/audioset/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + audioset: + data_type: audio + build_info: + storage: /path/to/AudioSet_SL/AudioSet_SL{00..54}.tar diff --git a/bubogpt/configs/datasets/bbc/defaults.yaml b/bubogpt/configs/datasets/bbc/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3bc330d8c091b69df1deb8a3daecfc6c699d30a --- /dev/null +++ b/bubogpt/configs/datasets/bbc/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + bbc: + data_type: audio + build_info: + storage: /path/to/BBC_Sound_Effects/BBC_Sound_Effects{000000..000062}.tar diff --git a/bubogpt/configs/datasets/cc12m/defaults.yaml b/bubogpt/configs/datasets/cc12m/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..889e111324e3e3e25a733be3be23518190347a92 --- /dev/null +++ b/bubogpt/configs/datasets/cc12m/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + cc12m: + data_type: images + build_info: + storage: /path/to/cc12m_web/{000000..002221}.tar diff --git a/bubogpt/configs/datasets/cc_sbu/align.yaml b/bubogpt/configs/datasets/cc_sbu/align.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eee6bf7d357ef8388f808bc69f5ede284f3d3135 --- /dev/null +++ b/bubogpt/configs/datasets/cc_sbu/align.yaml @@ -0,0 +1,5 @@ +datasets: + cc_sbu_align: + data_type: images + build_info: + storage: /path/to/cc_sbu_align diff --git a/bubogpt/configs/datasets/cc_sbu/defaults.yaml b/bubogpt/configs/datasets/cc_sbu/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..60390eece551fe06a0f7c3ebb395351794b9f5f1 --- /dev/null +++ b/bubogpt/configs/datasets/cc_sbu/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + cc_sbu: + data_type: images + build_info: + storage: /path/to/cc_sbu_dataset/{00000..01255}.tar diff --git a/bubogpt/configs/datasets/clotho/align.yaml b/bubogpt/configs/datasets/clotho/align.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fc06329fb39239e27ed650ac499020b1c59fde72 --- /dev/null +++ b/bubogpt/configs/datasets/clotho/align.yaml @@ -0,0 +1,5 @@ +datasets: + clotho_align: + data_type: audio + build_info: + storage: /path/to/clotho diff --git a/bubogpt/configs/datasets/freesound/defaults.yaml b/bubogpt/configs/datasets/freesound/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ad8dab5b2d4205965ebbe2b758e17444be51c79b --- /dev/null +++ b/bubogpt/configs/datasets/freesound/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + freesound: + data_type: audio + build_info: + storage: /path/to/wavcaps/web_datasets/FreeSound/FreeSound{000000..000524}.tar diff --git a/bubogpt/configs/datasets/laion/defaults.yaml b/bubogpt/configs/datasets/laion/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6bad62901619c0a9e34619a400290f3e18083899 --- /dev/null +++ b/bubogpt/configs/datasets/laion/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + laion: + data_type: images + build_info: + storage: /path/to/laion_dataset/{00000..10488}.tar diff --git a/bubogpt/configs/datasets/soundbible/defaults.yaml b/bubogpt/configs/datasets/soundbible/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..addf9ddf1ee21d8a71ab71fcd355515c7635f229 --- /dev/null +++ b/bubogpt/configs/datasets/soundbible/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + soundbible: + data_type: audio + build_info: + storage: /path/to/SoundBible0.tar diff --git a/bubogpt/configs/datasets/vggss/align.yaml b/bubogpt/configs/datasets/vggss/align.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7fc913e58c22542d22dd9a5baf25de7ceb0ed053 --- /dev/null +++ b/bubogpt/configs/datasets/vggss/align.yaml @@ -0,0 +1,6 @@ +datasets: + vggss_align: + data_type: audio_image + build_info: + storage: /path/to/vggss + ann_files: ["vggss_mult_prefix.json"] diff --git a/bubogpt/configs/default.yaml b/bubogpt/configs/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d645ad6f6c10e473674704fc6ab2b1f668f4e7ff --- /dev/null +++ b/bubogpt/configs/default.yaml @@ -0,0 +1,5 @@ +env: + # For default users + # cache_root: "cache" + # For internal use with persistent storage + cache_root: "/export/home/.cache/bubogpt" diff --git a/bubogpt/configs/models/mmgpt4.yaml b/bubogpt/configs/models/mmgpt4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb8aabe19ce88f021d11d42e300feb61486617d8 --- /dev/null +++ b/bubogpt/configs/models/mmgpt4.yaml @@ -0,0 +1,30 @@ +model: + arch: mm_gpt4 + + # Imagebind + freeze_imagebind: True + + # Q-Former + freeze_qformer: True + q_former_model: "/path/to/blip2_pretrained_flant5xxl.pth" + num_query_token: 32 + + # Vicuna + llama_model: "/path/to/vicuna-7b-v0/" + + # generation configs + prompt: "" + +preprocess: + vis_processor: + train: + name: "imagebind_vision_train" + image_size: 224 + eval: + name: "imagebind_vision_eval" + image_size: 224 + text_processor: + train: + name: "imagebind_caption" + eval: + name: "imagebind_caption" diff --git a/bubogpt/datasets/__init__.py b/bubogpt/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bubogpt/datasets/builders/__init__.py b/bubogpt/datasets/builders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c85fa4fcac5b01564efbefc4f3c8f199405bd708 --- /dev/null +++ b/bubogpt/datasets/builders/__init__.py @@ -0,0 +1,90 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from bubogpt.datasets.builders.image_base_dataset_builder import load_dataset_config +from bubogpt.datasets.builders.image_text_pair_builder import ( + CCSBUBuilderImage, + LaionBuilderImage, + CCSBUAlignBuilderImage, + LlavaInstruct150Builder, +) +from bubogpt.datasets.builders.audio_text_pair_builder import ( + BBCBuilder, + AudioSetBuilder, + SoundBibleBuilder, + FreeSoundBuilder +) +from bubogpt.datasets.builders.audio_image_text_builder import ( + VGGSSBuilderAudioImage +) +from bubogpt.common.registry import registry + +__all__ = [ + "CCSBUBuilderImage", + "LaionBuilderImage", + "CCSBUAlignBuilderImage", + "LlavaInstruct150Builder", + # Audio builders + "BBCBuilder", + "AudioSetBuilder", + "SoundBibleBuilder", + "FreeSoundBuilder", + # Audio Image builders + "VGGSSBuilderAudioImage" +] + + +def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): + """ + Example + + >>> dataset = load_dataset("coco_caption", cfg=None) + >>> splits = dataset.keys() + >>> print([len(dataset[split]) for split in splits]) + + """ + if cfg_path is None: + cfg = None + else: + cfg = load_dataset_config(cfg_path) + + try: + builder = registry.get_builder_class(name)(cfg) + except TypeError: + print( + f"Dataset {name} not found. Available datasets:\n" + + ", ".join([str(k) for k in dataset_zoo.get_names()]) + ) + exit(1) + + if vis_path is not None: + if data_type is None: + # use default data type in the config + data_type = builder.config.data_type + + assert ( + data_type in builder.config.build_info + ), f"Invalid data_type {data_type} for {name}." + + builder.config.build_info.get(data_type).storage = vis_path + + dataset = builder.build_datasets() + return dataset + + +class DatasetZoo: + def __init__(self) -> None: + self.dataset_zoo = { + k: list(v.DATASET_CONFIG_DICT.keys()) + for k, v in sorted(registry.mapping["builder_name_mapping"].items()) + } + + def get_names(self): + return list(self.dataset_zoo.keys()) + + +dataset_zoo = DatasetZoo() diff --git a/bubogpt/datasets/builders/audio_base_dataset_builder.py b/bubogpt/datasets/builders/audio_base_dataset_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..584fd7c2995305b2bfecdc4250cadebdd7fe020c --- /dev/null +++ b/bubogpt/datasets/builders/audio_base_dataset_builder.py @@ -0,0 +1,142 @@ +import logging +import os +import shutil +import warnings + +from omegaconf import OmegaConf +import torch.distributed as dist +from torchvision.datasets.utils import download_url + +import bubogpt.common.utils as utils +from bubogpt.common.dist_utils import is_dist_avail_and_initialized, is_main_process +from bubogpt.common.registry import registry +from bubogpt.datasets.builders import load_dataset_config +from bubogpt.processors.base_processor import BaseProcessor + + +class AudioBaseDatasetBuilder: + train_dataset_cls, eval_dataset_cls = None, None + + def __init__(self, cfg=None): + super().__init__() + + if cfg is None: + # help to create datasets from default config. + self.config = load_dataset_config(self.default_config_path()) + elif isinstance(cfg, str): + self.config = load_dataset_config(cfg) + else: + # when called from task.build_dataset() + self.config = cfg + + self.data_type = self.config.data_type + + self.audio_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + + if is_main_process(): + self._download_data() + + if is_dist_avail_and_initialized(): + dist.barrier() + + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + datasets = self.build() # dataset['train'/'val'/'test'] + + return datasets + + def build_processors(self): + aud_proc_cfg = self.config.get("audio_processor") + txt_proc_cfg = self.config.get("text_processor") + + if aud_proc_cfg is not None: + aud_train_cfg = aud_proc_cfg.get("train") + aud_eval_cfg = aud_proc_cfg.get("eval") + + self.audio_processors["train"] = self._build_proc_from_cfg(aud_train_cfg) + self.audio_processors["eval"] = self._build_proc_from_cfg(aud_eval_cfg) + + if txt_proc_cfg is not None: + txt_train_cfg = txt_proc_cfg.get("train") + txt_eval_cfg = txt_proc_cfg.get("eval") + + self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) + self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) + + @staticmethod + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else None + ) + + @classmethod + def default_config_path(cls, type="default"): + return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) + + def _download_data(self): + self._download_ann() + self._download_aud() + + def _download_ann(self): + """ + Download annotation files if necessary. + All the audio-language datasets should have annotations of unified format. + + storage_path can be: + (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. + (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. + + Local annotation paths should be relative. + """ + anns = self.config.build_info.annotations + + splits = anns.keys() + + cache_root = registry.get_path("cache_root") + + for split in splits: + info = anns[split] + + urls, storage_paths = info.get("url", None), info.storage + + if isinstance(urls, str): + urls = [urls] + if isinstance(storage_paths, str): + storage_paths = [storage_paths] + + assert len(urls) == len(storage_paths) + + for url_or_filename, storage_path in zip(urls, storage_paths): + # if storage_path is relative, make it full by prefixing with cache_root. + if not os.path.isabs(storage_path): + storage_path = os.path.join(cache_root, storage_path) + + dirname = os.path.dirname(storage_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + if os.path.isfile(url_or_filename): + src, dst = url_or_filename, storage_path + if not os.path.exists(dst): + shutil.copyfile(src=src, dst=dst) + else: + logging.info("Using existing file {}.".format(dst)) + else: + if os.path.isdir(storage_path): + # if only dirname is provided, suffix with basename of URL. + raise ValueError( + "Expecting storage_path to be a file path, got directory {}".format( + storage_path + ) + ) + else: + filename = os.path.basename(storage_path) + + download_url(url=url_or_filename, root=dirname, filename=filename) diff --git a/bubogpt/datasets/builders/audio_image_text_builder.py b/bubogpt/datasets/builders/audio_image_text_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..246913cfa0c61e27c58722353316f37a049d2490 --- /dev/null +++ b/bubogpt/datasets/builders/audio_image_text_builder.py @@ -0,0 +1,105 @@ +import logging +import os +import warnings + +from bubogpt.common.registry import registry +from bubogpt.datasets.builders.multimodal_base_dataset_builder import MultimodalBaseDatasetBuilder +from bubogpt.datasets.datasets.audio_image.audio_image_datasets import AudioLocalizationDataset, AudioImageNegDataset + + +@registry.register_builder("vggss_align") +class VGGSSBuilderAudioImage(MultimodalBaseDatasetBuilder): + train_dataset_cls = AudioLocalizationDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/vggss/align.yaml", + "3k": "configs/datasets/vggss/align3k.yaml", + "5k": "configs/datasets/vggss/align5k.yaml", + "31k": "configs/datasets/vggss/align31k.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + storage_path = build_info.storage + + datasets = dict() + + if not os.path.exists(storage_path): + warnings.warn("storage path {} does not exist.".format(storage_path)) + print("Building datasets with: ", self.get_ann_files()) + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + processors={**{ + modal: self.processors[modal]["train"] for modal in self.data_type + }, **{ + "text": self.processors["text"]["train"] + }}, + roots={ + modal: os.path.join(storage_path, f"{modal}s") for modal in self.data_type + }, + # ann_paths=[os.path.join(storage_path, 'vggsound_balanced.json')], + ann_paths=self.get_ann_files(), + ) + + return datasets + + def get_ann_files(self): + ann_files = self.config.build_info.get("ann_files", ["vggsound_balanced.json"]) + return [os.path.join(self.config.build_info.storage, fname) for fname in ann_files] + + +@registry.register_builder("aud_img_neg") +class NegBuilderAudioImage(MultimodalBaseDatasetBuilder): + train_dataset_cls = AudioImageNegDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/aud_img_neg/default.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + # storage_path = build_info.storage + storage_path = { + "image": build_info.image.storage, + "audio": build_info.audio.storage, + } + ann_files = { + "image": build_info.image.ann_files, + "audio": build_info.audio.ann_files, + } + ann_paths = { + modal: [os.path.join(storage_path[modal], fname) for fname in ann_files[modal]] for modal in self.data_type + } + + datasets = dict() + + for path in storage_path.values(): + if not os.path.exists(path): + warnings.warn("storage path {} does not exist.".format(path)) + print("Building datasets with: ", ann_paths) + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + processors={**{ + modal: self.processors[modal]["train"] for modal in self.data_type + }, **{ + "text": self.processors["text"]["train"] + }}, + roots={ + modal: os.path.join(storage_path[modal], f"{modal}") for modal in self.data_type + }, + ann_paths=ann_paths, + ) + + return datasets diff --git a/bubogpt/datasets/builders/audio_text_pair_builder.py b/bubogpt/datasets/builders/audio_text_pair_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbf4e53dd2c4f1ce7f1e2f361ccb9b4a9200078 --- /dev/null +++ b/bubogpt/datasets/builders/audio_text_pair_builder.py @@ -0,0 +1,88 @@ +import os +import logging +import warnings + +from bubogpt.common.registry import registry +from bubogpt.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder +from bubogpt.datasets.datasets.audio_caption import GenericAudioDataset, AudioCaptionDataset + + +class GenericAudioBuilder(AudioBaseDatasetBuilder): + train_dataset_cls = GenericAudioDataset + + def _download_ann(self): + pass + + def _download_aud(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + audio_processor=self.audio_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("bbc") +class BBCBuilder(GenericAudioBuilder): + DATASET_CONFIG_DICT = {"default": "configs/datasets/bbc/defaults.yaml"} + + +@registry.register_builder("audioset") +class AudioSetBuilder(GenericAudioBuilder): + DATASET_CONFIG_DICT = {"default": "configs/datasets/audioset/defaults.yaml"} + + +@registry.register_builder("soundbible") +class SoundBibleBuilder(GenericAudioBuilder): + DATASET_CONFIG_DICT = {"default": "configs/datasets/soundbible/defaults.yaml"} + + +@registry.register_builder("freesound") +class FreeSoundBuilder(GenericAudioBuilder): + DATASET_CONFIG_DICT = {"default": "configs/datasets/freesound/defaults.yaml"} + + +@registry.register_builder("clotho_align") +class ClothoAlignBuilderAudio(GenericAudioBuilder): + train_dataset_cls = AudioCaptionDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/clotho/align.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + storage_path = build_info.storage + + datasets = dict() + + if not os.path.exists(storage_path): + warnings.warn("storage path {} does not exist.".format(storage_path)) + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + audio_processor=self.audio_processors["train"], + text_processor=self.text_processors["train"], + audio_root=os.path.join(storage_path, 'all'), + ann_paths=[os.path.join(storage_path, 'audio_cap.json')], + ) + + return datasets diff --git a/bubogpt/datasets/builders/image_base_dataset_builder.py b/bubogpt/datasets/builders/image_base_dataset_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..942ff86bf9ca13af4efee94af5e8b0a481b25353 --- /dev/null +++ b/bubogpt/datasets/builders/image_base_dataset_builder.py @@ -0,0 +1,238 @@ +""" + This file is from + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os +import shutil +import warnings + +from omegaconf import OmegaConf +import torch.distributed as dist +from torchvision.datasets.utils import download_url + +import bubogpt.common.utils as utils +from bubogpt.common.dist_utils import is_dist_avail_and_initialized, is_main_process +from bubogpt.common.registry import registry +from bubogpt.processors.base_processor import BaseProcessor + + +class ImageBaseDatasetBuilder: + train_dataset_cls, eval_dataset_cls = None, None + + def __init__(self, cfg=None): + super().__init__() + + if cfg is None: + # help to create datasets from default config. + self.config = load_dataset_config(self.default_config_path()) + elif isinstance(cfg, str): + self.config = load_dataset_config(cfg) + else: + # when called from task.build_dataset() + self.config = cfg + + self.data_type = self.config.data_type + + self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + + if is_main_process(): + self._download_data() + + if is_dist_avail_and_initialized(): + dist.barrier() + + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + datasets = self.build() # dataset['train'/'val'/'test'] + + return datasets + + def build_processors(self): + vis_proc_cfg = self.config.get("vis_processor") + txt_proc_cfg = self.config.get("text_processor") + + if vis_proc_cfg is not None: + vis_train_cfg = vis_proc_cfg.get("train") + vis_eval_cfg = vis_proc_cfg.get("eval") + + self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) + self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) + + if txt_proc_cfg is not None: + txt_train_cfg = txt_proc_cfg.get("train") + txt_eval_cfg = txt_proc_cfg.get("eval") + + self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) + self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) + + @staticmethod + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else None + ) + + @classmethod + def default_config_path(cls, type="default"): + if cls.DATASET_CONFIG_DICT[type] is None: + return None + else: + return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) + + def _download_data(self): + self._download_ann() + self._download_vis() + + def _download_ann(self): + """ + Download annotation files if necessary. + All the vision-language datasets should have annotations of unified format. + + storage_path can be: + (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. + (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. + + Local annotation paths should be relative. + """ + anns = self.config.build_info.annotations + + splits = anns.keys() + + cache_root = registry.get_path("cache_root") + + for split in splits: + info = anns[split] + + urls, storage_paths = info.get("url", None), info.storage + + if isinstance(urls, str): + urls = [urls] + if isinstance(storage_paths, str): + storage_paths = [storage_paths] + + assert len(urls) == len(storage_paths) + + for url_or_filename, storage_path in zip(urls, storage_paths): + # if storage_path is relative, make it full by prefixing with cache_root. + if not os.path.isabs(storage_path): + storage_path = os.path.join(cache_root, storage_path) + + dirname = os.path.dirname(storage_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + if os.path.isfile(url_or_filename): + src, dst = url_or_filename, storage_path + if not os.path.exists(dst): + shutil.copyfile(src=src, dst=dst) + else: + logging.info("Using existing file {}.".format(dst)) + else: + if os.path.isdir(storage_path): + # if only dirname is provided, suffix with basename of URL. + raise ValueError( + "Expecting storage_path to be a file path, got directory {}".format( + storage_path + ) + ) + else: + filename = os.path.basename(storage_path) + + download_url(url=url_or_filename, root=dirname, filename=filename) + + def _download_vis(self): + + storage_path = self.config.build_info.get(self.data_type).storage + storage_path = utils.get_cache_path(storage_path) + + if not os.path.exists(storage_path): + warnings.warn( + f""" + The specified path {storage_path} for visual inputs does not exist. + Please provide a correct path to the visual inputs or + refer to datasets/download_scripts/README.md for downloading instructions. + """ + ) + + def build(self): + """ + Create by split datasets inheriting torch.utils.data.Datasets. + + # build() can be dataset-specific. Overwrite to customize. + """ + self.build_processors() + + build_info = self.config.build_info + + ann_info = build_info.annotations + vis_info = build_info.get(self.data_type) + + datasets = dict() + for split in ann_info.keys(): + if split not in ["train", "val", "test"]: + continue + + is_train = split == "train" + + # processors + vis_processor = ( + self.vis_processors["train"] + if is_train + else self.vis_processors["eval"] + ) + text_processor = ( + self.text_processors["train"] + if is_train + else self.text_processors["eval"] + ) + + # annotation path + ann_paths = ann_info.get(split).storage + if isinstance(ann_paths, str): + ann_paths = [ann_paths] + + abs_ann_paths = [] + for ann_path in ann_paths: + if not os.path.isabs(ann_path): + ann_path = utils.get_cache_path(ann_path) + abs_ann_paths.append(ann_path) + ann_paths = abs_ann_paths + + # visual data storage path + vis_path = os.path.join(vis_info.storage, split) + + if not os.path.isabs(vis_path): + # vis_path = os.path.join(utils.get_cache_path(), vis_path) + vis_path = utils.get_cache_path(vis_path) + + if not os.path.exists(vis_path): + warnings.warn("storage path {} does not exist.".format(vis_path)) + + # create datasets + dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls + datasets[split] = dataset_cls( + vis_processor=vis_processor, + text_processor=text_processor, + ann_paths=ann_paths, + vis_root=vis_path, + ) + + return datasets + + +def load_dataset_config(cfg_path): + cfg = OmegaConf.load(cfg_path).datasets + cfg = cfg[list(cfg.keys())[0]] + + return cfg diff --git a/bubogpt/datasets/builders/image_text_pair_builder.py b/bubogpt/datasets/builders/image_text_pair_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa404362f043aad9575aead930e7c72627a8123 --- /dev/null +++ b/bubogpt/datasets/builders/image_text_pair_builder.py @@ -0,0 +1,189 @@ +import os +import logging +import warnings + +from bubogpt.common.registry import registry +from bubogpt.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder +from bubogpt.datasets.datasets.image_caption.laion_dataset import LaionDataset +from bubogpt.datasets.datasets.image_caption.cc_sbu_dataset import CCSBUDataset, \ + CCSBUAlignDatasetImageImageCaptionDataset, CCDataset +from bubogpt.datasets.datasets.image_caption.llava_dataset import LlavaInstruct150Dataset + +@registry.register_builder("cc_sbu") +class CCSBUBuilderImage(ImageBaseDatasetBuilder): + train_dataset_cls = CCSBUDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vision_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("laion") +class LaionBuilderImage(ImageBaseDatasetBuilder): + train_dataset_cls = LaionDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vision_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("cc_sbu_align") +class CCSBUAlignBuilderImage(ImageBaseDatasetBuilder): + train_dataset_cls = CCSBUAlignDatasetImageImageCaptionDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/cc_sbu/align.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + self.build_processors() + + build_info = self.config.build_info + storage_path = build_info.storage + + datasets = dict() + + if not os.path.exists(storage_path): + warnings.warn("storage path {} does not exist.".format(storage_path)) + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vision_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_paths=[os.path.join(storage_path, 'filter_cap.json')], + vis_root=os.path.join(storage_path, 'image'), + ) + + return datasets + + +@registry.register_builder("cc12m") +class CC12MBuilder(ImageBaseDatasetBuilder): + train_dataset_cls = CCDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/cc12m/defaults.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("llava_instruct150") +class LlavaInstruct150Builder(ImageBaseDatasetBuilder): + train_dataset_cls = LlavaInstruct150Dataset + + DATASET_CONFIG_DICT = {"default": None} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + + def build(self): + self.build_processors() + + datasets = dict() + split = "train" + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + vis_root="/path/to/dataset/COCO_2014", + ann_paths=[os.path.join("/path/to/dataset/llava/annotations", subset + '.json') + for subset in ["complex_reasoning_77k", "conversation_58k", "detail_23k"]], + ) + return datasets + + +# from bubogpt.datasets.builders.image_text_pair_builder import LlavaInstruct150Builder + +if __name__ == "__main__": + from omegaconf import OmegaConf + from itertools import islice + + data_cfg = OmegaConf.create({ + "vis_processor": {"train": {"name": "imagebind_vision_train", "image_size": 224}}, + "text_processor": {"train": {"name": "imagebind_caption"}}, + "data_type": "image", + }) + + builder = LlavaInstruct150Builder(data_cfg) + + datasets = builder.build_datasets() + + datasets["train"].check_existence() + + for sample in islice(datasets["train"], 10): + print(sample["vision"].shape, sample["prompt"], sample["text_input"]) diff --git a/bubogpt/datasets/builders/multimodal_base_dataset_builder.py b/bubogpt/datasets/builders/multimodal_base_dataset_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..ca61ab87470d563b2c1eff4bdcf0c3f1efb55c0f --- /dev/null +++ b/bubogpt/datasets/builders/multimodal_base_dataset_builder.py @@ -0,0 +1,74 @@ +import logging + +import torch.distributed as dist + +import bubogpt.common.utils as utils +from bubogpt.common.dist_utils import is_dist_avail_and_initialized, is_main_process +from bubogpt.common.registry import registry +from bubogpt.datasets.builders import load_dataset_config +from bubogpt.processors.base_processor import BaseProcessor + + +class MultimodalBaseDatasetBuilder(): + train_dataset_cls, eval_dataset_cls = None, None + + def __init__(self, cfg=None): + super().__init__() + + if cfg is None: + # help to create datasets from default config. + self.config = load_dataset_config(self.default_config_path()) + elif isinstance(cfg, str): + self.config = load_dataset_config(cfg) + else: + # when called from task.build_dataset() + self.config = cfg + + self.data_type = self.config.data_type.split("_") + # It will be a list like ["audio", "image"], etc. + + # Add "text" manually here. + + self.processors = {modal: {"train": BaseProcessor(), "eval": BaseProcessor()} + for modal in [*self.data_type, "text"]} + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + + if is_main_process(): + self._download_data() + + if is_dist_avail_and_initialized(): + dist.barrier() + + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + datasets = self.build() # dataset['train'/'val'/'test'] + + return datasets + + def build_processors(self): + for modal in [*self.data_type, "text"]: + proc_cfg = self.config.get("{}_processor".format(modal)) + if proc_cfg is not None: + train_cfg = proc_cfg.get("train") + eval_cfg = proc_cfg.get("eval") + self.processors[modal]["train"] = self._build_proc_from_cfg(train_cfg) + self.processors[modal]["eval"] = self._build_proc_from_cfg(eval_cfg) + + + @staticmethod + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else None + ) + + @classmethod + def default_config_path(cls, type="default"): + return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) + + def _download_data(self): + pass diff --git a/bubogpt/datasets/data_utils.py b/bubogpt/datasets/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..280387308fdf8c0790b5ccaa27014d8158da60ed --- /dev/null +++ b/bubogpt/datasets/data_utils.py @@ -0,0 +1,215 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import random +from typing import List, Iterable + +import decord +import webdataset as wds +import torch +from torch.utils.data import IterableDataset, Dataset, ConcatDataset + +from bubogpt.common.registry import registry + +decord.bridge.set_bridge("torch") +MAX_INT = registry.get("MAX_INT") + + +class WrappedConcatDataset(ConcatDataset): + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__(datasets) + + def collater(self, samples): + # TODO For now only supports datasets with same underlying collater implementations + + all_keys = set() + for s in samples: + all_keys.update(s) + + shared_keys = all_keys + for s in samples: + shared_keys = shared_keys & set(s.keys()) + + samples_shared_keys = [] + for s in samples: + samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) + + return self.datasets[0].collater(samples_shared_keys) + + +class WrappedChainDataset(wds.DataPipeline): + r"""Dataset for chaining multiple :class:`DataPipeline` s. + + This class is useful to assemble different existing dataset streams. The + chaining operation is done on-the-fly, so concatenating large-scale + datasets with this class will be efficient. + + Args: + datasets (iterable of IterableDataset): datasets to be chained together + """ + + def __init__(self, datasets: List[wds.DataPipeline]) -> None: + super().__init__() + self.datasets = datasets + self.prob = [] + self.names = [] + for dataset in self.datasets: + if hasattr(dataset, 'name'): + self.names.append(dataset.name) + else: + self.names.append('Unknown') + if hasattr(dataset, 'sample_ratio'): + self.prob.append(dataset.sample_ratio) + else: + self.prob.append(1) + logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") + + def __iter__(self): + datastreams = [iter(dataset) for dataset in self.datasets] + while True: + select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] + yield next(select_datastream) + + +def apply_to_sample(f, sample): + if len(sample) == 0: + return {} + + def _apply(x): + if torch.is_tensor(x): + return f(x) + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + else: + return x + + return _apply(sample) + + +def move_to_cuda(sample): + def _move_to_cuda(tensor): + return tensor.cuda() + + return apply_to_sample(_move_to_cuda, sample) + + +def move_to_cpu(sample): + def _move_to_cpu(tensor): + return tensor.cpu() + + return apply_to_sample(_move_to_cpu, sample) + + +def prepare_sample(samples, cuda_enabled=True): + if cuda_enabled: + samples = move_to_cuda(samples) + + # TODO fp16 support + + return samples + + +def reorg_datasets_by_split(datasets): + """ + Organizes datasets by split. + + Args: + datasets: dict of torch.utils.data.Dataset objects by name. + + Returns: + Dict of datasets by split {split_name: List[Datasets]}. + """ + # if len(datasets) == 1: + # return datasets[list(datasets.keys())[0]] + # else: + reorg_datasets = dict() + + # reorganize by split + for _, dataset in datasets.items(): + for split_name, dataset_split in dataset.items(): + if split_name not in reorg_datasets: + reorg_datasets[split_name] = [dataset_split] + else: + reorg_datasets[split_name].append(dataset_split) + + return reorg_datasets + + +def concat_datasets(datasets): + """ + Concatenates multiple datasets into a single dataset. + + It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support + generic IterableDataset because it requires creating separate samplers. + + Now only supports conctenating training datasets and assuming validation and testing + have only a single dataset. This is because metrics should not be computed on the concatenated + datasets. + + Args: + datasets: dict of torch.utils.data.Dataset objects by split. + + Returns: + Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, + "val" and "test" remain the same. + + If the input training datasets contain both map-style and DataPipeline datasets, returns + a tuple, where the first element is a concatenated map-style dataset and the second + element is a chained DataPipeline dataset. + + """ + # concatenate datasets in the same split + for split_name in datasets: + if split_name != "train": + assert ( + len(datasets[split_name]) == 1 + ), "Do not support multiple {} datasets.".format(split_name) + datasets[split_name] = datasets[split_name][0] + else: + iterable_datasets, map_datasets = [], [] + for dataset in datasets[split_name]: + if isinstance(dataset, wds.DataPipeline): + logging.info( + "Dataset {} is IterableDataset, can't be concatenated.".format( + dataset + ) + ) + iterable_datasets.append(dataset) + elif isinstance(dataset, IterableDataset): + raise NotImplementedError( + "Do not support concatenation of generic IterableDataset." + ) + else: + map_datasets.append(dataset) + + # if len(iterable_datasets) > 0: + # concatenate map-style datasets and iterable-style datasets separately + if len(iterable_datasets) > 1: + chained_datasets = ( + WrappedChainDataset(iterable_datasets) + ) + elif len(iterable_datasets) == 1: + chained_datasets = iterable_datasets[0] + else: + chained_datasets = None + + concat_datasets = ( + WrappedConcatDataset(map_datasets) if len(map_datasets) > 0 else None + ) + + train_datasets = concat_datasets, chained_datasets + train_datasets = tuple([x for x in train_datasets if x is not None]) + train_datasets = ( + train_datasets[0] if len(train_datasets) == 1 else train_datasets + ) + + datasets[split_name] = train_datasets + + return datasets diff --git a/bubogpt/datasets/datasets/__init__.py b/bubogpt/datasets/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bubogpt/datasets/datasets/audio_caption/__init__.py b/bubogpt/datasets/datasets/audio_caption/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7da2ede72a29df703b832afbc3caa771a4f266d2 --- /dev/null +++ b/bubogpt/datasets/datasets/audio_caption/__init__.py @@ -0,0 +1 @@ +from bubogpt.datasets.datasets.audio_caption.audio_caption_datasets import GenericAudioDataset, AudioCaptionDataset diff --git a/bubogpt/datasets/datasets/audio_caption/audio_caption_datasets.py b/bubogpt/datasets/datasets/audio_caption/audio_caption_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..b42a855bbd62c3d2effa9e8e84880aa82fa404f7 --- /dev/null +++ b/bubogpt/datasets/datasets/audio_caption/audio_caption_datasets.py @@ -0,0 +1,70 @@ +import json +import os +import torchaudio +import random +import tempfile + +from torch.utils.data import Dataset, default_collate +import webdataset as wds +from bubogpt.datasets.datasets.base_dataset import BaseDualDataset + + +class GenericAudioDataset(BaseDualDataset): + def __init__(self, audio_processor, text_processor, location): + super().__init__(x_processor=audio_processor, text_processor=text_processor) + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode(wds.torch_audio, handler=wds.warn_and_continue), + wds.to_tuple("flac", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.x_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + return { + "audio": sample[0], + # [clips_per_video, channel, mel_bins, time_steps] + "text_input": self.text_processor(sample[1]["caption"]), + } + + +class AudioCaptionDataset(BaseDualDataset): + def __init__(self, audio_processor, text_processor, audio_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(audio_processor, text_processor, audio_root, ann_paths) + + self.audio_ids = {} + n = 0 + for ann in self.annotation: + audio_id = ann["audio_id"] + if audio_id not in self.audio_ids.keys(): + self.audio_ids[audio_id] = n + n += 1 + + with open("prompts/alignment_audio.txt") as f: + self.prompts = f.read().splitlines() + print(f"==> {self.__class__.__name__} using prompts: ", "\n " + "\n ".join(self.prompts)) + + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + audio_file = ann["audio_id"] + ".wav" + audio_path = os.path.join(self.x_root, audio_file) + audio = torchaudio.load(audio_path) + audio = self.x_processor(audio) + caption = self.text_processor(ann["caption"]) + + return { + "audio": audio, + "text_input": caption, + # "audio_id": self.audio_ids[ann["audio_id"]], + "prompt": random.choice(self.prompts), + } diff --git a/bubogpt/datasets/datasets/audio_image/__init__.py b/bubogpt/datasets/datasets/audio_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bubogpt/datasets/datasets/audio_image/audio_image_datasets.py b/bubogpt/datasets/datasets/audio_image/audio_image_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..79318ad968ff544c777c081ab61b5b95ebf59927 --- /dev/null +++ b/bubogpt/datasets/datasets/audio_image/audio_image_datasets.py @@ -0,0 +1,92 @@ +import os +import random +import json +import torchaudio +from torch.utils.data import Dataset +from PIL import Image +from bubogpt.datasets.datasets.base_dataset import BaseMultiSourceDataset +import webdataset as wds + + +class AudioLocalizationDataset(BaseMultiSourceDataset): + def __init__(self, processors, roots, ann_paths): + super().__init__(processors, roots, ann_paths) + + with open("prompts/alignment_audio_image_region.txt") as f: + self.prompts = f.read().splitlines() + print(f"==> {self.__class__.__name__} using prompts: ", "\n " + "\n ".join(self.prompts)) + + def __getitem__(self, index): + ann = self.annotation[index] + + audio_file = ann["audio_id"] + ".wav" + image_file = ann["image_id"] + ".jpg" + audio_path = os.path.join(self.roots["audio"], audio_file) + image_path = os.path.join(self.roots["image"], image_file) + + audio = torchaudio.load(audio_path) + image = Image.open(image_path).convert("RGB") + audio = self.processors["audio"](audio) + image = self.processors["image"](image) + caption = self.processors["text"](ann["caption"]) + + return { + "audio": audio, + "vision": image, + "text_input": caption, + "prompt": random.choice(self.prompts), + } + + +class AudioImageNegDataset(Dataset): + def __init__(self, processors, roots, ann_paths) -> None: + super().__init__() + + self.processors = processors + self.roots = roots + self.ann_paths = ann_paths + + self.img_annotation = [] + for ann_path in ann_paths['image']: + self.img_annotation.extend(json.load(open(ann_path, "r"))['annotations']) + + self.aud_annotation = [] + for ann_path in ann_paths['audio']: + self.aud_annotation.extend(json.load(open(ann_path, "r"))['annotations']) + + with open("prompts/alignment_audio_image_neg.txt") as f: + self.prompts = f.read().splitlines() + print(f"==> {self.__class__.__name__} using prompts: ", "\n " + "\n ".join(self.prompts)) + + def __len__(self): + return len(self.img_annotation) + + def __getitem__(self, index): + + img_ann = self.img_annotation[index] + + img_file = '{}.jpg'.format(img_ann["image_id"]) + image_path = os.path.join(self.roots['image'], img_file) + image = Image.open(image_path).convert("RGB") + image = self.processors['image'](image) + + aud_index = random.randint(0, len(self.aud_annotation)-1) + aud_ann = self.aud_annotation[aud_index] + + audio_file = aud_ann["audio_id"] + ".wav" + audio_path = os.path.join(self.roots['audio'], audio_file) + audio = torchaudio.load(audio_path) + audio = self.processors['audio'](audio) + prompt = random.choice(self.prompts) + if "related" in prompt: + prefix = "They seem unrelated. " + else: + prefix = "They seem unrelated. " if random.random() < 0.5 else "" + caption = self.processors['text'](prefix + img_ann["caption"] + aud_ann["caption"]) + + return { + 'audio': audio, + 'vision': image, + 'text_input': caption, + 'prompt': prompt, + } diff --git a/bubogpt/datasets/datasets/base_dataset.py b/bubogpt/datasets/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb69ec2c1adcf2fed3b3bff48bbbbec7e300f7d --- /dev/null +++ b/bubogpt/datasets/datasets/base_dataset.py @@ -0,0 +1,79 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import json +from typing import Iterable + +from torch.utils.data import Dataset +from torch.utils.data.dataloader import default_collate + + +class BaseDualDataset(Dataset): + def __init__( + self, x_processor=None, text_processor=None, x_root=None, ann_paths=[] + ): + """ + x_root (string): Root directory of data in modality X (e.g. coco/images/, etc.) + ann_root (string): directory to store the annotation file + """ + self.x_root = x_root + + self.annotation = [] + for ann_path in ann_paths: + self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) + + self.x_processor = x_processor + self.text_processor = text_processor + + self._add_instance_ids() + + def __len__(self): + return len(self.annotation) + + def collater(self, samples): + return default_collate(samples) + + def set_processors(self, x_processor, text_processor): + self.x_processor = x_processor + self.text_processor = text_processor + + def _add_instance_ids(self, key="instance_id"): + for idx, ann in enumerate(self.annotation): + ann[key] = str(idx) + + +class BaseMultiSourceDataset(Dataset): + def __init__( + self, processors=None, roots=None, ann_paths=[] + ): + """ + processors (Dict[str, Processor]): The processors of different modalities. + roots (Dict[str, str]): The roots of different modalities, Deprecated + ann_root (string): directory to store the annotation file + """ + self.roots = roots + + self.annotation = [] + for ann_path in ann_paths: + self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) + + self.processors = processors + + self._add_instance_ids() + + def __len__(self): + return len(self.annotation) + + def collater(self, samples): + return default_collate(samples) + + def set_processors(self, processors): + self.processors = processors + + def _add_instance_ids(self, key="instance_id"): + for idx, ann in enumerate(self.annotation): + ann[key] = str(idx) diff --git a/bubogpt/datasets/datasets/dataloader_utils.py b/bubogpt/datasets/datasets/dataloader_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6ab19184608453f47fba95583e0865456164e6 --- /dev/null +++ b/bubogpt/datasets/datasets/dataloader_utils.py @@ -0,0 +1,162 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import time +import random +import torch +from bubogpt.datasets.data_utils import move_to_cuda +from torch.utils.data import DataLoader + + +class MultiIterLoader: + """ + A simple wrapper for iterating over multiple iterators. + + Args: + loaders (List[Loader]): List of Iterator loaders. + ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. + """ + + def __init__(self, loaders, ratios=None): + # assert all loaders has __next__ method + for loader in loaders: + assert hasattr( + loader, "__next__" + ), "Loader {} has no __next__ method.".format(loader) + + if ratios is None: + ratios = [1.0] * len(loaders) + else: + assert len(ratios) == len(loaders) + ratios = [float(ratio) / sum(ratios) for ratio in ratios] + + self.loaders = loaders + self.ratios = ratios + + def __next__(self): + # random sample from each loader by ratio + loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] + return next(self.loaders[loader_idx]) + + +class PrefetchLoader(object): + """ + Modified from https://github.com/ChenRocks/UNITER. + + overlap compute and cuda data transfer + (copied and then modified from nvidia apex) + """ + + def __init__(self, loader): + self.loader = loader + self.stream = torch.cuda.Stream() + + def __iter__(self): + loader_it = iter(self.loader) + self.preload(loader_it) + batch = self.next(loader_it) + while batch is not None: + is_tuple = isinstance(batch, tuple) + if is_tuple: + task, batch = batch + + if is_tuple: + yield task, batch + else: + yield batch + batch = self.next(loader_it) + + def __len__(self): + return len(self.loader) + + def preload(self, it): + try: + self.batch = next(it) + except StopIteration: + self.batch = None + return + # if record_stream() doesn't work, another option is to make sure + # device inputs are created on the main stream. + # self.next_input_gpu = torch.empty_like(self.next_input, + # device='cuda') + # self.next_target_gpu = torch.empty_like(self.next_target, + # device='cuda') + # Need to make sure the memory allocated for next_* is not still in use + # by the main stream at the time we start copying to next_*: + # self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.batch = move_to_cuda(self.batch) + # more code for the alternative if record_stream() doesn't work: + # copy_ will record the use of the pinned source tensor in this + # side stream. + # self.next_input_gpu.copy_(self.next_input, non_blocking=True) + # self.next_target_gpu.copy_(self.next_target, non_blocking=True) + # self.next_input = self.next_input_gpu + # self.next_target = self.next_target_gpu + + def next(self, it): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is not None: + record_cuda_stream(batch) + self.preload(it) + return batch + + def __getattr__(self, name): + method = self.loader.__getattribute__(name) + return method + + +def record_cuda_stream(batch): + if isinstance(batch, torch.Tensor): + batch.record_stream(torch.cuda.current_stream()) + elif isinstance(batch, list) or isinstance(batch, tuple): + for t in batch: + record_cuda_stream(t) + elif isinstance(batch, dict): + for t in batch.values(): + record_cuda_stream(t) + else: + pass + + +class IterLoader: + """ + A wrapper to convert DataLoader as an infinite iterator. + + Modified from: + https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py + """ + + def __init__(self, dataloader: DataLoader, use_distributed: bool = False): + self._dataloader = dataloader + self.iter_loader = iter(self._dataloader) + self._use_distributed = use_distributed + self._epoch = 0 + + @property + def epoch(self) -> int: + return self._epoch + + def __next__(self): + try: + data = next(self.iter_loader) + except StopIteration: + self._epoch += 1 + if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: + self._dataloader.sampler.set_epoch(self._epoch) + time.sleep(2) # Prevent possible deadlock during epoch transition + self.iter_loader = iter(self._dataloader) + data = next(self.iter_loader) + + return data + + def __iter__(self): + return self + + def __len__(self): + return len(self._dataloader) diff --git a/bubogpt/datasets/datasets/image_caption/__init__.py b/bubogpt/datasets/datasets/image_caption/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bubogpt/datasets/datasets/image_caption/cc_sbu_dataset.py b/bubogpt/datasets/datasets/image_caption/cc_sbu_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..692d6071648d61b9014855d2f45442e4dcde035a --- /dev/null +++ b/bubogpt/datasets/datasets/image_caption/cc_sbu_dataset.py @@ -0,0 +1,68 @@ +import os +from PIL import Image +import webdataset as wds +from bubogpt.datasets.datasets.base_dataset import BaseDualDataset +from bubogpt.datasets.datasets.image_caption.image_caption_datasets import ImageCaptionDataset + + +class CCSBUDataset(BaseDualDataset): + def __init__(self, vision_processor, text_processor, location): + super().__init__(x_processor=vision_processor, text_processor=text_processor) + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.x_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + return { + "vision": sample[0], + "text_input": self.text_processor(sample[1]["caption"]), + } + + +class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset): + + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + img_file = '{}.jpg'.format(ann["image_id"]) + image_path = os.path.join(self.x_root, img_file) + image = Image.open(image_path).convert("RGB") + + image = self.x_processor(image) + caption = ann["caption"] + + return { + "vision": image, + "text_input": caption, + "image_id": self.img_ids[ann["image_id"]], + } + + +class CCDataset(BaseDualDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(x_processor=vis_processor, text_processor=text_processor) + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "txt", handler=wds.warn_and_continue), + wds.map_tuple(self.x_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + return { + "vision": sample[0], + "text_input": sample[1], + } diff --git a/bubogpt/datasets/datasets/image_caption/image_caption_datasets.py b/bubogpt/datasets/datasets/image_caption/image_caption_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..630c352d8a7b078b84612cea2445c074d852d91d --- /dev/null +++ b/bubogpt/datasets/datasets/image_caption/image_caption_datasets.py @@ -0,0 +1,73 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os + +from bubogpt.datasets.datasets.base_dataset import BaseDualDataset +from PIL import Image + +from bubogpt.datasets.datasets.mixins.mixins import __ImageDisplMixin + + +class ImageCaptionDataset(BaseDualDataset, __ImageDisplMixin): + def __init__(self, vision_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vision_processor, text_processor, vis_root, ann_paths) + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + img_file = '{:0>12}.jpg'.format(ann["image_id"]) + image_path = os.path.join(self.x_root, img_file) + image = Image.open(image_path).convert("RGB") + + image = self.x_processor(image) + caption = self.text_processor(ann["caption"]) + + return { + "vision": image, + "text_input": caption, + "image_id": self.img_ids[ann["image_id"]], + } + + +class CaptionEvalDataset(BaseDualDataset, __ImageDisplMixin): + def __init__(self, vision_processor, text_processor, x_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vision_processor, text_processor, x_root, ann_paths) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.x_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.x_processor(image) + + return { + "vision": image, + "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } \ No newline at end of file diff --git a/bubogpt/datasets/datasets/image_caption/laion_dataset.py b/bubogpt/datasets/datasets/image_caption/laion_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e4698c565fde00d822e8057883d7ad461f923e41 --- /dev/null +++ b/bubogpt/datasets/datasets/image_caption/laion_dataset.py @@ -0,0 +1,31 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import webdataset as wds +from bubogpt.datasets.datasets.base_dataset import BaseDualDataset + + +class LaionDataset(BaseDualDataset): + def __init__(self, vision_processor, text_processor, location): + super().__init__(x_processor=vision_processor, text_processor=text_processor) + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.x_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + return { + "vision": sample[0], + "text_input": self.text_processor(sample[1]["caption"]), + } + diff --git a/bubogpt/datasets/datasets/image_caption/llava_dataset.py b/bubogpt/datasets/datasets/image_caption/llava_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e7356fade5b6e1d8251804e793c4338c132ee789 --- /dev/null +++ b/bubogpt/datasets/datasets/image_caption/llava_dataset.py @@ -0,0 +1,72 @@ +import os +import json +import random +from PIL import Image +import webdataset as wds +from bubogpt.datasets.datasets.base_dataset import BaseDualDataset +from bubogpt.datasets.datasets.image_caption.image_caption_datasets import ImageCaptionDataset + + +class LlavaInstruct150Dataset(BaseDualDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(x_processor=vis_processor, text_processor=text_processor) + self.vis_root = vis_root + self.ann_paths = ann_paths + + self.data_list = data_list = [] + # for split in ["complex_reasoning_77k", "conversation_58k", "detail_23k"]: + # with open(os.path.join(vis_root, f'annotations/{split}.json'), 'r') as f: + # data_list.extend(json.load(f)) + for ann_path in ann_paths: + with open(ann_path) as f: + data_list.extend(json.load(f)) + + self.annotation = [] + for item in data_list: + image_id = item['id'] + conversations = item['conversations'] + for conv_id in range(len(conversations) //2 ): + question = conversations[2*conv_id]['value'] + answer = conversations[2 * conv_id+1]['value'] + self.annotation.append({'image_id':image_id, 'question':question, 'answer':answer}) + + # llava prompts + self.prompts = [ + " ", + " Quesion: ", + " A detail answer to the question is", + " Quesion: detail answer:", + " Based on the image, respond to this question with a detail answer: Answer:", + " Use the provided image to answer the question: ", + " What is the answer to the following question? ", + ] + print(f"==> {self.__class__.__name__} using prompts: ", "\n " + "\n ".join(self.prompts)) + # self.prompt_template = '###Human: {} ###Assistant: ' + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, "train2014/COCO_train2014_{:0>12}.jpg".format(ann["image_id"])) + image = Image.open(image_path).convert("RGB") + image = self.x_processor(image) + + question = ann['question'] + question = question.replace('\n', '').replace('\n', '') + # prompt = self.prompt_template.format(random.choice(self.prompts)) + prompt = random.choice(self.prompts) + prompt = prompt.replace('', question) + + return { + "vision": image, + "prompt": prompt, + "text_input": ann["answer"], + } + + def check_existence(self): + from tqdm import tqdm + for i in tqdm(range(len(self.data_list))): + image_id = self.data_list[i]["id"] + image_path = os.path.join(self.vis_root, "train2014/COCO_train2014_{:0>12}.jpg".format(image_id)) + if not os.path.exists(image_path): + print(f'Image does not exist: {image_path}') + print("Checking sucessful!") diff --git a/bubogpt/datasets/datasets/mixins/__init__.py b/bubogpt/datasets/datasets/mixins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bubogpt/datasets/datasets/mixins/mixins.py b/bubogpt/datasets/datasets/mixins/mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..1c677fdf8348632286ba22dc9f021ff7d3d2125b --- /dev/null +++ b/bubogpt/datasets/datasets/mixins/mixins.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + + +class __ImageDisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "caption": ann["caption"], + "vision": sample["vision"], + } + ) + + +class __AudioDisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + # TODO: Finish the Audio Display Mixin + ''' + return OrderedDict( + { + } + ) + ''' + + raise NotImplementedError + diff --git a/bubogpt/models/Qformer.py b/bubogpt/models/Qformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e71b12375e10511858a9c505dc795181e6ce5603 --- /dev/null +++ b/bubogpt/models/Qformer.py @@ -0,0 +1,1216 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bubogpt/models/__init__.py b/bubogpt/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d1c90893dc624306e08f660c12e0d553ab9ea927 --- /dev/null +++ b/bubogpt/models/__init__.py @@ -0,0 +1,200 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import torch +from omegaconf import OmegaConf + +from bubogpt.common.registry import registry +from bubogpt.models.base_model import BaseModel +from bubogpt.models.blip2 import Blip2Base +from bubogpt.processors.base_processor import BaseProcessor +from bubogpt.models.mm_gpt4 import MMGPT4 + + +__all__ = [ + "load_model", + "BaseModel", + "Blip2Base", + "MMGPT4" +] + + +def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): + """ + Load supported models. + + To list all available models and types in registry: + >>> from bubogpt.models import model_zoo + >>> print(model_zoo) + + Args: + name (str): name of the model. + model_type (str): type of the model. + is_eval (bool): whether the model is in eval mode. Default: False. + device (str): device to use. Default: "cpu". + checkpoint (str): path or to checkpoint. Default: None. + Note that expecting the checkpoint to have the same keys in state_dict as the model. + + Returns: + model (torch.nn.Module): model. + """ + + model = registry.get_model_class(name).from_pretrained(model_type=model_type) + + if checkpoint is not None: + model.load_checkpoint(checkpoint) + + if is_eval: + model.eval() + + if device == "cpu": + model = model.float() + + return model.to(device) + + +def load_preprocess(config): + """ + Load preprocessor configs and construct preprocessors. + + If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. + + Args: + config (dict): preprocessor configs. + + Returns: + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + + Key is "train" or "eval" for processors used in training and evaluation respectively. + """ + + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else BaseProcessor() + ) + + vis_processors = dict() + txt_processors = dict() + + vis_proc_cfg = config.get("vis_processor") + txt_proc_cfg = config.get("text_processor") + + if vis_proc_cfg is not None: + vis_train_cfg = vis_proc_cfg.get("train") + vis_eval_cfg = vis_proc_cfg.get("eval") + else: + vis_train_cfg = None + vis_eval_cfg = None + + vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) + vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) + + if txt_proc_cfg is not None: + txt_train_cfg = txt_proc_cfg.get("train") + txt_eval_cfg = txt_proc_cfg.get("eval") + else: + txt_train_cfg = None + txt_eval_cfg = None + + txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) + txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) + + return vis_processors, txt_processors + + +def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): + """ + Load model and its related preprocessors. + + List all available models and types in registry: + >>> from bubogpt.models import model_zoo + >>> print(model_zoo) + + Args: + name (str): name of the model. + model_type (str): type of the model. + is_eval (bool): whether the model is in eval mode. Default: False. + device (str): device to use. Default: "cpu". + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + """ + model_cls = registry.get_model_class(name) + + # load model + model = model_cls.from_pretrained(model_type=model_type) + + if is_eval: + model.eval() + + # load preprocess + cfg = OmegaConf.load(model_cls.default_config_path(model_type)) + if cfg is not None: + preprocess_cfg = cfg.preprocess + + vis_processors, txt_processors = load_preprocess(preprocess_cfg) + else: + vis_processors, txt_processors = None, None + logging.info( + f"""No default preprocess for model {name} ({model_type}). + This can happen if the model is not finetuned on downstream datasets, + or it is not intended for direct use without finetuning. + """ + ) + + if device == "cpu" or device == torch.device("cpu"): + model = model.float() + + return model.to(device), vis_processors, txt_processors + + +class ModelZoo: + """ + A utility class to create string representation of available model architectures and types. + + >>> from bubogpt.models import model_zoo + >>> # list all available models + >>> print(model_zoo) + >>> # show total number of models + >>> print(len(model_zoo)) + """ + + def __init__(self) -> None: + self.model_zoo = { + k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) + for k, v in registry.mapping["model_name_mapping"].items() + } + + def __str__(self) -> str: + return ( + "=" * 50 + + "\n" + + f"{'Architectures':<30} {'Types'}\n" + + "=" * 50 + + "\n" + + "\n".join( + [ + f"{name:<30} {', '.join(types)}" + for name, types in self.model_zoo.items() + ] + ) + ) + + def __iter__(self): + return iter(self.model_zoo.items()) + + def __len__(self): + return sum([len(v) for v in self.model_zoo.values()]) + + +model_zoo = ModelZoo() diff --git a/bubogpt/models/base_model.py b/bubogpt/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7e236381c6b395535cb609b9217c394345241445 --- /dev/null +++ b/bubogpt/models/base_model.py @@ -0,0 +1,247 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os + +import numpy as np +import torch +import torch.nn as nn +from bubogpt.common.dist_utils import download_cached_file, is_dist_avail_and_initialized +from bubogpt.common.utils import get_abs_path, is_url +from omegaconf import OmegaConf + + +class BaseModel(nn.Module): + """Base class for models.""" + + def __init__(self): + super().__init__() + + @property + def device(self): + return list(self.parameters())[0].device + + def load_checkpoint(self, url_or_filename): + """ + Load from a finetuned checkpoint. + + This should expect no mismatch in the model keys and the checkpoint keys. + """ + + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + if "model" in checkpoint.keys(): + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + msg = self.load_state_dict(state_dict, strict=False) + + logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + @classmethod + def from_pretrained(cls, model_type): + """ + Build a pretrained model from default configuration file, specified by model_type. + + Args: + - model_type (str): model type, specifying architecture and checkpoints. + + Returns: + - model (nn.Module): pretrained or finetuned model, depending on the configuration. + """ + model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model + model = cls.from_config(model_cfg) + + return model + + @classmethod + def default_config_path(cls, model_type): + assert ( + model_type in cls.PRETRAINED_MODEL_CONFIG_DICT + ), "Unknown model type {}".format(model_type) + return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) + + def load_checkpoint_from_config(self, cfg, **kwargs): + """ + Load checkpoint as specified in the config file. + + If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. + When loading the pretrained model, each task-specific architecture may define their + own load_from_pretrained() method. + """ + load_finetuned = cfg.get("load_finetuned", True) + if load_finetuned: + finetune_path = cfg.get("finetuned", None) + assert ( + finetune_path is not None + ), "Found load_finetuned is True, but finetune_path is None." + self.load_checkpoint(url_or_filename=finetune_path) + else: + # load pre-trained weights + pretrain_path = cfg.get("pretrained", None) + assert "Found load_finetuned is False, but pretrain_path is None." + self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) + + def before_evaluation(self, **kwargs): + pass + + def show_n_params(self, return_str=True): + tot = 0 + for p in self.parameters(): + w = 1 + for x in p.shape: + w *= x + tot += w + if return_str: + if tot >= 1e6: + return "{:.1f}M".format(tot / 1e6) + else: + return "{:.1f}K".format(tot / 1e3) + else: + return tot + + +class BaseEncoder(nn.Module): + """ + Base class for primitive encoders, such as ViT, TimeSformer, etc. + """ + + def __init__(self): + super().__init__() + + def forward_features(self, samples, **kwargs): + raise NotImplementedError + + @property + def device(self): + return list(self.parameters())[0].device + + +class SharedQueueMixin: + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): + # gather keys before updating queue + image_feats = concat_all_gather(image_feat) + text_feats = concat_all_gather(text_feat) + + batch_size = image_feats.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr : ptr + batch_size] = image_feats.T + self.text_queue[:, ptr : ptr + batch_size] = text_feats.T + + if idxs is not None: + idxs = concat_all_gather(idxs) + self.idx_queue[:, ptr : ptr + batch_size] = idxs.T + + ptr = (ptr + batch_size) % self.queue_size # move pointer + self.queue_ptr[0] = ptr + + +class MomentumDistilationMixin: + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip( + model_pair[0].parameters(), model_pair[1].parameters() + ): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip( + model_pair[0].parameters(), model_pair[1].parameters() + ): + param_m.data = param_m.data * self.momentum + param.data * ( + 1.0 - self.momentum + ) + + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [ + torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + torch.distributed.all_reduce(all_gradients) + return all_gradients[torch.distributed.get_rank()] + + +def all_gather_with_grad(tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = torch.distributed.get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + + # tensor_all = GatherLayer.apply(tensors) + tensor_all = GatherLayer.apply(tensors) + + return torch.cat(tensor_all, dim=0) + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + # if use distributed training + if not is_dist_avail_and_initialized(): + return tensor + + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) + ) + return torch.index_select(x, dim, order_index.to(x.device)) diff --git a/bubogpt/models/blip2.py b/bubogpt/models/blip2.py new file mode 100644 index 0000000000000000000000000000000000000000..054c4bdeebfcc209a678443bfd2d3268f6ade6f5 --- /dev/null +++ b/bubogpt/models/blip2.py @@ -0,0 +1,221 @@ +""" + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import contextlib +import logging +import os +import time +import datetime + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.nn.functional as F + +import bubogpt.common.dist_utils as dist_utils +from bubogpt.common.dist_utils import download_cached_file +from bubogpt.common.utils import is_url +from bubogpt.common.logger import MetricLogger +from bubogpt.models.base_model import BaseModel +from bubogpt.models.Qformer import BertConfig, BertLMHeadModel +from bubogpt.models.eva_vit import create_eva_vit_g +from transformers import BertTokenizer + + +class Blip2Base(BaseModel): + @classmethod + def init_tokenizer(cls): + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + return tokenizer + + def maybe_autocast(self, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = self.device != torch.device("cpu") + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + + @classmethod + def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens + + @classmethod + def init_vision_encoder( + cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision + ): + assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" + visual_encoder = create_eva_vit_g( + img_size, drop_path_rate, use_grad_checkpoint, precision + ) + + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + + msg = self.load_state_dict(state_dict, strict=False) + + # logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +def compute_sim_matrix(model, data_loader, **kwargs): + k_test = kwargs.pop("k_test") + + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation:" + + logging.info("Computing features for evaluation...") + start_time = time.time() + + texts = data_loader.dataset.text + num_text = len(texts) + text_bs = 256 + text_ids = [] + text_embeds = [] + text_atts = [] + for i in range(0, num_text, text_bs): + text = texts[i : min(num_text, i + text_bs)] + text_input = model.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=35, + return_tensors="pt", + ).to(model.device) + text_feat = model.forward_text(text_input) + text_embed = F.normalize(model.text_proj(text_feat)) + text_embeds.append(text_embed) + text_ids.append(text_input.input_ids) + text_atts.append(text_input.attention_mask) + + text_embeds = torch.cat(text_embeds, dim=0) + text_ids = torch.cat(text_ids, dim=0) + text_atts = torch.cat(text_atts, dim=0) + + vit_feats = [] + image_embeds = [] + for samples in data_loader: + image = samples["vision"] + + image = image.to(model.device) + image_feat, vit_feat = model.forward_image(image) + image_embed = model.vision_proj(image_feat) + image_embed = F.normalize(image_embed, dim=-1) + + vit_feats.append(vit_feat.cpu()) + image_embeds.append(image_embed) + + vit_feats = torch.cat(vit_feats, dim=0) + image_embeds = torch.cat(image_embeds, dim=0) + + sims_matrix = [] + for image_embed in image_embeds: + sim_q2t = image_embed @ text_embeds.t() + sim_i2t, _ = sim_q2t.max(0) + sims_matrix.append(sim_i2t) + sims_matrix = torch.stack(sims_matrix, dim=0) + + score_matrix_i2t = torch.full( + (len(data_loader.dataset.image), len(texts)), -100.0 + ).to(model.device) + + num_tasks = dist_utils.get_world_size() + rank = dist_utils.get_rank() + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) + score = model.compute_itm( + image_inputs=image_inputs, + text_ids=text_ids[topk_idx], + text_atts=text_atts[topk_idx], + ).float() + score_matrix_i2t[start + i, topk_idx] = score + topk_sim + + sims_matrix = sims_matrix.t() + score_matrix_t2i = torch.full( + (len(texts), len(data_loader.dataset.image)), -100.0 + ).to(model.device) + + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + image_inputs = vit_feats[topk_idx.cpu()].to(model.device) + score = model.compute_itm( + image_inputs=image_inputs, + text_ids=text_ids[start + i].repeat(k_test, 1), + text_atts=text_atts[start + i].repeat(k_test, 1), + ).float() + score_matrix_t2i[start + i, topk_idx] = score + topk_sim + + if dist_utils.is_dist_avail_and_initialized(): + dist.barrier() + torch.distributed.all_reduce( + score_matrix_i2t, op=torch.distributed.ReduceOp.SUM + ) + torch.distributed.all_reduce( + score_matrix_t2i, op=torch.distributed.ReduceOp.SUM + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Evaluation time {}".format(total_time_str)) + + return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() diff --git a/bubogpt/models/blip2_outputs.py b/bubogpt/models/blip2_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..e8722b1fedaec1e31e39d8c80f911b8ff79bbb75 --- /dev/null +++ b/bubogpt/models/blip2_outputs.py @@ -0,0 +1,110 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.modeling_outputs import ( + ModelOutput, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) + + +@dataclass +class BlipSimilarity(ModelOutput): + sim_i2t: torch.FloatTensor = None + sim_t2i: torch.FloatTensor = None + + sim_i2t_m: Optional[torch.FloatTensor] = None + sim_t2i_m: Optional[torch.FloatTensor] = None + + sim_i2t_targets: Optional[torch.FloatTensor] = None + sim_t2i_targets: Optional[torch.FloatTensor] = None + + +@dataclass +class BlipIntermediateOutput(ModelOutput): + """ + Data class for intermediate outputs of BLIP models. + + image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). + text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). + + image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). + text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). + + encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. + encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. + + decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. + decoder_labels (torch.LongTensor): labels for the captioning loss. + + itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). + itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) + + """ + + # uni-modal features + image_embeds: torch.FloatTensor = None + text_embeds: Optional[torch.FloatTensor] = None + + image_embeds_m: Optional[torch.FloatTensor] = None + text_embeds_m: Optional[torch.FloatTensor] = None + + # intermediate outputs of multimodal encoder + encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + + itm_logits: Optional[torch.FloatTensor] = None + itm_labels: Optional[torch.LongTensor] = None + + # intermediate outputs of multimodal decoder + decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None + decoder_labels: Optional[torch.LongTensor] = None + + +@dataclass +class BlipOutput(ModelOutput): + # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. + sims: Optional[BlipSimilarity] = None + + intermediate_output: BlipIntermediateOutput = None + + loss: Optional[torch.FloatTensor] = None + + loss_itc: Optional[torch.FloatTensor] = None + + loss_itm: Optional[torch.FloatTensor] = None + + loss_lm: Optional[torch.FloatTensor] = None + + +@dataclass +class BlipOutputFeatures(ModelOutput): + """ + Data class of features from BlipFeatureExtractor. + + Args: + image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional + image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional + text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional + text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional + + The first embedding or feature is for the [CLS] token. + + Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. + """ + + image_embeds: Optional[torch.FloatTensor] = None + image_embeds_proj: Optional[torch.FloatTensor] = None + + text_embeds: Optional[torch.FloatTensor] = None + text_embeds_proj: Optional[torch.FloatTensor] = None + + multimodal_embeds: Optional[torch.FloatTensor] = None diff --git a/bubogpt/models/eva_vit.py b/bubogpt/models/eva_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..74d241d18e9cfb07934ba4fc4ce7f84990c2b1f2 --- /dev/null +++ b/bubogpt/models/eva_vit.py @@ -0,0 +1,442 @@ +# Based on EVA, BEIT, timm and DeiT code bases +# https://github.com/baaivision/EVA +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/facebookresearch/deit/ +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + +from bubogpt.common.dist_utils import download_cached_file + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + **kwargs + } + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias=None): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, + use_mean_pooling=True, init_scale=0.001, use_checkpoint=False): + super().__init__() + self.image_size = img_size + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + self.use_checkpoint = use_checkpoint + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) + for i in range(depth)]) +# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) +# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None +# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + # trunc_normal_(self.mask_token, std=.02) +# if isinstance(self.head, nn.Linear): +# trunc_normal_(self.head.weight, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() +# if isinstance(self.head, nn.Linear): +# self.head.weight.data.mul_(init_scale) +# self.head.bias.data.mul_(init_scale) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, rel_pos_bias) + else: + x = blk(x, rel_pos_bias) + return x +# x = self.norm(x) + +# if self.fc_norm is not None: +# t = x[:, 1:, :] +# return self.fc_norm(t.mean(1)) +# else: +# return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) +# x = self.head(x) + return x + + def get_intermediate_layers(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + features = [] + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias) + features.append(x) + + return features + + +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'].float() + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + +# if isinstance(l, (nn.MultiheadAttention, Attention)): +# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: +# tensor = getattr(l, attr) +# if tensor is not None: +# tensor.data = tensor.data.half() + + model.apply(_convert_weights_to_fp16) + + +def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"): + model = VisionTransformer( + img_size=img_size, + patch_size=14, + use_mean_pooling=False, + embed_dim=1408, + depth=39, + num_heads=1408//88, + mlp_ratio=4.3637, + qkv_bias=True, + drop_path_rate=drop_path_rate, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + use_checkpoint=use_checkpoint, + ) + url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" + cached_file = download_cached_file( + url, check_hash=False, progress=True + ) + state_dict = torch.load(cached_file, map_location="cpu") + interpolate_pos_embed(model,state_dict) + + incompatible_keys = model.load_state_dict(state_dict, strict=False) +# print(incompatible_keys) + + if precision == "fp16": +# model.to("cuda") + convert_weights_to_fp16(model) + return model \ No newline at end of file diff --git a/bubogpt/models/mm_gpt4.py b/bubogpt/models/mm_gpt4.py new file mode 100644 index 0000000000000000000000000000000000000000..c92b51ab59d61631224e417e2f2ced99883f3302 --- /dev/null +++ b/bubogpt/models/mm_gpt4.py @@ -0,0 +1,307 @@ +import random +from typing import Dict, Tuple, List, Union + +import torch +import torch.nn as nn +import re +from torch import Tensor +from transformers import LlamaTokenizer +from omegaconf import DictConfig + +from imagebind.models.image_bind import imagebind_huge, ImageBindJoiner, ModalityType, replace_joiner_vision +from bubogpt.common.registry import registry +from bubogpt.models.blip2 import BaseModel +from bubogpt.models.modeling_llama import LlamaForCausalLM + + +def filter_prompt(input_embeds: Dict[str, Tensor], prompt_list: List[str]) -> List[str]: + if not prompt_list: + return prompt_list + input_modal_set = set([k.title() for k in input_embeds if input_embeds[k] is not None]) + prompt_modal_sets = [set(re.findall("<([^<>]+)>", prompt)) for prompt in prompt_list] + results = [prompt_list[i] for i, prompt_modal_set in enumerate(prompt_modal_sets) if + prompt_modal_set == input_modal_set] + return results + + +def arrange_modalities(input_embeds: Dict[str, Tensor], prompt: str) -> List[Tensor]: + prompt_modalities = re.findall("<([^<>]+)>", prompt) + return [input_embeds[modality.lower()] for modality in prompt_modalities] + + +def concat_all_embeddings(input_embeds: Dict[str, Tensor], dim: int) -> Tensor: + embeds = [input_embeds[key] for key in input_embeds if input_embeds[key] is not None] + return torch.cat(embeds, dim=dim) + + +def filter_modalities(inputs): + filtered_inputs = {} + + for k in ModalityType.__dict__.values(): + if k in inputs: + filtered_inputs[k] = inputs[k] + + return filtered_inputs + + +@registry.register_model("mm_gpt4") +class MMGPT4(BaseModel): + """ + ImageBind GPT-LLAMA model. + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_vicuna": "configs/models/mmgpt4.yaml", + } + + def __init__( + self, + joiner_cfg: DictConfig, + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + freeze_imagebind=True, + freeze_qformer=False, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=128, + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + with_bind_head=False, + freeze_llm=True, + use_blip_vision=False, + proj_model="", + ): + super().__init__() + assert not low_resource, "Low Resource Mode is Currently Unavailable." + + self.low_resource = low_resource + + print('Loading ImageBind') + self.multimodal_encoder = imagebind_huge(pretrained=True, freeze_imagebind=freeze_imagebind, + with_head=with_bind_head, use_blip_vision=use_blip_vision) + print('Loading ImageBind Done') + + print(f'Loading LLAMA from {llama_model}') + self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) + self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token + + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float16, + ) + + if freeze_llm: + for name, param in self.llama_model.named_parameters(): + param.requires_grad = False + print('Loading LLAMA Done') + + print('Loading Q-Former and Adapter/Projector') + self.multimodal_joiner = ImageBindJoiner(joiner_cfg, output_dim=self.llama_model.config.hidden_size) + if use_blip_vision: + replace_joiner_vision(self.multimodal_joiner, q_former_model, proj_model) + print('Loading Q-Former and Adapter/Projector Done') + + self.max_txt_len = max_txt_len + self.end_sym = end_sym + + print("Preparing Prompts") + self.prompt_template = prompt_template + if prompt_path: + with open(prompt_path, 'r') as f: + raw_prompts = f.read().splitlines() + self.prompt_list = [prompt_template.format(p) for p in raw_prompts] + print('Load {} training prompts'.format(len(self.prompt_list))) + print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) + else: + self.prompt_list = [] + print("Preparing Prompts Done") + + def maybe_autocast(self, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = self.device != torch.device("cpu") + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + import contextlib + return contextlib.nullcontext() + + def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + with self.maybe_autocast(): + imagebind_outputs = self.multimodal_encoder(inputs) + llama_inputs = self.multimodal_joiner(imagebind_outputs) + return llama_inputs + + def prompt_wrap(self, inputs: Dict[str, Tensor], prompt: Union[str, list]) -> Tuple[Tensor, Tensor]: + if isinstance(prompt, (list, tuple)): + bs = list(inputs.values())[0].shape[0] + assert bs == len(prompt) + + return self.batch_prompt_wrap(inputs, prompt) + elif isinstance(prompt, (str, type(None))): + return self.single_prompt_wrap(inputs, prompt) + else: + raise NotImplementedError(f"Prompt type: {type(prompt)} not supported.") + + def single_prompt_wrap(self, inputs: Dict[str, Tensor], prompt: str) -> Tuple[Tensor, Tensor]: + if not prompt: + input_embeds = concat_all_embeddings(inputs, dim=1) + attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device) + return input_embeds, attns_input + input_embeds_list = arrange_modalities(inputs, prompt) + batch_size = input_embeds_list[0].shape[0] + prompt_slices = prompt.split('') + prompt_tokens = [self.llama_tokenizer(prompt_slice, return_tensors="pt", add_special_tokens=False) + .to(input_embeds_list[0].device) for prompt_slice in prompt_slices] + prompt_embeds = [self.llama_model.model.embed_tokens(prompt_token.input_ids).expand(batch_size, -1, -1) + for prompt_token in prompt_tokens] + result_embeds = [emb for pair in zip(prompt_embeds[:-1], input_embeds_list) + for emb in pair] + [prompt_embeds[-1]] + wrapped_input_embeds = torch.cat(result_embeds, dim=1) + wrapped_atts_input = torch.ones(wrapped_input_embeds.size()[:-1], + dtype=torch.long).to(wrapped_input_embeds.device) + return wrapped_input_embeds, wrapped_atts_input + + def batch_prompt_wrap(self, inputs: Dict[str, Tensor], prompts: List[str]) -> Tuple[Tensor, Tensor]: + device = list(inputs.values())[0].device + # This one only works for visual prompting + prompt_slices = [prompt.split('') for prompt in prompts] + slice_batch = list(zip(*prompt_slices)) + + prompt_tokens = [self.llama_tokenizer(slice, + return_tensors="pt", + add_special_tokens=False, + padding="longest", + truncation=True, + max_length=self.max_txt_len).to(device) + for slice in slice_batch] + prompt_embeds = [self.llama_model.model.embed_tokens(prompt_token.input_ids) for prompt_token in prompt_tokens] + prompt_masks = [prompt_token.attention_mask for prompt_token in prompt_tokens] + + # NOTE: assuming moalities are the same within a batch + input_embeds_list = arrange_modalities(inputs, prompts[0]) + input_mask_list = [torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(device) for input_embeds in input_embeds_list] + result_embeds = [emb for pair in zip(prompt_embeds[:-1], input_embeds_list) for emb in pair] + [prompt_embeds[-1]] + result_masks = [mask for pair in zip(prompt_masks[:-1], input_mask_list) for mask in pair] + [prompt_masks[-1]] + wrapped_input_embeds = torch.cat(result_embeds, dim=1) + wrapped_atts_input = torch.cat(result_masks, dim=1) + return wrapped_input_embeds, wrapped_atts_input + + def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + # filter `inputs` as it may contain informatioins other than modalities + modality_inputs = filter_modalities(inputs) + embeds = self.encode_inputs(modality_inputs) + filtered_prompts = filter_prompt(embeds, self.prompt_list) + if "prompt" in inputs: + assert isinstance(inputs["prompt"], (list, tuple)) + prompt = [self.prompt_template.format(p) for p in inputs["prompt"]] + elif filtered_prompts: + prompt = random.choice(filtered_prompts) + else: + prompt = None + # NOTE&TODO: add support for a list of prompts + input_embs, input_atts = self.prompt_wrap(embeds, prompt) + + # NOTE: No modifications from the next line to the end. Except for the autocast part. + + self.llama_tokenizer.padding_side = "right" + + text = [t + self.end_sym for t in inputs["text_input"]] + + to_regress_tokens = self.llama_tokenizer( + text, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False + ).to(input_embs.device) + + targets = to_regress_tokens.input_ids.masked_fill( + to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 + ) + + empty_targets = ( + torch.ones([input_atts.shape[0], input_atts.shape[1] + 1], + dtype=torch.long).to(input_embs.device).fill_(-100) # plus one for bos + ) + targets = torch.cat([empty_targets, targets], dim=1) + + batch_size = input_embs.shape[0] + bos = torch.ones([batch_size, 1], + dtype=to_regress_tokens.input_ids.dtype, + device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens(bos) + atts_bos = input_atts[:, :1] + + to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) + inputs_embeds = torch.cat([bos_embeds, input_embs, to_regress_embeds], dim=1) + attention_mask = torch.cat([atts_bos, input_atts, to_regress_tokens.attention_mask], dim=1) + + with self.maybe_autocast(): + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + + return {"loss": loss} + + @classmethod + def from_config(cls, cfg): + joiner_cfg = cfg.get("joiner_cfg") + q_former_model = cfg.get( + "q_former_model", + "/mnt/bn/bykang/chixma/data/pretrained_models/blip2_pretrained_flant5xxl.pth", + ) + num_query_token = cfg.get("num_query_token") + llama_model = cfg.get("llama_model") + + freeze_imagebind = cfg.get("freeze_imagebind", True) + freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) + device_8bit = cfg.get("device_8bit", 0) + + prompt_path = cfg.get("prompt_path", "") + prompt_template = cfg.get("prompt_template", "") + max_txt_len = cfg.get("max_txt_len", 128) + end_sym = cfg.get("end_sym", '\n') + with_bind_head = cfg.get("with_bind_head", False) + freeze_llm = cfg.get("freeze_llm", True) + use_blip_vision = cfg.get("use_blip_vision", False) + proj_model = cfg.get("proj_model", "checkpoints/prerained_minigpt4_7b.pth") + + model = cls( + joiner_cfg=joiner_cfg, + q_former_model=q_former_model, + freeze_imagebind=freeze_imagebind, + freeze_qformer=freeze_qformer, + num_query_token=num_query_token, + llama_model=llama_model, + prompt_path=prompt_path, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + end_sym=end_sym, + low_resource=low_resource, + device_8bit=device_8bit, + with_bind_head=with_bind_head, + freeze_llm=freeze_llm, + use_blip_vision=use_blip_vision, + proj_model=proj_model, + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + if isinstance(ckpt_path, str): + ckpt_path = [ckpt_path] + for cur_ckpt_path in ckpt_path: + print("Load ImageBind-LLM Checkpoint: {}".format(cur_ckpt_path)) + ckpt = torch.load(cur_ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + + return model \ No newline at end of file diff --git a/bubogpt/models/modeling_llama.py b/bubogpt/models/modeling_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..12d980e189d902fb1a6d9ea05dc3ca91959b1c8c --- /dev/null +++ b/bubogpt/models/modeling_llama.py @@ -0,0 +1,755 @@ +# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + +""" PyTorch LLaMA model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.models.llama.configuration_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if query_embeds is not None: + inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1) + batch_size, seq_length, _ = inputs_embeds.shape + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + query_embeds = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "query_embeds": query_embeds, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + diff --git a/bubogpt/processors/__init__.py b/bubogpt/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1bb0171393b0e2af7f05de671a02e6946edd5ba --- /dev/null +++ b/bubogpt/processors/__init__.py @@ -0,0 +1,47 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from bubogpt.processors.base_processor import BaseProcessor +from bubogpt.processors.blip_processors import ( + Blip2ImageTrainProcessor, + Blip2ImageEvalProcessor, + BlipCaptionProcessor, +) +from bubogpt.processors.imagebind_vision_processor import ( + ImageBindCaptionProcessor, + ImageBindVisionTrainProcessor, + ImageBindVisionEvalProcessor +) +from bubogpt.processors.imagebind_audio_processor import ( + ImageBindAudioTrainProcessor, + ImageBindAudioEvalProcessor, +) + +from bubogpt.common.registry import registry + +__all__ = [ + "BaseProcessor", + "Blip2ImageTrainProcessor", + "Blip2ImageEvalProcessor", + "BlipCaptionProcessor", + "ImageBindCaptionProcessor", + "ImageBindVisionTrainProcessor", + "ImageBindVisionEvalProcessor", + "ImageBindAudioTrainProcessor", + "ImageBindAudioEvalProcessor", +] + + +def load_processor(name, cfg=None): + """ + Example + + >>> processor = load_processor("alpro_video_train", cfg=None) + """ + processor = registry.get_processor_class(name).from_config(cfg) + + return processor diff --git a/bubogpt/processors/audio_augment.py b/bubogpt/processors/audio_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..869db2ec71c8da01fa7d0f707cedce2d4d8fa738 --- /dev/null +++ b/bubogpt/processors/audio_augment.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# @Author : Xinhao Mei @CVSSP, University of Surrey +# @E-mail : x.mei@surrey.ac.uk + +""" + Implemenation of SpecAugment++, + Adapated from Qiuqiang Kong's trochlibrosa: + https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py +""" + +import torch +import torch.nn as nn + + +class DropStripes: + + def __init__(self, dim, drop_width, stripes_num): + """ Drop stripes. + args: + dim: int, dimension along which to drop + drop_width: int, maximum width of stripes to drop + stripes_num: int, how many stripes to drop + """ + super(DropStripes, self).__init__() + + assert dim in [2, 3] # dim 2: time; dim 3: frequency + + self.dim = dim + self.drop_width = drop_width + self.stripes_num = stripes_num + + def __call__(self, input): + """input: (batch_size, channels, time_steps, freq_bins)""" + + assert input.ndimension() == 4 + batch_size = input.shape[0] + total_width = input.shape[self.dim] + + for n in range(batch_size): + self.transform_slice(input[n], total_width) + + return input + + def transform_slice(self, e, total_width): + """ e: (channels, time_steps, freq_bins)""" + + for _ in range(self.stripes_num): + distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0] + bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] + + if self.dim == 2: + e[:, bgn: bgn + distance, :] = 0 + elif self.dim == 3: + e[:, :, bgn: bgn + distance] = 0 + + +class MixStripes: + + def __init__(self, dim, mix_width, stripes_num): + """ Mix stripes + args: + dim: int, dimension along which to mix + mix_width: int, maximum width of stripes to mix + stripes_num: int, how many stripes to mix + """ + + super(MixStripes, self).__init__() + + assert dim in [2, 3] + + self.dim = dim + self.mix_width = mix_width + self.stripes_num = stripes_num + + def __call__(self, input): + """input: (batch_size, channel, time_steps, freq_bins)""" + + assert input.ndimension() == 4 + + batch_size = input.shape[0] + total_width = input.shape[self.dim] + + rand_sample = input[torch.randperm(batch_size)] + for i in range(batch_size): + self.transform_slice(input[i], rand_sample[i], total_width) + return input + + def transform_slice(self, input, random_sample, total_width): + + for _ in range(self.stripes_num): + distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0] + bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] + + if self.dim == 2: + input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \ + 0.5 * random_sample[:, bgn: bgn + distance, :] + elif self.dim == 3: + input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \ + 0.5 * random_sample[:, :, bgn: bgn + distance] + + +class CutStripes: + + def __init__(self, dim, cut_width, stripes_num): + """ Cutting stripes with another randomly selected sample in mini-batch. + args: + dim: int, dimension along which to cut + cut_width: int, maximum width of stripes to cut + stripes_num: int, how many stripes to cut + """ + + super(CutStripes, self).__init__() + + assert dim in [2, 3] + + self.dim = dim + self.cut_width = cut_width + self.stripes_num = stripes_num + + def __call__(self, input): + """input: (batch_size, channel, time_steps, freq_bins)""" + + assert input.ndimension() == 4 + + batch_size = input.shape[0] + total_width = input.shape[self.dim] + + rand_sample = input[torch.randperm(batch_size)] + for i in range(batch_size): + self.transform_slice(input[i], rand_sample[i], total_width) + return input + + def transform_slice(self, input, random_sample, total_width): + + for _ in range(self.stripes_num): + distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0] + bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] + + if self.dim == 2: + input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :] + elif self.dim == 3: + input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance] + + +class SpecAugmentation: + + def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num, + mask_type='mixture'): + """Spec augmetation and SpecAugment++. + [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D. + and Le, Q.V., 2019. Specaugment: A simple data augmentation method + for automatic speech recognition. arXiv preprint arXiv:1904.08779. + [ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space + Data Augmentation Method for Acoustic Scene Classification. arXiv + preprint arXiv:2103.16858. + + Args: + time_drop_width: int + time_stripes_num: int + freq_drop_width: int + freq_stripes_num: int + mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting) + """ + + super(SpecAugmentation, self).__init__() + + if mask_type == 'zero_value': + self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width, + stripes_num=time_stripes_num) + self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width, + stripes_num=freq_stripes_num) + elif mask_type == 'mixture': + self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width, + stripes_num=time_stripes_num) + self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width, + stripes_num=freq_stripes_num) + elif mask_type == 'cutting': + self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width, + stripes_num=time_stripes_num) + self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width, + stripes_num=freq_stripes_num) + else: + raise NameError('No such mask type in SpecAugment++') + + def __call__(self, inputs): + # x should be in size [batch_size, channel, time_steps, freq_bins] + x = self.time_augmentator(inputs) + x = self.freq_augmentator(x) + return x diff --git a/bubogpt/processors/base_processor.py b/bubogpt/processors/base_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..39b33cdf8fcd97cfd3e4a5fbece6593357af9d41 --- /dev/null +++ b/bubogpt/processors/base_processor.py @@ -0,0 +1,26 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from omegaconf import OmegaConf + + +class BaseProcessor: + def __init__(self): + self.transform = lambda x: x + return + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + return cls() + + def build(self, **kwargs): + cfg = OmegaConf.create(kwargs) + + return self.from_config(cfg) diff --git a/bubogpt/processors/blip_processors.py b/bubogpt/processors/blip_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..0c4ac341b4e31e02aefa841678bf857bf1cdf990 --- /dev/null +++ b/bubogpt/processors/blip_processors.py @@ -0,0 +1,141 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import re + +from bubogpt.common.registry import registry +from bubogpt.processors.base_processor import BaseProcessor +from bubogpt.processors.vision_augment import RandomAugment +from omegaconf import OmegaConf +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + + +class BlipImageBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms.Normalize(mean, std) + + +@registry.register_processor("blip_caption") +class BlipCaptionProcessor(BaseProcessor): + def __init__(self, prompt="", max_words=50): + self.prompt = prompt + self.max_words = max_words + + def __call__(self, caption): + caption = self.prompt + self.pre_caption(caption) + + return caption + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + prompt = cfg.get("prompt", "") + max_words = cfg.get("max_words", 50) + + return cls(prompt=prompt, max_words=max_words) + + def pre_caption(self, caption): + caption = re.sub( + r"([.!\"()*#:;~])", + " ", + caption.lower(), + ) + caption = re.sub( + r"\s{2,}", + " ", + caption, + ) + caption = caption.rstrip("\n") + caption = caption.strip(" ") + + # truncate caption + caption_words = caption.split(" ") + if len(caption_words) > self.max_words: + caption = " ".join(caption_words[: self.max_words]) + + return caption + + +@registry.register_processor("blip2_image_train") +class Blip2ImageTrainProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + image_size, + scale=(min_scale, max_scale), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + min_scale = cfg.get("min_scale", 0.5) + max_scale = cfg.get("max_scale", 1.0) + + return cls( + image_size=image_size, + mean=mean, + std=std, + min_scale=min_scale, + max_scale=max_scale, + ) + + +@registry.register_processor("blip2_image_eval") +class Blip2ImageEvalProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls(image_size=image_size, mean=mean, std=std) \ No newline at end of file diff --git a/bubogpt/processors/imagebind_audio_processor.py b/bubogpt/processors/imagebind_audio_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c299dd01d364eb7a22f71f3ff63bddd43285df60 --- /dev/null +++ b/bubogpt/processors/imagebind_audio_processor.py @@ -0,0 +1,187 @@ +import math +from typing import Union, List + +import torch +import torchaudio +from omegaconf import OmegaConf +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler +from torch import Tensor +from torch_time_stretch import time_stretch + +from imagebind.data.data_utils import waveform2melspec, get_constant_clip_timepoints, \ + get_random_clip_timepoints +from bubogpt.datasets.data_utils import move_to_cuda, move_to_cpu +from bubogpt.processors.base_processor import BaseProcessor +from torchvision import transforms +from bubogpt.common.registry import registry +from bubogpt.processors.audio_augment import SpecAugmentation + + +class ImageBindAudioBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None, + num_mel_bins=None, target_length=None, clip_sample_method="Random", use_global=False): + super().__init__() + self.mean = -4.268 if mean is None else mean + self.std = 9.138 if std is None else std + self.target_sr = 16000 if target_sr is None else target_sr + self.num_mel_bins = num_mel_bins + self.target_length = target_length + self.clip_duration = clip_duration + self.clip_sampler = self._construct_clip_sampler(clip_duration, clips_per_video, clip_sample_method) + self.normalize = transforms.Normalize(self.mean, self.std) + self.use_global = use_global + + def _construct_clip_sampler(self, clip_duration, clips_per_video, clip_sample_method): + if clip_duration is None or clips_per_video is None: + return None + if clip_sample_method == "Constant": + return ConstantClipsPerVideoSampler( + clip_duration=clip_duration, clips_per_video=clips_per_video + ) + elif clip_sample_method == "Random": + return RandomMultiClipSampler(clip_duration=clip_duration, num_clips=clips_per_video) + else: + raise NotImplementedError + + def waveform_resample(self, waveform: Tensor, origin_sr: int) -> Tensor: + waveform = torchaudio.functional.resample(waveform, orig_freq=origin_sr, new_freq=self.target_sr) + all_duration = waveform.size(1) / self.target_sr + num_repeat = self.clip_duration / all_duration + if num_repeat < 1: # all duration > clip duration + return waveform + flatten_waves = torch.tile(waveform, dims=[1, int(num_repeat) + 1]) # [1, N * L] + return flatten_waves[:, :self.clip_duration * self.target_sr] + + def global_stretching(self, waveform: Tensor) -> Tensor: + # NOTE: directly applying "waveform[:, ::shrink_ratio]" is FORBIDDEN! + # NOTE: May be Deprecated, TOO SLOW. + # shrink_ratio = self.clip_duration * self.target_sr / waveform.size(1) + # return move_to_cpu(time_stretch(move_to_cuda(waveform.unsqueeze(0)), shrink_ratio, self.target_sr)[0]) + return waveform + + def clip_sample(self, waveform: Tensor) -> List[Tensor]: + if self.clip_sampler is None: + return [waveform] + elif isinstance(self.clip_sampler, ConstantClipsPerVideoSampler): + all_clips_timepoints = get_constant_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr) + elif isinstance(self.clip_sampler, RandomMultiClipSampler): + all_clips_timepoints = get_random_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr) + else: + raise NotImplementedError + all_clips = [] + for clip_timepoints in all_clips_timepoints: + start_pos = int(clip_timepoints[0] * self.target_sr) + end_pos = int(clip_timepoints[1] * self.target_sr) + waveform_clip = waveform[:, start_pos: end_pos] + all_clips.append(waveform_clip) + return all_clips + + def waveform_melspec(self, waveforms: Union[List[Tensor], Tensor]) -> List[Tensor]: + if isinstance(waveforms, Tensor): + return waveform2melspec(waveforms, self.target_sr, self.num_mel_bins, self.target_length) + else: + return [waveform2melspec(waveform, self.target_sr, self.num_mel_bins, self.target_length) + for waveform in waveforms] + + +@registry.register_processor("imagebind_audio_train") +class ImageBindAudioTrainProcessor(ImageBindAudioBaseProcessor): + def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None, + clip_sample_method="Random", use_global=False, num_mel_bins=None, target_length=None, + time_drop_width=13, time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2, + mask_type='mixture'): + super().__init__(mean=mean, std=std, target_sr=target_sr, + clip_duration=clip_duration, clips_per_video=clips_per_video, + num_mel_bins=num_mel_bins, target_length=target_length, + clip_sample_method=clip_sample_method, use_global=use_global) + self.spec_augment = SpecAugmentation(time_drop_width, time_stripes_num, + freq_drop_width, freq_stripes_num, mask_type) + + def __call__(self, item): + # item: Tuple[Tensor, int] + waveform, origin_sr = item[0], item[1] + waveform = self.waveform_resample(waveform, origin_sr) + waveform_clips = self.clip_sample(waveform) + if self.use_global: + waveform_clips.append(self.global_stretching(waveform)) + melspec_clips = self.waveform_melspec(waveform_clips) + normed_melspecs = [self.normalize(clip) for clip in melspec_clips] + all_clips = torch.stack(normed_melspecs, dim=0) + # all_clips: [clips_per_video, channel, mel_bins, time_steps] + # augment: [batch_size, channel, time_steps, freq_bins] + augmented_clips = self.spec_augment(all_clips.transpose(-2, -1)).transpose(-2, -1) + return augmented_clips + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + target_sr = cfg.get("target_sr", 16000) + clip_duration = cfg.get("clip_duration", None) + clips_per_video = cfg.get("clips_per_video", None) + num_mel_bins = cfg.get("num_mel_bins", 128) + target_length = cfg.get("target_length", 204) + time_drop_width = cfg.get("time_drop_width", 13) + time_stripes_num = cfg.get("time_stripes_num", 2) + # 13 * 2 / 204 = 12.75% Time Mask + freq_drop_width = cfg.get("freq_drop_width", 8) + freq_stripes_num = cfg.get("freq_stripes_num", 2) + # 8 * 2 / 128 = 12.5% Freq Mask + mask_type = cfg.get("mask_type", 'mixture') + use_global = cfg.get("use_global", False) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls( + mean=mean, std=std, target_sr=target_sr, + clip_duration=clip_duration, clips_per_video=clips_per_video, + num_mel_bins=num_mel_bins, target_length=target_length, + time_drop_width=time_drop_width, time_stripes_num=time_stripes_num, + freq_drop_width=freq_drop_width, freq_stripes_num=freq_stripes_num, + mask_type=mask_type, use_global=use_global + ) + + +@registry.register_processor("imagebind_audio_eval") +class ImageBindAudioEvalProcessor(ImageBindAudioBaseProcessor): + def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None, + clip_sample_method="Constant", use_global=False, num_mel_bins=None, target_length=None): + super().__init__(mean=mean, std=std, target_sr=target_sr, + clip_duration=clip_duration, clips_per_video=clips_per_video, + num_mel_bins=num_mel_bins, target_length=target_length, + clip_sample_method=clip_sample_method, use_global=use_global) + + def __call__(self, item): + # item: Tuple[Tensor, int] + waveform, origin_sr = item[0], item[1] + waveform = self.waveform_resample(waveform, origin_sr) + waveform_clips = self.clip_sample(waveform) + if self.use_global: + waveform_clips.append(self.global_stretching(waveform)) + melspec_clips = self.waveform_melspec(waveform_clips) + normed_melspecs = [self.normalize(clip) for clip in melspec_clips] + all_clips = torch.stack(normed_melspecs, dim=0) + # all_clips: [clips_per_video, channel, mel_bins, time_steps] + return all_clips + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + target_sr = cfg.get("target_sr", 16000) + clip_duration = cfg.get("clip_duration", None) + clips_per_video = cfg.get("clips_per_video", None) + num_mel_bins = cfg.get("num_mel_bins", 128) + target_length = cfg.get("target_length", 204) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls( + mean=mean, std=std, target_sr=target_sr, + clip_duration=clip_duration, clips_per_video=clips_per_video, + num_mel_bins=num_mel_bins, target_length=target_length + ) diff --git a/bubogpt/processors/imagebind_vision_processor.py b/bubogpt/processors/imagebind_vision_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..4e1f572116490fce47dd099f346792ca535bf946 --- /dev/null +++ b/bubogpt/processors/imagebind_vision_processor.py @@ -0,0 +1,151 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import re + +from bubogpt.common.registry import registry +from bubogpt.processors.base_processor import BaseProcessor +from bubogpt.processors.vision_augment import RandomAugment +from omegaconf import OmegaConf +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + + +class ImageBindVisionBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None): + super().__init__() + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms.Normalize(mean, std) + + +# Note: The config of caption processor is different from the ones in BLIP2 / MiniGPT4 +@registry.register_processor("imagebind_caption") +class ImageBindCaptionProcessor(BaseProcessor): + def __init__(self, prompt="", max_words=50): + # Note: Actually no prompts are used here. + super().__init__() + self.prompt = prompt + self.max_words = max_words + + def __call__(self, caption): + caption = self.prompt + self.pre_caption(caption) + + return caption + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + prompt = cfg.get("prompt", "") + max_words = cfg.get("max_words", 150) + + return cls(prompt=prompt, max_words=max_words) + + def pre_caption(self, caption): + caption = re.sub( + r"([\n\"()*#~])", + " ", + caption, + ) + caption = re.sub( + r"\s{2,}", + " ", + caption, + ) + caption = caption.rstrip("\n") + caption = caption.strip(" ") + + # # truncate caption Note: Deprecated. + # caption_words = caption.split(" ") + # if len(caption_words) > self.max_words: + # caption = " ".join(caption_words[: self.max_words]) + + return caption + + +# Note: The training config of vision processor keeps the same as BLIP2 / MiniGPT4 +@registry.register_processor("imagebind_vision_train") +class ImageBindVisionTrainProcessor(ImageBindVisionBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + image_size, + scale=(min_scale, max_scale), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + min_scale = cfg.get("min_scale", 0.5) + max_scale = cfg.get("max_scale", 1.0) + + return cls( + image_size=image_size, + mean=mean, + std=std, + min_scale=min_scale, + max_scale=max_scale, + ) + + +# Changed. +@registry.register_processor("imagebind_vision_eval") +class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.Resize( + image_size, interpolation=InterpolationMode.BICUBIC + ), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls(image_size=image_size, mean=mean, std=std) + + + diff --git a/bubogpt/processors/vision_augment.py b/bubogpt/processors/vision_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..7034a49ad5fc63b97910790017432617ff4c6d7b --- /dev/null +++ b/bubogpt/processors/vision_augment.py @@ -0,0 +1,398 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import cv2 +import numpy as np + +import torch + + +## aug functions +def identity_func(img): + return img + + +def autocontrast_func(img, cutoff=0): + """ + same output as PIL.ImageOps.autocontrast + """ + n_bins = 256 + + def tune_channel(ch): + n = ch.size + cut = cutoff * n // 100 + if cut == 0: + high, low = ch.max(), ch.min() + else: + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + low = np.argwhere(np.cumsum(hist) > cut) + low = 0 if low.shape[0] == 0 else low[0] + high = np.argwhere(np.cumsum(hist[::-1]) > cut) + high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] + if high <= low: + table = np.arange(n_bins) + else: + scale = (n_bins - 1) / (high - low) + offset = -low * scale + table = np.arange(n_bins) * scale + offset + table[table < 0] = 0 + table[table > n_bins - 1] = n_bins - 1 + table = table.clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def equalize_func(img): + """ + same output as PIL.ImageOps.equalize + PIL's implementation is different from cv2.equalize + """ + n_bins = 256 + + def tune_channel(ch): + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + non_zero_hist = hist[hist != 0].reshape(-1) + step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) + if step == 0: + return ch + n = np.empty_like(hist) + n[0] = step // 2 + n[1:] = hist[:-1] + table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def rotate_func(img, degree, fill=(0, 0, 0)): + """ + like PIL, rotate by degree, not radians + """ + H, W = img.shape[0], img.shape[1] + center = W / 2, H / 2 + M = cv2.getRotationMatrix2D(center, degree, 1) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill) + return out + + +def solarize_func(img, thresh=128): + """ + same output as PIL.ImageOps.posterize + """ + table = np.array([el if el < thresh else 255 - el for el in range(256)]) + table = table.clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def color_func(img, factor): + """ + same output as PIL.ImageEnhance.Color + """ + ## implementation according to PIL definition, quite slow + # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] + # out = blend(degenerate, img, factor) + # M = ( + # np.eye(3) * factor + # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) + # )[np.newaxis, np.newaxis, :] + M = np.float32( + [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] + ) * factor + np.float32([[0.114], [0.587], [0.299]]) + out = np.matmul(img, M).clip(0, 255).astype(np.uint8) + return out + + +def contrast_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) + table = ( + np.array([(el - mean) * factor + mean for el in range(256)]) + .clip(0, 255) + .astype(np.uint8) + ) + out = table[img] + return out + + +def brightness_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def sharpness_func(img, factor): + """ + The differences the this result and PIL are all on the 4 boundaries, the center + areas are same + """ + kernel = np.ones((3, 3), dtype=np.float32) + kernel[1][1] = 5 + kernel /= 13 + degenerate = cv2.filter2D(img, -1, kernel) + if factor == 0.0: + out = degenerate + elif factor == 1.0: + out = img + else: + out = img.astype(np.float32) + degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] + out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) + out = out.astype(np.uint8) + return out + + +def shear_x_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, factor, 0], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def translate_x_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, -offset], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def translate_y_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [0, 1, -offset]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def posterize_func(img, bits): + """ + same output as PIL.ImageOps.posterize + """ + out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) + return out + + +def shear_y_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [factor, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR + ).astype(np.uint8) + return out + + +def cutout_func(img, pad_size, replace=(0, 0, 0)): + replace = np.array(replace, dtype=np.uint8) + H, W = img.shape[0], img.shape[1] + rh, rw = np.random.random(2) + pad_size = pad_size // 2 + ch, cw = int(rh * H), int(rw * W) + x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) + y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) + out = img.copy() + out[x1:x2, y1:y2, :] = replace + return out + + +### level to args +def enhance_level_to_args(MAX_LEVEL): + def level_to_args(level): + return ((level / MAX_LEVEL) * 1.8 + 0.1,) + + return level_to_args + + +def shear_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 0.3 + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * float(translate_const) + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = int((level / MAX_LEVEL) * cutout_const) + return (level, replace_value) + + return level_to_args + + +def solarize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 256) + return (level,) + + return level_to_args + + +def none_level_to_args(level): + return () + + +def posterize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 4) + return (level,) + + return level_to_args + + +def rotate_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 30 + if np.random.random() < 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +func_dict = { + "Identity": identity_func, + "AutoContrast": autocontrast_func, + "Equalize": equalize_func, + "Rotate": rotate_func, + "Solarize": solarize_func, + "Color": color_func, + "Contrast": contrast_func, + "Brightness": brightness_func, + "Sharpness": sharpness_func, + "ShearX": shear_x_func, + "TranslateX": translate_x_func, + "TranslateY": translate_y_func, + "Posterize": posterize_func, + "ShearY": shear_y_func, +} + +translate_const = 10 +MAX_LEVEL = 10 +replace_value = (128, 128, 128) +arg_dict = { + "Identity": none_level_to_args, + "AutoContrast": none_level_to_args, + "Equalize": none_level_to_args, + "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), + "Solarize": solarize_level_to_args(MAX_LEVEL), + "Color": enhance_level_to_args(MAX_LEVEL), + "Contrast": enhance_level_to_args(MAX_LEVEL), + "Brightness": enhance_level_to_args(MAX_LEVEL), + "Sharpness": enhance_level_to_args(MAX_LEVEL), + "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), + "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "Posterize": posterize_level_to_args(MAX_LEVEL), + "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), +} + + +class RandomAugment(object): + def __init__(self, N=2, M=10, isPIL=False, augs=[]): + self.N = N + self.M = M + self.isPIL = isPIL + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N) + return [(op, 0.5, self.M) for op in sampled_ops] + + def __call__(self, img): + if self.isPIL: + img = np.array(img) + ops = self.get_random_ops() + for name, prob, level in ops: + if np.random.random() > prob: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return img + + +class VideoRandomAugment(object): + def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): + self.N = N + self.M = M + self.p = p + self.tensor_in_tensor_out = tensor_in_tensor_out + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N, replace=False) + return [(op, self.M) for op in sampled_ops] + + def __call__(self, frames): + assert ( + frames.shape[-1] == 3 + ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." + + if self.tensor_in_tensor_out: + frames = frames.numpy().astype(np.uint8) + + num_frames = frames.shape[0] + + ops = num_frames * [self.get_random_ops()] + apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] + + frames = torch.stack( + list(map(self._aug, frames, ops, apply_or_not)), dim=0 + ).float() + + return frames + + def _aug(self, img, ops, apply_or_not): + for i, (name, level) in enumerate(ops): + if not apply_or_not[i]: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return torch.from_numpy(img) + + +if __name__ == "__main__": + a = RandomAugment() + img = np.random.randn(32, 32, 3) + a(img) diff --git a/bubogpt/runners/__init__.py b/bubogpt/runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c3234b7ce28fb41cf2ccf86f2e1d79489394091 --- /dev/null +++ b/bubogpt/runners/__init__.py @@ -0,0 +1,10 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from bubogpt.runners.runner_base import RunnerBase + +__all__ = ["RunnerBase"] diff --git a/bubogpt/runners/runner_base.py b/bubogpt/runners/runner_base.py new file mode 100644 index 0000000000000000000000000000000000000000..b85bbfa64bc3dda086ca79a45c6d0cc883a56b73 --- /dev/null +++ b/bubogpt/runners/runner_base.py @@ -0,0 +1,699 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import json +import logging +import os +import time +from pathlib import Path + +import torch +import torch.distributed as dist +import webdataset as wds +from bubogpt.common.dist_utils import ( + download_cached_file, + get_rank, + get_world_size, + is_main_process, + main_process, +) +from bubogpt.common.registry import registry +from bubogpt.common.utils import is_url +from bubogpt.datasets.data_utils import concat_datasets, reorg_datasets_by_split, WrappedChainDataset +from bubogpt.datasets.datasets.dataloader_utils import ( + IterLoader, + MultiIterLoader, + PrefetchLoader, +) +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, DistributedSampler + + +@registry.register_runner("runner_base") +class RunnerBase: + """ + A runner class to train and evaluate a model given a task and datasets. + + The runner uses pytorch distributed data parallel by default. Future release + will support other distributed frameworks. + """ + + def __init__(self, cfg, task, model, datasets, job_id): + self.config = cfg + self.job_id = job_id + + self.task = task + self.datasets = datasets + + self._model = model + + self._wrapped_model = None + self._device = None + self._optimizer = None + self._scaler = None + self._dataloaders = None + self._lr_sched = None + + self.start_epoch = 0 + + # self.setup_seeds() + self.setup_output_dir() + + @property + def device(self): + if self._device is None: + self._device = torch.device(self.config.run_cfg.device) + + return self._device + + @property + def use_distributed(self): + return self.config.run_cfg.distributed + + @property + def model(self): + """ + A property to get the DDP-wrapped model on the device. + """ + # move model to device + if self._model.device != self.device: + self._model = self._model.to(self.device) + + # distributed training wrapper + if self.use_distributed: + if self._wrapped_model is None: + self._wrapped_model = DDP( + self._model, device_ids=[self.config.run_cfg.gpu], + find_unused_parameters=True + ) + else: + self._wrapped_model = self._model + + return self._wrapped_model + + @property + def optimizer(self): + # TODO make optimizer class and configurations + if self._optimizer is None: + num_parameters = 0 + p_wd, p_non_wd = [], [] + for n, p in self.model.named_parameters(): + if not p.requires_grad: + continue # frozen weights + print(n) + if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n: + p_non_wd.append(p) + else: + p_wd.append(p) + num_parameters += p.data.nelement() + logging.info("number of trainable parameters: %d" % num_parameters) + optim_params = [ + { + "params": p_wd, + "weight_decay": float(self.config.run_cfg.weight_decay), + }, + {"params": p_non_wd, "weight_decay": 0}, + ] + beta2 = self.config.run_cfg.get("beta2", 0.999) + self._optimizer = torch.optim.AdamW( + optim_params, + lr=float(self.config.run_cfg.init_lr), + weight_decay=float(self.config.run_cfg.weight_decay), + betas=(0.9, beta2), + ) + + return self._optimizer + + @property + def scaler(self): + amp = self.config.run_cfg.get("amp", False) + + if amp: + if self._scaler is None: + self._scaler = torch.cuda.amp.GradScaler() + + return self._scaler + + @property + def lr_scheduler(self): + """ + A property to get and create learning rate scheduler by split just in need. + """ + if self._lr_sched is None: + lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched) + + # max_epoch = self.config.run_cfg.max_epoch + max_epoch = self.max_epoch + # min_lr = self.config.run_cfg.min_lr + min_lr = self.min_lr + # init_lr = self.config.run_cfg.init_lr + init_lr = self.init_lr + + # optional parameters + decay_rate = self.config.run_cfg.get("lr_decay_rate", None) + warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1) + warmup_steps = self.config.run_cfg.get("warmup_steps", 0) + iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None) + + if iters_per_epoch is None: + try: + iters_per_epoch = len(self.dataloaders['train']) + except (AttributeError, TypeError): + iters_per_epoch = 10000 + + self._lr_sched = lr_sched_cls( + optimizer=self.optimizer, + max_epoch=max_epoch, + iters_per_epoch=iters_per_epoch, + min_lr=min_lr, + init_lr=init_lr, + decay_rate=decay_rate, + warmup_start_lr=warmup_start_lr, + warmup_steps=warmup_steps, + ) + + return self._lr_sched + + @property + def dataloaders(self) -> dict: + """ + A property to get and create dataloaders by split just in need. + + If no train_dataset_ratio is provided, concatenate map-style datasets and + chain wds.DataPipe datasets separately. Training set becomes a tuple + (ConcatDataset, ChainDataset), both are optional but at least one of them is + required. The resultant ConcatDataset and ChainDataset will be sampled evenly. + + If train_dataset_ratio is provided, create a MultiIterLoader to sample + each dataset by ratios during training. + + Currently do not support multiple datasets for validation and test. + + Returns: + dict: {split_name: (tuples of) dataloader} + """ + if self._dataloaders is None: + + # concatenate map-style datasets and chain wds.DataPipe datasets separately + # training set becomes a tuple (ConcatDataset, ChainDataset), both are + # optional but at least one of them is required. The resultant ConcatDataset + # and ChainDataset will be sampled evenly. + logging.info( + "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." + ) + + datasets = reorg_datasets_by_split(self.datasets) + self.datasets = datasets + # self.datasets = concat_datasets(datasets) + + # print dataset statistics after concatenation/chaining + for split_name in self.datasets: + if isinstance(self.datasets[split_name], tuple) or isinstance( + self.datasets[split_name], list + ): + # mixed wds.DataPipeline and torch.utils.data.Dataset + num_records = sum( + [ + len(d) + if not type(d) in [wds.DataPipeline, WrappedChainDataset] + else 0 + for d in self.datasets[split_name] + ] + ) + + else: + if hasattr(self.datasets[split_name], "__len__"): + # a single map-style dataset + num_records = len(self.datasets[split_name]) + else: + # a single wds.DataPipeline + num_records = -1 + logging.info( + "Only a single wds.DataPipeline dataset, no __len__ attribute." + ) + + if num_records >= 0: + logging.info( + "Loaded {} records for {} split from the dataset.".format( + num_records, split_name + ) + ) + + # create dataloaders + split_names = sorted(self.datasets.keys()) + + datasets = [self.datasets[split] for split in split_names] + is_trains = [split in self.train_splits for split in split_names] + + batch_sizes = [ + self.config.run_cfg.batch_size_train + if split == "train" + else self.config.run_cfg.batch_size_eval + for split in split_names + ] + + collate_fns = [] + for dataset in datasets: + if isinstance(dataset, tuple) or isinstance(dataset, list): + collate_fns.append([getattr(d, "collater", None) for d in dataset]) + else: + collate_fns.append(getattr(dataset, "collater", None)) + + dataloaders = self.create_loaders( + datasets=datasets, + num_workers=self.config.run_cfg.num_workers, + batch_sizes=batch_sizes, + is_trains=is_trains, + collate_fns=collate_fns, + ) + + self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} + + return self._dataloaders + + @property + def cuda_enabled(self): + return self.device.type == "cuda" + + @property + def max_epoch(self): + return int(self.config.run_cfg.max_epoch) + + @property + def log_freq(self): + log_freq = self.config.run_cfg.get("log_freq", 50) + return int(log_freq) + + @property + def init_lr(self): + return float(self.config.run_cfg.init_lr) + + @property + def min_lr(self): + return float(self.config.run_cfg.min_lr) + + @property + def accum_grad_iters(self): + return int(self.config.run_cfg.get("accum_grad_iters", 1)) + + @property + def valid_splits(self): + valid_splits = self.config.run_cfg.get("valid_splits", []) + + if len(valid_splits) == 0: + logging.info("No validation splits found.") + + return valid_splits + + @property + def test_splits(self): + test_splits = self.config.run_cfg.get("test_splits", []) + + return test_splits + + @property + def train_splits(self): + train_splits = self.config.run_cfg.get("train_splits", []) + + if len(train_splits) == 0: + logging.info("Empty train splits.") + + return train_splits + + @property + def evaluate_only(self): + """ + Set to True to skip training. + """ + return self.config.run_cfg.evaluate + + @property + def use_dist_eval_sampler(self): + return self.config.run_cfg.get("use_dist_eval_sampler", True) + + @property + def resume_ckpt_path(self): + return self.config.run_cfg.get("resume_ckpt_path", None) + + @property + def train_loader(self): + train_dataloader = self.dataloaders["train"] + + return train_dataloader + + def setup_output_dir(self): + lib_root = Path(registry.get_path("library_root")) + + output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id + result_dir = output_dir / "result" + + output_dir.mkdir(parents=True, exist_ok=True) + result_dir.mkdir(parents=True, exist_ok=True) + + registry.register_path("result_dir", str(result_dir)) + registry.register_path("output_dir", str(output_dir)) + + self.result_dir = result_dir + self.output_dir = output_dir + + def train(self): + start_time = time.time() + best_agg_metric = 0 + best_epoch = 0 + + self.log_config() + + # resume from checkpoint if specified + if not self.evaluate_only and self.resume_ckpt_path is not None: + self._load_checkpoint(self.resume_ckpt_path) + + for cur_epoch in range(self.start_epoch, self.max_epoch): + # training phase + if not self.evaluate_only: + logging.info("Start training") + train_stats = self.train_epoch(cur_epoch) + self.log_stats(split_name="train", stats=train_stats) + + # evaluation phase + if len(self.valid_splits) > 0: + for split_name in self.valid_splits: + logging.info("Evaluating on {}.".format(split_name)) + + val_log = self.eval_epoch( + split_name=split_name, cur_epoch=cur_epoch + ) + if val_log is not None: + if is_main_process(): + assert ( + "agg_metrics" in val_log + ), "No agg_metrics found in validation log." + + agg_metrics = val_log["agg_metrics"] + if agg_metrics > best_agg_metric and split_name == "val": + best_epoch, best_agg_metric = cur_epoch, agg_metrics + + self._save_checkpoint(cur_epoch, is_best=True) + + val_log.update({"best_epoch": best_epoch}) + self.log_stats(val_log, split_name) + + else: + # if no validation split is provided, we just save the checkpoint at the end of each epoch. + if not self.evaluate_only: + self._save_checkpoint(cur_epoch, is_best=False) + + if self.evaluate_only: + break + + if self.config.run_cfg.distributed: + dist.barrier() + + # testing phase + test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch + self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Training time {}".format(total_time_str)) + + def evaluate(self, cur_epoch="best", skip_reload=False): + test_logs = dict() + + if len(self.test_splits) > 0: + for split_name in self.test_splits: + test_logs[split_name] = self.eval_epoch( + split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload + ) + + return test_logs + + def train_epoch(self, epoch): + # train + self.model.train() + + return self.task.train_epoch( + epoch=epoch, + model=self.model, + data_loader=self.train_loader, + optimizer=self.optimizer, + scaler=self.scaler, + lr_scheduler=self.lr_scheduler, + cuda_enabled=self.cuda_enabled, + log_freq=self.log_freq, + accum_grad_iters=self.accum_grad_iters, + ) + + @torch.no_grad() + def eval_epoch(self, split_name, cur_epoch, skip_reload=False): + """ + Evaluate the model on a given split. + + Args: + split_name (str): name of the split to evaluate on. + cur_epoch (int): current epoch. + skip_reload_best (bool): whether to skip reloading the best checkpoint. + During training, we will reload the best checkpoint for validation. + During testing, we will use provided weights and skip reloading the best checkpoint . + """ + data_loader = self.dataloaders.get(split_name, None) + assert data_loader, "data_loader for split {} is None.".format(split_name) + + # TODO In validation, you need to compute loss as well as metrics + # TODO consider moving to model.before_evaluation() + model = self.unwrap_dist_model(self.model) + if not skip_reload and cur_epoch == "best": + model = self._reload_best_model(model) + model.eval() + + self.task.before_evaluation( + model=model, + dataset=self.datasets[split_name], + ) + results = self.task.evaluation(model, data_loader) + + if results is not None: + return self.task.after_evaluation( + val_result=results, + split_name=split_name, + epoch=cur_epoch, + ) + + def unwrap_dist_model(self, model): + if self.use_distributed: + return model.module + else: + return model + + def create_loaders( + self, + datasets, + num_workers, + batch_sizes, + is_trains, + collate_fns, + dataset_ratios=None, + ): + """ + Create dataloaders for training and validation. + """ + + def _create_loader(dataset, num_workers, bsz, is_train, collate_fn): + # create a single dataloader for each split + if isinstance(dataset, WrappedChainDataset) or isinstance( + dataset, wds.DataPipeline + ): + # wds.WebdDataset instance are chained together + # webdataset.DataPipeline has its own sampler and collate_fn + loader = iter( + DataLoader( + dataset, + batch_size=bsz, + num_workers=num_workers, + pin_memory=True, + ) + ) + else: + # map-style dataset are concatenated together + # setup distributed sampler + if self.use_distributed: + sampler = DistributedSampler( + dataset, + shuffle=is_train, + num_replicas=get_world_size(), + rank=get_rank(), + ) + if not self.use_dist_eval_sampler: + # e.g. retrieval evaluation + sampler = sampler if is_train else None + else: + sampler = None + + loader = DataLoader( + dataset, + batch_size=bsz, + num_workers=num_workers, + pin_memory=True, + sampler=sampler, + shuffle=sampler is None and is_train, + collate_fn=collate_fn, + drop_last=True if is_train else False, + ) + loader = PrefetchLoader(loader) + + if is_train: + loader = IterLoader(loader, use_distributed=self.use_distributed) + + return loader + + def regroup_by_data_type(dataset): + if not isinstance(dataset, (tuple, list)): + return [dataset] + + dtypes = set([d.data_type for d in dataset]) + type2data = {} + for dtype in dtypes: + type2data[dtype] = [d for d in dataset if d.data_type == dtype] + + return list(type2data.values()), dtypes + + def get_data_type_ratio(datasets): + ratios = [] + for type_dataests in datasets: + type_ratio = None + for dataset in type_dataests: + if hasattr(dataset, 'dtype_ratio') and dataset.dtype_ratio is not None: + type_ratio = dataset.dtype_ratio + ratios.append(type_ratio) + + if any([x is None for x in ratios]): + ratios = [] + else: + return ratios + + # Use sample ratio as the data_type ratio + for type_datasets in datasets: + ratios.append(sum([d.sample_ratio for d in type_datasets])) + return ratios + + loaders = [] + for mix_dataset, bsz, is_train, collate_fn in zip( + datasets, batch_sizes, is_trains, collate_fns + ): + mix_dataset, dtypes = regroup_by_data_type(mix_dataset) + mix_loader = [] + for dataset in mix_dataset: + if isinstance(dataset, list) or isinstance(dataset, tuple): + dataset_ratios = None + if hasattr(dataset[0], 'sample_ratio'): + dataset_ratios = [d.sample_ratio for d in dataset] + loader = MultiIterLoader( + loaders=[ + _create_loader(d, num_workers, bsz, is_train, collate_fn[i]) + for i, d in enumerate(dataset) + ], + ratios=dataset_ratios, + ) + else: + loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn) + mix_loader.append(loader) + print(f"There are {len(mix_dataset)} of data types, They are:", dtypes) + if len(mix_loader) == 1: + loaders.append(mix_loader[0]) + else: + loader_ratios = get_data_type_ratio(mix_dataset) + print("Data type ratios are: ", loader_ratios) + merged_loader = MultiIterLoader(loaders=mix_loader, ratios=loader_ratios) + loaders.append(merged_loader) + + return loaders + + @main_process + def _save_checkpoint(self, cur_epoch, is_best=False): + """ + Save the checkpoint at the current epoch. + """ + model_no_ddp = self.unwrap_dist_model(self.model) + param_grad_dic = { + k: v.requires_grad for (k, v) in model_no_ddp.named_parameters() + } + state_dict = model_no_ddp.state_dict() + for k in list(state_dict.keys()): + if k in param_grad_dic.keys() and not param_grad_dic[k]: + # delete parameters that do not require gradient + del state_dict[k] + save_obj = { + "model": state_dict, + "optimizer": self.optimizer.state_dict(), + "config": self.config.to_dict(), + "scaler": self.scaler.state_dict() if self.scaler else None, + "epoch": cur_epoch, + } + save_to = os.path.join( + self.output_dir, + "checkpoint_{}.pth".format("best" if is_best else cur_epoch), + ) + logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to)) + torch.save(save_obj, save_to) + + def _reload_best_model(self, model): + """ + Load the best checkpoint for evaluation. + """ + checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth") + + logging.info("Loading checkpoint from {}.".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + try: + model.load_state_dict(checkpoint["model"]) + except RuntimeError as e: + logging.warning( + """ + Key mismatch when loading checkpoint. This is expected if only part of the model is saved. + Trying to load the model with strict=False. + """ + ) + model.load_state_dict(checkpoint["model"], strict=False) + return model + + def _load_checkpoint(self, url_or_filename): + """ + Resume from a checkpoint. + """ + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location=self.device) + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location=self.device) + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False) + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if self.scaler and "scaler" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler"]) + + self.start_epoch = checkpoint["epoch"] + 1 + logging.info("Resume checkpoint from {}".format(url_or_filename)) + + @main_process + def log_stats(self, stats, split_name): + if isinstance(stats, dict): + log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}} + with open(os.path.join(self.output_dir, "log.txt"), "a") as f: + f.write(json.dumps(log_stats) + "\n") + elif isinstance(stats, list): + pass + + @main_process + def log_config(self): + with open(os.path.join(self.output_dir, "log.txt"), "a") as f: + f.write(json.dumps(self.config.to_dict(), indent=4) + "\n") diff --git a/bubogpt/tasks/__init__.py b/bubogpt/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b98c918e501e40b6f74f41fdb6ab4550b24b683 --- /dev/null +++ b/bubogpt/tasks/__init__.py @@ -0,0 +1,26 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from bubogpt.common.registry import registry +from bubogpt.tasks.base_task import BaseTask +from bubogpt.tasks.image_text_pretrain import ImageTextPretrainTask + + +def setup_task(cfg): + assert "task" in cfg.run_cfg, "Task name must be provided." + + task_name = cfg.run_cfg.task + task = registry.get_task_class(task_name).setup_task(cfg=cfg) + assert task is not None, "Task {} not properly registered.".format(task_name) + + return task + + +__all__ = [ + "BaseTask", + "ImageTextPretrainTask", +] diff --git a/bubogpt/tasks/base_task.py b/bubogpt/tasks/base_task.py new file mode 100644 index 0000000000000000000000000000000000000000..421cc24c2d7748495d4282aa294f64e4c4a248a3 --- /dev/null +++ b/bubogpt/tasks/base_task.py @@ -0,0 +1,305 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os + +import torch +import torch.distributed as dist +from bubogpt.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized +from bubogpt.common.logger import MetricLogger, SmoothedValue +from bubogpt.common.registry import registry +from bubogpt.datasets.data_utils import prepare_sample + + +def get_data_type(config, builder): + data_type = config.get("data_type", None) + if data_type is not None: + return data_type + + if isinstance(builder.data_type, str): + return builder.data_type + elif isinstance(builder.data_type, (tuple, list)): + return "_".join(builder.data_type) + else: + raise RuntimeError(f"Data type: {builder.data_type} not recognized!") + + +class BaseTask: + def __init__(self, **kwargs): + super().__init__() + + self.inst_id_key = "instance_id" + + @classmethod + def setup_task(cls, **kwargs): + return cls() + + def build_model(self, cfg): + model_config = cfg.model_cfg + + model_cls = registry.get_model_class(model_config.arch) + return model_cls.from_config(model_config) + + def build_datasets(self, cfg): + """ + Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'. + Download dataset and annotations automatically if not exist. + + Args: + cfg (common.config.Config): _description_ + + Returns: + dict: Dictionary of torch.utils.data.Dataset objects by split. + """ + + datasets = dict() + + datasets_config = cfg.datasets_cfg + + assert len(datasets_config) > 0, "At least one dataset has to be specified." + + for name in datasets_config: + dataset_config = datasets_config[name] + + builder = registry.get_builder_class(name)(dataset_config) + dataset = builder.build_datasets() + + dataset['train'].name = name + if 'sample_ratio' in dataset_config: + dataset['train'].sample_ratio = dataset_config.sample_ratio + for k, v in dataset.items(): + v.data_type = get_data_type(datasets_config, builder) + if 'dtype_ratio' in dataset_config: + v.dtype_ratio = dataset_config['dtype_ratio'] + else: + v.dtype_ratio = None + + datasets[name] = dataset + + return datasets + + def train_step(self, model, samples): + loss = model(samples)["loss"] + return loss + + def valid_step(self, model, samples): + raise NotImplementedError + + def before_evaluation(self, model, dataset, **kwargs): + model.before_evaluation(dataset=dataset, task_type=type(self)) + + def after_evaluation(self, **kwargs): + pass + + def inference_step(self): + raise NotImplementedError + + def evaluation(self, model, data_loader, cuda_enabled=True): + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation" + # TODO make it configurable + print_freq = 10 + + results = [] + + for samples in metric_logger.log_every(data_loader, print_freq, header): + samples = prepare_sample(samples, cuda_enabled=cuda_enabled) + + eval_output = self.valid_step(model=model, samples=samples) + results.extend(eval_output) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return results + + def train_epoch( + self, + epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + cuda_enabled=False, + log_freq=50, + accum_grad_iters=1, + ): + return self._train_inner_loop( + epoch=epoch, + iters_per_epoch=lr_scheduler.iters_per_epoch, + model=model, + data_loader=data_loader, + optimizer=optimizer, + scaler=scaler, + lr_scheduler=lr_scheduler, + log_freq=log_freq, + cuda_enabled=cuda_enabled, + accum_grad_iters=accum_grad_iters, + ) + + def train_iters( + self, + epoch, + start_iters, + iters_per_inner_epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + cuda_enabled=False, + log_freq=50, + accum_grad_iters=1, + ): + return self._train_inner_loop( + epoch=epoch, + start_iters=start_iters, + iters_per_epoch=iters_per_inner_epoch, + model=model, + data_loader=data_loader, + optimizer=optimizer, + scaler=scaler, + lr_scheduler=lr_scheduler, + log_freq=log_freq, + cuda_enabled=cuda_enabled, + accum_grad_iters=accum_grad_iters, + ) + + def _train_inner_loop( + self, + epoch, + iters_per_epoch, + model, + data_loader, + optimizer, + lr_scheduler, + scaler=None, + start_iters=None, + log_freq=50, + cuda_enabled=False, + accum_grad_iters=1, + ): + """ + An inner training loop compatible with both epoch-based and iter-based training. + + When using epoch-based, training stops after one epoch; when using iter-based, + training stops after #iters_per_epoch iterations. + """ + use_amp = scaler is not None + + if not hasattr(data_loader, "__next__"): + # convert to iterator if not already + data_loader = iter(data_loader) + + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) + + # if iter-based runner, schedule lr based on inner epoch. + logging.info( + "Start training epoch {}, {} iters per inner epoch.".format( + epoch, iters_per_epoch + ) + ) + header = "Train: data epoch: [{}]".format(epoch) + if start_iters is None: + # epoch-based runner + inner_epoch = epoch + else: + # In iter-based runner, we schedule the learning rate based on iterations. + inner_epoch = start_iters // iters_per_epoch + header = header + "; inner epoch [{}]".format(inner_epoch) + + for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): + # if using iter-based runner, we stop after iters_per_epoch iterations. + if i >= iters_per_epoch: + break + + samples = next(data_loader) + + samples = prepare_sample(samples, cuda_enabled=cuda_enabled) + samples.update( + { + "epoch": inner_epoch, + "num_iters_per_epoch": iters_per_epoch, + "iters": i, + } + ) + + lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i) + + with torch.cuda.amp.autocast(enabled=use_amp): + loss = self.train_step(model=model, samples=samples) + + # after_train_step() + if use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + # update gradients every accum_grad_iters iterations + if (i + 1) % accum_grad_iters == 0: + if use_amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + # after train_epoch() + # gather the stats from all processes + metric_logger.synchronize_between_processes() + logging.info("Averaged stats: " + str(metric_logger.global_avg())) + return { + k: "{:.3f}".format(meter.global_avg) + for k, meter in metric_logger.meters.items() + } + + @staticmethod + def save_result(result, result_dir, filename, remove_duplicate=""): + import json + + result_file = os.path.join( + result_dir, "%s_rank%d.json" % (filename, get_rank()) + ) + final_result_file = os.path.join(result_dir, "%s.json" % filename) + + json.dump(result, open(result_file, "w")) + + if is_dist_avail_and_initialized(): + dist.barrier() + + if is_main_process(): + logging.warning("rank %d starts merging results." % get_rank()) + # combine results from all processes + result = [] + + for rank in range(get_world_size()): + result_file = os.path.join( + result_dir, "%s_rank%d.json" % (filename, rank) + ) + res = json.load(open(result_file, "r")) + result += res + + if remove_duplicate: + result_new = [] + id_list = [] + for res in result: + if res[remove_duplicate] not in id_list: + id_list.append(res[remove_duplicate]) + result_new.append(res) + result = result_new + + json.dump(result, open(final_result_file, "w")) + print("result file saved to %s" % final_result_file) + + return final_result_file diff --git a/bubogpt/tasks/image_text_pretrain.py b/bubogpt/tasks/image_text_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..ac02e140a5af65a13f68451e084b124caa32cbc3 --- /dev/null +++ b/bubogpt/tasks/image_text_pretrain.py @@ -0,0 +1,18 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from bubogpt.common.registry import registry +from bubogpt.tasks.base_task import BaseTask + + +@registry.register_task("image_text_pretrain") +class ImageTextPretrainTask(BaseTask): + def __init__(self): + super().__init__() + + def evaluation(self, model, data_loader, cuda_enabled=True): + pass diff --git a/checkpoints/bubogpt_7b.pth b/checkpoints/bubogpt_7b.pth new file mode 100644 index 0000000000000000000000000000000000000000..0dbf239811b4b33022296f6f5d3134d81072f326 --- /dev/null +++ b/checkpoints/bubogpt_7b.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f226a9d71f093cdcae1a0e44c4228fe09321c102f3f610ce2a3b09a25ec48af8 +size 746968666 diff --git a/checkpoints/groundingdino_swint_ogc.pth b/checkpoints/groundingdino_swint_ogc.pth new file mode 100644 index 0000000000000000000000000000000000000000..5cdf6bcd10d491abf170a78eca4fcebf76aa791a --- /dev/null +++ b/checkpoints/groundingdino_swint_ogc.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b3ca2563c77c69f651d7bd133e97139c186df06231157a64c507099c52bc799 +size 693997677 diff --git a/checkpoints/ram_swin_large_14m.pth b/checkpoints/ram_swin_large_14m.pth new file mode 100644 index 0000000000000000000000000000000000000000..a477cb695d2f8da77df0af6a77b687d3cf78fd4a --- /dev/null +++ b/checkpoints/ram_swin_large_14m.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15c729c793af28b9d107c69f85836a1356d76ea830d4714699fb62e55fcc08ed +size 5625634877 diff --git a/checkpoints/sam_vit_h_4b8939.pth b/checkpoints/sam_vit_h_4b8939.pth new file mode 100644 index 0000000000000000000000000000000000000000..8523acce9ddab1cf7e355628a08b1aab8ce08a72 --- /dev/null +++ b/checkpoints/sam_vit_h_4b8939.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e +size 2564550879 diff --git a/eval_configs/GroundingDINO_SwinT_OGC.yaml b/eval_configs/GroundingDINO_SwinT_OGC.yaml new file mode 100644 index 0000000000000000000000000000000000000000..85b08736dd7fe98b5fe7ff18c333430c572e5980 --- /dev/null +++ b/eval_configs/GroundingDINO_SwinT_OGC.yaml @@ -0,0 +1,43 @@ +batch_size: 1 +modelname: "groundingdino" +backbone: "swin_T_224_1k" +position_embedding: "sine" +pe_temperatureH: 20 +pe_temperatureW: 20 +return_interm_indices: [1, 2, 3] +backbone_freeze_keywords: None +enc_layers: 6 +dec_layers: 6 +pre_norm: False +dim_feedforward: 2048 +hidden_dim: 256 +dropout: 0.0 +nheads: 8 +num_queries: 900 +query_dim: 4 +num_patterns: 0 +num_feature_levels: 4 +enc_n_points: 4 +dec_n_points: 4 +two_stage_type: "standard" +two_stage_bbox_embed_share: False +two_stage_class_embed_share: False +transformer_activation: "relu" +dec_pred_bbox_embed_share: True +dn_box_noise_scale: 1.0 +dn_label_noise_ratio: 0.5 +dn_label_coef: 1.0 +dn_bbox_coef: 1.0 +embed_init_tgt: True +dn_labelbook_size: 2000 +max_text_len: 256 +text_encoder_type: "bert-base-uncased" +use_text_enhancer: True +use_fusion_layer: True +use_checkpoint: True +use_transformer_ckpt: True +use_text_cross_attention: True +text_dropout: 0.0 +fusion_dropout: 0.0 +fusion_droppath: 0.1 +sub_sentence_present: True diff --git a/eval_configs/mmgpt4_eval.yaml b/eval_configs/mmgpt4_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ac3ae066b125fc63f126478e48bb851ba36cb9c --- /dev/null +++ b/eval_configs/mmgpt4_eval.yaml @@ -0,0 +1,46 @@ +model: + arch: mm_gpt4 + model_type: pretrain_vicuna + freeze_imagebind: True + freeze_qformer: False + max_txt_len: 160 + end_sym: "###" + low_resource: False + prompt_path: "prompts/alignment.txt" + prompt_template: '###Human: {} ###Assistant: ' + ckpt: ['checkpoints/bubogpt_7b.pth', + # 'checkpoints/mmgpt2_stage1_audio.pth', + # 'checkpoints/mmgpt2_stage2_mm_5k.pth', + ] + with_bind_head: False + use_blip_vision: True + joiner_cfg: + # NOTE: uncomment below to share qformer across modalities + # share_key: vision + vision: + feat_dim: 1408 + post_dims: [768,] + num_query_token: 32 + freeze_qformer: True + audio: + feat_dim: 768 + + +datasets: + default: # Double check + vis_processor: + eval: + name: "imagebind_vision_eval" + image_size: 224 + text_processor: + eval: + name: "imagebind_caption" + audio_processor: + eval: + name: "imagebind_audio_eval" + use_global: True + clip_duration: 5 + clips_per_video: 6 +run: + task: image_text_pretrain + evaluate: True diff --git a/eval_configs/mmgpt4_eval_13b.yaml b/eval_configs/mmgpt4_eval_13b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36a9a0c6756bdecacedc45028359a4863e73fef8 --- /dev/null +++ b/eval_configs/mmgpt4_eval_13b.yaml @@ -0,0 +1,51 @@ +model: + arch: mm_gpt4 + model_type: pretrain_vicuna + freeze_imagebind: True + freeze_qformer: False + max_txt_len: 160 + end_sym: "###" + low_resource: False + prompt_path: "prompts/alignment.txt" + prompt_template: '###Human: {} ###Assistant: ' + ckpt: [ + "bubogpt/output/mmgpt4_stage2_mm_blipvision_13b/20230701204/checkpoint_4.pth", + ] + with_bind_head: False + use_blip_vision: True + proj_model: "checkpoints/prerained_minigpt4_13b.pth" + llama_model: "/mnt/bn/bykang/chixma/data/pretrained_models/vicuna-13b-v0/" + joiner_cfg: + # NOTE: uncomment below to share qformer across modalities + # share_key: vision + vision: + feat_dim: 1408 + post_dims: [768,] + num_query_token: 32 + freeze_qformer: True + audio: + feat_dim: 768 + + +datasets: + default: # Double check + vis_processor: + eval: + name: "imagebind_vision_eval" + image_size: 224 + text_processor: + eval: + name: "imagebind_caption" + audio_processor: + eval: + name: "imagebind_audio_eval" + # d2c18 + # clip_duration: 2 + # clips_per_video: 18 + # d5c6 + use_global: True + clip_duration: 5 + clips_per_video: 6 +run: + task: image_text_pretrain + evaluate: True diff --git a/eval_scripts/__init__.py b/eval_scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/eval_scripts/conversation.py b/eval_scripts/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..3118b813475b5913e2b62ab97115b0ba9ebde07a --- /dev/null +++ b/eval_scripts/conversation.py @@ -0,0 +1,216 @@ +import dataclasses +from copy import deepcopy +from types import SimpleNamespace +from typing import List, Union, Dict, Tuple +import numpy as np + +import torch +from PIL import Image +from torch import nn, Tensor +from transformers import StoppingCriteria, StoppingCriteriaList + +from eval_scripts.eval_utils import load_image, load_audio +from imagebind.models.image_bind import ModalityType +from bubogpt import BaseProcessor + +Roles = SimpleNamespace( + HUMAN="Human", + ASSISTANT="Assistant" +) + + +class Message: + def __init__(self, role: str, content: Union[str, None]): + self.role = role + self.content = content + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + messages: List[Message] + sep: str = "###" + + def get_prompt(self): + ret = self.system + self.sep + for message in self.messages: + if message.content: + ret += message.role + ": " + message.content + self.sep + else: + ret += message.role + ":" + return ret + + def append_message(self, role, content): + self.messages.append(Message(role, content)) + + def copy(self): + return Conversation( + system=self.system, + messages=deepcopy(self.messages), + sep=self.sep) + + def dict(self): + return { + "system": self.system, + "messages": [(msg.role, msg.content) for msg in self.messages], + "sep": self.sep + } + + +class StoppingCriteriaSub(StoppingCriteria): + def __init__(self, stops=[], encounters=1): + super().__init__() + self.stops = stops + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + for stop in self.stops: + if torch.all((stop == input_ids[0][-len(stop):])).item(): + return True + + return False + + +CONV_X = Conversation( + # system="Give the following ..." + # "You will be able to ... once I provide it to you. Please answer my questions.", + system="Give the following image: ImageContent or audio: . " + "You will be able to see the image/audio once I provide it to you. Please answer my questions.", + messages=[], + sep="###", +) + + +# TODO: If needed and possible, rewrite this file and re-organize the definition of components. + +class DummyChat: + def __init__(self, dummy_answer=None, *args, **kwargs): + self.dummy_answer = dummy_answer + + def ask(self, text, conversation): + conversation.append_message(Roles.HUMAN, text) + + def answer(self, *args, **kwargs): + if self.dummy_answer is not None: + return self.dummy_answer, None + else: + print(kwargs) + return kwargs["conversation"].messages[-1].content, None + + def upload_img(self, *args, **kwargs): + pass + + def upload_aud(self, *args, **kwargs): + pass + + + +class Chat: + def __init__(self, + model: nn.Module, + processors: Dict[str, BaseProcessor], + device: str = 'cuda:0' + ): + self.device = device + self.model = model + self.processors = processors + stop_words_ids = [torch.tensor([835]).to(self.device), + torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. + self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + self.just_uploaded = False + + def ask(self, text, conversation): + # NOTE: the hard code for postfix is removed. + # end_token = '' + # if len(conversation.messages) > 0 and conversation.messages[-1].role == Roles.HUMAN \ + # and conversation.messages[-1].content[-len(end_token):] == end_token: + if self.just_uploaded: + conversation.messages[-1].content = ' '.join([conversation.messages[-1].content, text]) + self.just_uploaded = False + else: + conversation.append_message(Roles.HUMAN, text) + + def answer(self, conversation, emb_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): + # Generate an answer written by LLaMA + conversation.append_message(Roles.ASSISTANT, None) + embs = self.get_context_emb(conversation, emb_list) + + current_max_len = embs.shape[1] + max_new_tokens + if current_max_len - max_length > 0: + print('Warning: The number of tokens in current conversation exceeds the max length. ' + 'The model will not see the contexts outside the range.') + begin_idx = max(0, current_max_len - max_length) + + embs = embs[:, begin_idx:] + + outputs = self.model.llama_model.generate( + inputs_embeds=embs, + max_new_tokens=max_new_tokens, + stopping_criteria=self.stopping_criteria, + num_beams=num_beams, + do_sample=True, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + temperature=temperature, + ) + output_token = outputs[0] + if output_token[0] == 0: # the model might output a unknown token at the beginning. remove it + output_token = output_token[1:] + if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it + output_token = output_token[1:] + output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) + output_text = output_text.split('###')[0] # remove the stop sign '###' + output_text = output_text.split('Assistant:')[-1].strip() + conversation.messages[-1].content = output_text + return output_text, output_token.cpu().numpy() + + def upload_img(self, image: Union[str, Image.Image, Tensor], conversation: Conversation, emb_list: List[Tensor]): + # Upload Image, Encode Image and Create a new message from human. + image = load_image(image, self.processors[ModalityType.VISION]).to(self.device) + if hasattr(self.model, "encode_img"): + # To compitable with minigpt4 + image_emb, _ = self.model.encode_img(image) + else: + all_embeddings = self.model.encode_inputs({ModalityType.VISION: image}) + image_emb = all_embeddings[ModalityType.VISION] + emb_list.append(image_emb) + conversation.append_message(Roles.HUMAN, "") + self.just_uploaded = True + + # def upload_img_mini(self, image: Union[str, Image.Image, Tensor], conversation: Conversation, emb_list: List[Tensor]): + # # Upload Image, Encode Image and Create a new message from human. + # image = load_image(image, self.processors[ModalityType.VISION]).to(self.device) + # image_emb, _ = self.model.encode_img(image) + # emb_list.append(image_emb) + # conversation.append_message(Roles.HUMAN, "") + + def upload_aud(self, audio: Union[str, Tuple[int, np.ndarray]], conversation: Conversation, emb_list: List[Tensor]): + # Upload Audio, Encode Audio and Create a new message from human. + audio = load_audio(audio, self.processors[ModalityType.AUDIO]).to(self.device) + audio = audio.float() + all_embeddings = self.model.encode_inputs({ModalityType.AUDIO: audio}) + audio_emb = all_embeddings[ModalityType.AUDIO] + emb_list.append(audio_emb) + conversation.append_message(Roles.HUMAN, "") + self.just_uploaded = True + + def get_context_emb(self, conversation: Conversation, emb_list: List[Tensor]): + # Insert the embeddings into the prompts and queries. + # NOTE: Assume the placeholders have been aligned to the embeddings! + prompt = conversation.get_prompt() + print(prompt) + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(emb_list) + 1, "Unmatched numbers of placeholders and embeddings." + seg_tokens = [ + self.model.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids + # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] + mixed_embs = [emb for pair in zip(seg_embs[:-1], emb_list) for emb in pair] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs \ No newline at end of file diff --git a/eval_scripts/eval_utils.py b/eval_scripts/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..82ead0cd221a5ce71fbff23977d2a6313cbd4db2 --- /dev/null +++ b/eval_scripts/eval_utils.py @@ -0,0 +1,36 @@ +import torch +import torchaudio +from PIL import Image +import numpy as np + + +def load_image(image, image_processor): + if isinstance(image, str): # is a image path + raw_image = Image.open(image).convert('RGB') + image = image_processor(raw_image).unsqueeze(0) + elif isinstance(image, Image.Image): + raw_image = image + image = image_processor(raw_image).unsqueeze(0) + elif isinstance(image, torch.Tensor): + if len(image.shape) == 3: + image = image.unsqueeze(0) + return image + + +def load_audio(audio, audio_processor): + if isinstance(audio, str): # is a audio path + raw_audio = torchaudio.load(audio) + audio = audio_processor(raw_audio) + elif isinstance(audio, tuple): + sample_rate, raw_waveform = audio + waveform = raw_waveform / np.iinfo(raw_waveform.dtype).max + if waveform.ndim == 1: + waveform = torch.from_numpy(waveform[None, :]) + elif waveform.ndim == 2: + waveform = torch.from_numpy(waveform).mean(1).unsqueeze(0) + else: + raise NotImplementedError # "No such data!" + audio = audio_processor((waveform, sample_rate)) + else: + raise NotImplementedError + return audio.unsqueeze(0) diff --git a/grounding_model.py b/grounding_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b09c4d79aa21011a6bb482ec924c85861b88ad25 --- /dev/null +++ b/grounding_model.py @@ -0,0 +1,386 @@ +import PIL +import numpy as np +import torch +import torch.nn as nn +import torchvision +from yacs.config import CfgNode as CN +from PIL import ImageDraw +from segment_anything import build_sam, SamPredictor +from segment_anything.utils.amg import remove_small_regions +from PIL import ImageDraw, ImageFont + +import groundingdino.util.transforms as T +from constants.constant import DARKER_COLOR_MAP, LIGHTER_COLOR_MAP, COLORS +from groundingdino import build_groundingdino +from groundingdino.util.predict import predict +from groundingdino.util.utils import clean_state_dict + + +def load_groundingdino_model(model_config_path, model_checkpoint_path): + args = CN.load_cfg(open(model_config_path, "r")) + model = build_groundingdino(args) + checkpoint = torch.load(model_checkpoint_path, map_location="cpu") + load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) + print('loading GroundingDINO:', load_res) + _ = model.eval() + return model + + +class GroundingModule(nn.Module): + def __init__(self, device='cpu'): + super().__init__() + self.device = device + sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth" + groundingdino_checkpoint = "./checkpoints/groundingdino_swint_ogc.pth" + groundingdino_config_file = "./eval_configs/GroundingDINO_SwinT_OGC.yaml" + + self.grounding_model = load_groundingdino_model(groundingdino_config_file, + groundingdino_checkpoint).to(device) + self.grounding_model.eval() + + sam = build_sam(checkpoint=sam_checkpoint).to(device) + sam.eval() + self.sam_predictor = SamPredictor(sam) + + @torch.no_grad() + def prompt2mask(self, original_image, prompt, state, box_threshold=0.35, text_threshold=0.25, num_boxes=10): + def image_transform_grounding(init_image): + transform = T.Compose([ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + image, _ = transform(init_image, None) # 3, h, w + return init_image, image + + image_np = np.array(original_image, dtype=np.uint8) + prompt = prompt.lower() + prompt = prompt.strip() + if not prompt.endswith("."): + prompt = prompt + "." + _, image_tensor = image_transform_grounding(original_image) + print('==> Box grounding with "{}"...'.format(prompt)) + with torch.cuda.amp.autocast(enabled=True): + boxes, logits, phrases = predict(self.grounding_model, + image_tensor, prompt, box_threshold, text_threshold, device=self.device) + print(phrases) + # from PIL import Image, ImageDraw, ImageFont + H, W = original_image.size[1], original_image.size[0] + + draw_img = original_image.copy() + draw = ImageDraw.Draw(draw_img) + color_boxes = [] + color_masks = [] + local_results = [original_image.copy() for _ in range(len(state['entity']))] + + local2entity = {} + for obj_ind, (box, label) in enumerate(zip(boxes, phrases)): + # from 0..1 to 0..W, 0..H + box = box * torch.Tensor([W, H, W, H]) + # from xywh to xyxy + box[:2] -= box[2:] / 2 + box[2:] += box[:2] + # random color + for i, s in enumerate(state['entity']): + # print(label.lower(), i[0].lower(), label.lower() == i[0].lower()) + if label.lower() == s[0].lower(): + local2entity[obj_ind] = i + break + + if obj_ind not in local2entity: + print('Color not found', label) + color = "grey" # In grey mode. + # tuple(np.random.randint(0, 255, size=3).tolist()) + else: + for i, s in enumerate(state['entity']): + # print(label.lower(), i[0].lower(), label.lower() == i[0].lower()) + if label.lower() == s[0].lower(): + local2entity[obj_ind] = i + break + + if obj_ind not in local2entity: + print('Color not found', label) + color = tuple(np.random.randint(0, 255, size=3).tolist()) + else: + color = state['entity'][local2entity[obj_ind]][3] + color_boxes.append(color) + print(color_boxes) + # draw + x0, y0, x1, y1 = box + x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) + + draw.rectangle([x0, y0, x1, y1], outline=color, width=10) + # font = ImageFont.load_default() + font = ImageFont.truetype('InputSans-Regular.ttf', int(H / 512.0 * 30)) + if hasattr(font, "getbbox"): + bbox = draw.textbbox((x0, y0), str(label), font) + else: + w, h = draw.textsize(str(label), font) + bbox = (x0, y0, w + x0, y0 + h) + draw.rectangle(bbox, fill=color) + draw.text((x0, y0), str(label), fill="white", font=font) + + if obj_ind in local2entity: + local_draw = ImageDraw.Draw(local_results[local2entity[obj_ind]]) + local_draw.rectangle([x0, y0, x1, y1], outline=color, width=10) + local_draw.rectangle(bbox, fill=color) + local_draw.text((x0, y0), str(label), fill="white", font=font) + + if boxes.size(0) > 0: + print('==> Mask grounding...') + boxes = boxes * torch.Tensor([W, H, W, H]) + boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 + boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] + + self.sam_predictor.set_image(image_np) + + transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2]) + with torch.cuda.amp.autocast(enabled=True): + masks, _, _ = self.sam_predictor.predict_torch( + point_coords=None, + point_labels=None, + boxes=transformed_boxes.to(self.device), + multimask_output=False, + ) + + # remove small disconnected regions and holes + fine_masks = [] + for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w] + fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) + masks = np.stack(fine_masks, axis=0)[:, np.newaxis] + masks = torch.from_numpy(masks) + + num_obj = min(len(logits), num_boxes) + mask_map = None + + full_img = None + for obj_ind in range(num_obj): + # box = boxes[obj_ind] + + m = masks[obj_ind][0] + + if full_img is None: + full_img = np.zeros((m.shape[0], m.shape[1], 3)) + mask_map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) + local_image = np.zeros((m.shape[0], m.shape[1], 3)) + + mask_map[m != 0] = obj_ind + 1 + # color_mask = np.random.random((1, 3)).tolist()[0] + color_mask = np.array(color_boxes[obj_ind]) / 255.0 + full_img[m != 0] = color_mask + local_image[m != 0] = color_mask + # if local_results[local2entity[obj_ind]] is not None: + # local_image[m == 0] = np.asarray(local_results[local2entity[obj_ind]])[m == 0] + local_image = (local_image * 255).astype(np.uint8) + local_image = PIL.Image.fromarray(local_image) + if local_results[local2entity[obj_ind]] is not None: + local_results[local2entity[obj_ind]] = PIL.Image.blend(local_results[local2entity[obj_ind]], + local_image, 0.5) + full_img = (full_img * 255).astype(np.uint8) + full_img = PIL.Image.fromarray(full_img) + draw_img = PIL.Image.blend(draw_img, full_img, 0.5) + + return draw_img, local_results + + # def draw_text(self, entity_state, entity, text): + # local_img = entity_state['grounding']['local'][entity]['image'].copy() + # H, W = local_img.width, local_img.height + # font = ImageFont.truetype('InputSans-Regular.ttf', int(min(H, W) / 512.0 * 30)) + # + # for x0, y0 in entity_state['grounding']['local'][entity]['text_positions']: + # color = entity_state['grounding']['local'][entity]['color'] + # local_draw = ImageDraw.Draw(local_img) + # if hasattr(font, "getbbox"): + # bbox = local_draw.textbbox((x0, y0), str(text), font) + # else: + # w, h = local_draw.textsize(str(text), font) + # bbox = (x0, y0, w + x0, y0 + h) + # + # local_draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) + # local_draw.text((x0, y0), str(text), fill="white", font=font) + # return local_img + + def draw(self, original_image, entity_state, item=None): + original_image = original_image.copy() + W, H = original_image.width, original_image.height + font = ImageFont.truetype('InputSans-Regular.ttf', int(min(H, W) / 512.0 * 30)) + local_image = np.zeros((H, W, 3)) + local_mask = np.zeros((H, W), dtype=bool) + + def draw_item(img, item): + nonlocal local_image, local_mask + entity = entity_state['match_state'][item] + ei = entity_state['grounding']['local'][entity] + color = ei['color'] + local_draw = ImageDraw.Draw(img) + for x0, y0, x1, y1 in ei['entity_positions']: + local_draw.rectangle([x0, y0, x1, y1], outline=DARKER_COLOR_MAP[color], + width=int(min(H, W) / 512.0 * 10)) + for x0, y0 in ei['text_positions']: + if hasattr(font, "getbbox"): + bbox = local_draw.textbbox((x0, y0), str(item), font) + else: + w, h = local_draw.textsize(str(item), font) + bbox = (x0, y0, w + x0, y0 + h) + + local_draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) + local_draw.text((x0, y0), str(item), fill="white", font=font) + for m in ei['masks']: + local_image[m != 0] = np.array(LIGHTER_COLOR_MAP[color]) / 255.0 + local_mask = np.logical_or(local_mask, m) + # local_image = (local_image * 255).astype(np.uint8) + # local_image = PIL.Image.fromarray(local_image) + # img = PIL.Image.blend(img, local_image, 0.5) + return img + + if item is None: + for item in entity_state['match_state'].keys(): + original_image = draw_item(original_image, item) + else: + original_image = draw_item(original_image, item) + local_image[local_mask == 0] = (np.array(original_image) / 255.0)[local_mask == 0] + local_image = (local_image * 255).astype(np.uint8) + local_image = PIL.Image.fromarray(local_image) + + original_image = PIL.Image.blend(original_image, local_image, 0.5) + return original_image + + @torch.no_grad() + def prompt2mask2(self, original_image, prompt, state, box_threshold=0.25, + text_threshold=0.2, iou_threshold=0.5, num_boxes=10): + def image_transform_grounding(init_image): + transform = T.Compose([ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + image, _ = transform(init_image, None) # 3, h, w + return init_image, image + + image_np = np.array(original_image, dtype=np.uint8) + prompt = prompt.lower() + prompt = prompt.strip() + if not prompt.endswith("."): + prompt = prompt + "." + _, image_tensor = image_transform_grounding(original_image) + print('==> Box grounding with "{}"...'.format(prompt)) + with torch.cuda.amp.autocast(enabled=True): + boxes, logits, phrases = predict(self.grounding_model, + image_tensor, prompt, box_threshold, text_threshold, device=self.device) + print('==> Box grounding results {}...'.format(phrases)) + + # boxes_filt = boxes.cpu() + # # use NMS to handle overlapped boxes + # print(f"==> Before NMS: {boxes_filt.shape[0]} boxes") + # nms_idx = torchvision.ops.nms(boxes_filt, logits, iou_threshold).numpy().tolist() + # boxes_filt = boxes_filt[nms_idx] + # phrases = [phrases[idx] for idx in nms_idx] + # print(f"==> After NMS: {boxes_filt.shape[0]} boxes") + # boxes = boxes_filt + + # from PIL import Image, ImageDraw, ImageFont + H, W = original_image.size[1], original_image.size[0] + + draw_img = original_image.copy() + draw = ImageDraw.Draw(draw_img) + color_boxes = [] + color_masks = [] + + entity_dict = {} + for obj_ind, (box, label) in enumerate(zip(boxes, phrases)): + # from 0..1 to 0..W, 0..H + box = box * torch.Tensor([W, H, W, H]) + # from xywh to xyxy + box[:2] -= box[2:] / 2 + box[2:] += box[:2] + if label not in entity_dict: + entity_dict[label] = { + 'color': COLORS[len(entity_dict) % (len(COLORS) - 1)], + # 'image': original_image.copy(), + 'text_positions': [], + 'entity_positions': [], + 'masks': [] + } + color = entity_dict[label]['color'] + + color_boxes.append(DARKER_COLOR_MAP[color]) + color_masks.append(LIGHTER_COLOR_MAP[color]) + + # draw + x0, y0, x1, y1 = box + x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) + + draw.rectangle([x0, y0, x1, y1], outline=DARKER_COLOR_MAP[color], width=10) + font = ImageFont.truetype('InputSans-Regular.ttf', int(min(H, W) / 512.0 * 30)) + if hasattr(font, "getbbox"): + bbox = draw.textbbox((x0, y0), str(label), font) + else: + w, h = draw.textsize(str(label), font) + bbox = (x0, y0, w + x0, y0 + h) + + draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) + draw.text((x0, y0), str(label), fill="white", font=font) + + # local_img = entity_dict[label]['image'] + # local_draw = ImageDraw.Draw(local_img) + # local_draw.rectangle([x0, y0, x1, y1], outline=DARKER_COLOR_MAP[color], width=10) + entity_dict[label]['text_positions'].append((x0, y0)) + entity_dict[label]['entity_positions'].append((x0, y0, x1, y1)) + # local_draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) + # local_draw.text((x0, y0), str(label), fill="white", font=font) + + if boxes.size(0) > 0: + print('==> Mask grounding...') + boxes = boxes * torch.Tensor([W, H, W, H]) + boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 + boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] + + self.sam_predictor.set_image(image_np) + + transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes, + image_np.shape[:2]).to(self.device) + with torch.cuda.amp.autocast(enabled=True): + masks, _, _ = self.sam_predictor.predict_torch( + point_coords=None, + point_labels=None, + boxes=transformed_boxes.to(self.device), + multimask_output=False, + ) + + # remove small disconnected regions and holes + fine_masks = [] + for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w] + fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) + masks = np.stack(fine_masks, axis=0)[:, np.newaxis] + masks = torch.from_numpy(masks) + + mask_map = None + + full_img = None + for obj_ind, (box, label) in enumerate(zip(boxes, phrases)): + + m = masks[obj_ind][0] + + if full_img is None: + full_img = np.zeros((m.shape[0], m.shape[1], 3)) + mask_map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) + # local_image = np.zeros((m.shape[0], m.shape[1], 3)) + + mask_map[m != 0] = obj_ind + 1 + color_mask = np.array(color_masks[obj_ind]) / 255.0 + + full_img[m != 0] = color_mask + + entity_dict[label]['masks'].append(m) + # local_image[m != 0] = color_mask + # local_image[m == 0] = (np.array(entity_dict[label]['image']) / 255.0)[m == 0] + # + # local_image = (local_image * 255).astype(np.uint8) + # local_image = PIL.Image.fromarray(local_image) + # entity_dict[label]['image'] = PIL.Image.blend(entity_dict[label]['image'], local_image, 0.5) + + full_img = (full_img * 255).astype(np.uint8) + full_img = PIL.Image.fromarray(full_img) + draw_img = PIL.Image.blend(draw_img, full_img, 0.5) + print('==> Entity list: {}'.format(list(entity_dict.keys()))) + return draw_img, entity_dict diff --git a/groundingdino/__init__.py b/groundingdino/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9529332047bbb4a8602d2d634b25e83376b05ab --- /dev/null +++ b/groundingdino/__init__.py @@ -0,0 +1,15 @@ +# ------------------------------------------------------------------------ +# Grounding DINO +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Conditional DETR +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +from groundingdino.models.GroundingDINO.groundingdino import build_groundingdino diff --git a/groundingdino/models/GroundingDINO/__init__.py b/groundingdino/models/GroundingDINO/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/groundingdino/models/GroundingDINO/backbone/__init__.py b/groundingdino/models/GroundingDINO/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76e4b272b479a26c63d120c818c140870cd8c287 --- /dev/null +++ b/groundingdino/models/GroundingDINO/backbone/__init__.py @@ -0,0 +1 @@ +from .backbone import build_backbone diff --git a/groundingdino/models/GroundingDINO/backbone/backbone.py b/groundingdino/models/GroundingDINO/backbone/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..6940f1b46f16cd7c94bf79a7996897604292ca8c --- /dev/null +++ b/groundingdino/models/GroundingDINO/backbone/backbone.py @@ -0,0 +1,221 @@ +# ------------------------------------------------------------------------ +# Grounding DINO +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Conditional DETR +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Backbone modules. +""" + +from typing import Dict, List + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter + +from groundingdino.util.misc import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding +from .swin_transformer import build_swin_transformer + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + def __init__( + self, + backbone: nn.Module, + train_backbone: bool, + num_channels: int, + return_interm_indices: list, + ): + super().__init__() + for name, parameter in backbone.named_parameters(): + if ( + not train_backbone + or "layer2" not in name + and "layer3" not in name + and "layer4" not in name + ): + parameter.requires_grad_(False) + + return_layers = {} + for idx, layer_index in enumerate(return_interm_indices): + return_layers.update( + {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)} + ) + + # if len: + # if use_stage1_feature: + # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + # else: + # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} + # else: + # return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + # import ipdb; ipdb.set_trace() + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + + def __init__( + self, + name: str, + train_backbone: bool, + dilation: bool, + return_interm_indices: list, + batch_norm=FrozenBatchNorm2d, + ): + if name in ["resnet18", "resnet34", "resnet50", "resnet101"]: + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), + norm_layer=batch_norm, + ) + else: + raise NotImplementedError("Why you can get here with name {}".format(name)) + # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available." + assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] + num_channels_all = [256, 512, 1024, 2048] + num_channels = num_channels_all[4 - len(return_interm_indices) :] + super().__init__(backbone, train_backbone, num_channels, return_interm_indices) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(args): + """ + Useful args: + - backbone: backbone name + - lr_backbone: + - dilation + - return_interm_indices: available: [0,1,2,3], [1,2,3], [3] + - backbone_freeze_keywords: + - use_checkpoint: for swin only for now + + """ + position_embedding = build_position_encoding(args) + train_backbone = True + if not train_backbone: + raise ValueError("Please set lr_backbone > 0") + return_interm_indices = args.return_interm_indices + assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] + args.backbone_freeze_keywords + use_checkpoint = getattr(args, "use_checkpoint", False) + + if args.backbone in ["resnet50", "resnet101"]: + backbone = Backbone( + args.backbone, + train_backbone, + args.dilation, + return_interm_indices, + batch_norm=FrozenBatchNorm2d, + ) + bb_num_channels = backbone.num_channels + elif args.backbone in [ + "swin_T_224_1k", + "swin_B_224_22k", + "swin_B_384_22k", + "swin_L_224_22k", + "swin_L_384_22k", + ]: + pretrain_img_size = int(args.backbone.split("_")[-2]) + backbone = build_swin_transformer( + args.backbone, + pretrain_img_size=pretrain_img_size, + out_indices=tuple(return_interm_indices), + dilation=False, + use_checkpoint=use_checkpoint, + ) + + bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :] + else: + raise NotImplementedError("Unknown backbone {}".format(args.backbone)) + + assert len(bb_num_channels) == len( + return_interm_indices + ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}" + + model = Joiner(backbone, position_embedding) + model.num_channels = bb_num_channels + assert isinstance( + bb_num_channels, List + ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels)) + # import ipdb; ipdb.set_trace() + return model diff --git a/groundingdino/models/GroundingDINO/backbone/position_encoding.py b/groundingdino/models/GroundingDINO/backbone/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..eac7e896bbe85a670824bfe8ef487d0535d5bd99 --- /dev/null +++ b/groundingdino/models/GroundingDINO/backbone/position_encoding.py @@ -0,0 +1,186 @@ +# ------------------------------------------------------------------------ +# Grounding DINO +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DINO +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Conditional DETR +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Copied from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" +Various positional encodings for the transformer. +""" +import math + +import torch +from torch import nn + +from groundingdino.util.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + # if os.environ.get("SHILONG_AMP", None) == '1': + # eps = 1e-4 + # else: + # eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingSineHW(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperatureH = temperatureH + self.temperatureW = temperatureW + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + + # import ipdb; ipdb.set_trace() + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats) + pos_x = x_embed[:, :, :, None] / dim_tx + + dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats) + pos_y = y_embed[:, :, :, None] / dim_ty + + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + + # import ipdb; ipdb.set_trace() + + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = ( + torch.cat( + [ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], + dim=-1, + ) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(x.shape[0], 1, 1, 1) + ) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ("v2", "sine"): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSineHW( + N_steps, + temperatureH=args.pe_temperatureH, + temperatureW=args.pe_temperatureW, + normalize=True, + ) + elif args.position_embedding in ("v3", "learned"): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/groundingdino/models/GroundingDINO/backbone/swin_transformer.py b/groundingdino/models/GroundingDINO/backbone/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1c66194deb5dd370e797e57e2712f44303e568cc --- /dev/null +++ b/groundingdino/models/GroundingDINO/backbone/swin_transformer.py @@ -0,0 +1,802 @@ +# ------------------------------------------------------------------------ +# Grounding DINO +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DINO +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# -------------------------------------------------------- +# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py +# -------------------------------------------------------- + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from groundingdino.util.misc import NestedTensor + + +class Mlp(nn.Module): + """Multilayer perceptron.""" + + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop + ) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + dilation (bool): if True, the output size if 16x downsample, ow 32x downsample. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + dilation=False, + use_checkpoint=False, + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.dilation = dilation + + # if use_checkpoint: + # print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!") + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1], + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + # prepare downsample list + downsamplelist = [PatchMerging for i in range(self.num_layers)] + downsamplelist[-1] = None + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + if self.dilation: + downsamplelist[-2] = None + num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2 + for i_layer in range(self.num_layers): + layer = BasicLayer( + # dim=int(embed_dim * 2 ** i_layer), + dim=num_features[i_layer], + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + downsample=downsamplelist[i_layer], + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # def init_weights(self, pretrained=None): + # """Initialize the weights in backbone. + # Args: + # pretrained (str, optional): Path to pre-trained weights. + # Defaults to None. + # """ + + # def _init_weights(m): + # if isinstance(m, nn.Linear): + # trunc_normal_(m.weight, std=.02) + # if isinstance(m, nn.Linear) and m.bias is not None: + # nn.init.constant_(m.bias, 0) + # elif isinstance(m, nn.LayerNorm): + # nn.init.constant_(m.bias, 0) + # nn.init.constant_(m.weight, 1.0) + + # if isinstance(pretrained, str): + # self.apply(_init_weights) + # logger = get_root_logger() + # load_checkpoint(self, pretrained, strict=False, logger=logger) + # elif pretrained is None: + # self.apply(_init_weights) + # else: + # raise TypeError('pretrained must be a str or None') + + def forward_raw(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + # import ipdb; ipdb.set_trace() + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + # in: + # torch.Size([2, 3, 1024, 1024]) + # outs: + # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \ + # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])] + return tuple(outs) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + # in: + # torch.Size([2, 3, 1024, 1024]) + # out: + # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \ + # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])] + + # collect for nesttensors + outs_dict = {} + for idx, out_i in enumerate(outs): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0] + outs_dict[idx] = NestedTensor(out_i, mask) + + return outs_dict + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +def build_swin_transformer(modelname, pretrain_img_size, **kw): + assert modelname in [ + "swin_T_224_1k", + "swin_B_224_22k", + "swin_B_384_22k", + "swin_L_224_22k", + "swin_L_384_22k", + ] + + model_para_dict = { + "swin_T_224_1k": dict( + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7 + ), + "swin_B_224_22k": dict( + embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7 + ), + "swin_B_384_22k": dict( + embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12 + ), + "swin_L_224_22k": dict( + embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7 + ), + "swin_L_384_22k": dict( + embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12 + ), + } + kw_cgf = model_para_dict[modelname] + kw_cgf.update(kw) + model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf) + return model + + +if __name__ == "__main__": + model = build_swin_transformer("swin_L_384_22k", 384, dilation=True) + x = torch.rand(2, 3, 1024, 1024) + y = model.forward_raw(x) + import ipdb + + ipdb.set_trace() + x = torch.rand(2, 3, 384, 384) + y = model.forward_raw(x) diff --git a/groundingdino/models/GroundingDINO/bertwarper.py b/groundingdino/models/GroundingDINO/bertwarper.py new file mode 100644 index 0000000000000000000000000000000000000000..f0cf9779b270e1aead32845006f8b881fcba37ad --- /dev/null +++ b/groundingdino/models/GroundingDINO/bertwarper.py @@ -0,0 +1,273 @@ +# ------------------------------------------------------------------------ +# Grounding DINO +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch import Tensor, nn +from torchvision.ops.boxes import nms +from transformers import BertConfig, BertModel, BertPreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions + + +class BertModelWarper(nn.Module): + def __init__(self, bert_model): + super().__init__() + # self.bert = bert_modelc + + self.config = bert_model.config + self.embeddings = bert_model.embeddings + self.encoder = bert_model.encoder + self.pooler = bert_model.pooler + + self.get_extended_attention_mask = bert_model.get_extended_attention_mask + self.invert_attention_mask = bert_model.invert_attention_mask + self.get_head_mask = bert_model.get_head_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': + # import ipdb; ipdb.set_trace() + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class TextEncoderShell(nn.Module): + def __init__(self, text_encoder): + super().__init__() + self.text_encoder = text_encoder + self.config = self.text_encoder.config + + def forward(self, **kw): + # feed into text encoder + return self.text_encoder(**kw) + + +def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer): + """Generate attention mask between each pair of special tokens + Args: + input_ids (torch.Tensor): input ids. Shape: [bs, num_token] + special_tokens_mask (list): special tokens mask. + Returns: + torch.Tensor: attention mask between each special tokens. + """ + input_ids = tokenized["input_ids"] + bs, num_token = input_ids.shape + # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens + special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool() + for special_token in special_tokens_list: + special_tokens_mask |= input_ids == special_token + + # idxs: each row is a list of indices of special tokens + idxs = torch.nonzero(special_tokens_mask) + + # generate attention mask and positional ids + attention_mask = ( + torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1) + ) + position_ids = torch.zeros((bs, num_token), device=input_ids.device) + previous_col = 0 + for i in range(idxs.shape[0]): + row, col = idxs[i] + if (col == 0) or (col == num_token - 1): + attention_mask[row, col, col] = True + position_ids[row, col] = 0 + else: + attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True + position_ids[row, previous_col + 1 : col + 1] = torch.arange( + 0, col - previous_col, device=input_ids.device + ) + + previous_col = col + + # # padding mask + # padding_mask = tokenized['attention_mask'] + # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool() + + return attention_mask, position_ids.to(torch.long) + + +def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer): + """Generate attention mask between each pair of special tokens + Args: + input_ids (torch.Tensor): input ids. Shape: [bs, num_token] + special_tokens_mask (list): special tokens mask. + Returns: + torch.Tensor: attention mask between each special tokens. + """ + input_ids = tokenized["input_ids"] + bs, num_token = input_ids.shape + # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens + special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool() + for special_token in special_tokens_list: + special_tokens_mask |= input_ids == special_token + + # idxs: each row is a list of indices of special tokens + idxs = torch.nonzero(special_tokens_mask) + + # generate attention mask and positional ids + attention_mask = ( + torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1) + ) + position_ids = torch.zeros((bs, num_token), device=input_ids.device) + cate_to_token_mask_list = [[] for _ in range(bs)] + previous_col = 0 + for i in range(idxs.shape[0]): + row, col = idxs[i] + if (col == 0) or (col == num_token - 1): + attention_mask[row, col, col] = True + position_ids[row, col] = 0 + else: + attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True + position_ids[row, previous_col + 1 : col + 1] = torch.arange( + 0, col - previous_col, device=input_ids.device + ) + c2t_maski = torch.zeros((num_token), device=input_ids.device).bool() + c2t_maski[previous_col + 1 : col] = True + cate_to_token_mask_list[row].append(c2t_maski) + previous_col = col + + cate_to_token_mask_list = [ + torch.stack(cate_to_token_mask_listi, dim=0) + for cate_to_token_mask_listi in cate_to_token_mask_list + ] + + # # padding mask + # padding_mask = tokenized['attention_mask'] + # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool() + + return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list diff --git a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h new file mode 100644 index 0000000000000000000000000000000000000000..c7408eba007b424194618baa63726657e36875e3 --- /dev/null +++ b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h @@ -0,0 +1,64 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "ms_deform_attn_cuda.h" +#endif + +namespace groundingdino { + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +} // namespace groundingdino \ No newline at end of file diff --git a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..551243fdadfd1682b5dc6628623b67a79b3f6c74 --- /dev/null +++ b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp @@ -0,0 +1,43 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + +namespace groundingdino { + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +} // namespace groundingdino diff --git a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..b2b88e8c46f19b6db0933163e57ccdb51180f517 --- /dev/null +++ b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h @@ -0,0 +1,35 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +namespace groundingdino { + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + +} // namespace groundingdino diff --git a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..d04fae8a9a45c11e4e74f3035e94762796da4096 --- /dev/null +++ b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu @@ -0,0 +1,156 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + +namespace groundingdino { + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} + +} // namespace groundingdino \ No newline at end of file diff --git a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..ad1311a78f61303616504eb991aaa9c4a93d9948 --- /dev/null +++ b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h @@ -0,0 +1,33 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +namespace groundingdino { + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + +} // namespace groundingdino \ No newline at end of file diff --git a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6bc2acb7aea0eab2e9e91e769a16861e1652c284 --- /dev/null +++ b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1327 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} \ No newline at end of file diff --git a/groundingdino/models/GroundingDINO/csrc/cuda_version.cu b/groundingdino/models/GroundingDINO/csrc/cuda_version.cu new file mode 100644 index 0000000000000000000000000000000000000000..64569e34ffb250964de27e33e7a53f3822270b9e --- /dev/null +++ b/groundingdino/models/GroundingDINO/csrc/cuda_version.cu @@ -0,0 +1,7 @@ +#include + +namespace groundingdino { +int get_cudart_version() { + return CUDART_VERSION; +} +} // namespace groundingdino diff --git a/groundingdino/models/GroundingDINO/csrc/vision.cpp b/groundingdino/models/GroundingDINO/csrc/vision.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1f2c50c82909bbd5492c163d634af77a3ba1781 --- /dev/null +++ b/groundingdino/models/GroundingDINO/csrc/vision.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +#include "MsDeformAttn/ms_deform_attn.h" + +namespace groundingdino { + +#ifdef WITH_CUDA +extern int get_cudart_version(); +#endif + +std::string get_cuda_version() { +#ifdef WITH_CUDA + std::ostringstream oss; + + // copied from + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 + auto printCudaStyleVersion = [&](int v) { + oss << (v / 1000) << "." << (v / 10 % 100); + if (v % 10 != 0) { + oss << "." << (v % 10); + } + }; + printCudaStyleVersion(get_cudart_version()); + return oss.str(); +#else + return std::string("not available"); +#endif +} + +// similar to +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp +std::string get_compiler_version() { + std::ostringstream ss; +#if defined(__GNUC__) +#ifndef __clang__ + { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; } +#endif +#endif + +#if defined(__clang_major__) + { + ss << "clang " << __clang_major__ << "." << __clang_minor__ << "." + << __clang_patchlevel__; + } +#endif + +#if defined(_MSC_VER) + { ss << "MSVC " << _MSC_FULL_VER; } +#endif + return ss.str(); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} + +} // namespace groundingdino \ No newline at end of file diff --git a/groundingdino/models/GroundingDINO/fuse_modules.py b/groundingdino/models/GroundingDINO/fuse_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2753b3ddee43c7a9fe28d1824db5d786e7e1ad59 --- /dev/null +++ b/groundingdino/models/GroundingDINO/fuse_modules.py @@ -0,0 +1,297 @@ +# ------------------------------------------------------------------------ +# Grounding DINO +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath + + +class FeatureResizer(nn.Module): + """ + This class takes as input a set of embeddings of dimension C1 and outputs a set of + embedding of dimension C2, after a linear transformation, dropout and normalization (LN). + """ + + def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): + super().__init__() + self.do_ln = do_ln + # Object feature encoding + self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) + self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) + self.dropout = nn.Dropout(dropout) + + def forward(self, encoder_features): + x = self.fc(encoder_features) + if self.do_ln: + x = self.layer_norm(x) + output = self.dropout(x) + return output + + +def l1norm(X, dim, eps=1e-8): + """L1-normalize columns of X""" + norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps + X = torch.div(X, norm) + return X + + +def l2norm(X, dim, eps=1e-8): + """L2-normalize columns of X""" + norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps + X = torch.div(X, norm) + return X + + +def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8): + """ + query: (n_context, queryL, d) + context: (n_context, sourceL, d) + """ + batch_size_q, queryL = query.size(0), query.size(1) + batch_size, sourceL = context.size(0), context.size(1) + + # Get attention + # --> (batch, d, queryL) + queryT = torch.transpose(query, 1, 2) + + # (batch, sourceL, d)(batch, d, queryL) + # --> (batch, sourceL, queryL) + attn = torch.bmm(context, queryT) + if raw_feature_norm == "softmax": + # --> (batch*sourceL, queryL) + attn = attn.view(batch_size * sourceL, queryL) + attn = nn.Softmax()(attn) + # --> (batch, sourceL, queryL) + attn = attn.view(batch_size, sourceL, queryL) + elif raw_feature_norm == "l2norm": + attn = l2norm(attn, 2) + elif raw_feature_norm == "clipped_l2norm": + attn = nn.LeakyReLU(0.1)(attn) + attn = l2norm(attn, 2) + else: + raise ValueError("unknown first norm type:", raw_feature_norm) + # --> (batch, queryL, sourceL) + attn = torch.transpose(attn, 1, 2).contiguous() + # --> (batch*queryL, sourceL) + attn = attn.view(batch_size * queryL, sourceL) + attn = nn.Softmax()(attn * smooth) + # --> (batch, queryL, sourceL) + attn = attn.view(batch_size, queryL, sourceL) + # --> (batch, sourceL, queryL) + attnT = torch.transpose(attn, 1, 2).contiguous() + + # --> (batch, d, sourceL) + contextT = torch.transpose(context, 1, 2) + # (batch x d x sourceL)(batch x sourceL x queryL) + # --> (batch, d, queryL) + weightedContext = torch.bmm(contextT, attnT) + # --> (batch, queryL, d) + weightedContext = torch.transpose(weightedContext, 1, 2) + + return weightedContext, attnT + + +class BiMultiHeadAttention(nn.Module): + def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None): + super(BiMultiHeadAttention, self).__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.v_dim = v_dim + self.l_dim = l_dim + + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + self.scale = self.head_dim ** (-0.5) + self.dropout = dropout + + self.v_proj = nn.Linear(self.v_dim, self.embed_dim) + self.l_proj = nn.Linear(self.l_dim, self.embed_dim) + self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim) + self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim) + + self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim) + self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim) + + self.stable_softmax_2d = True + self.clamp_min_for_underflow = True + self.clamp_max_for_overflow = True + + self._reset_parameters() + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def _reset_parameters(self): + nn.init.xavier_uniform_(self.v_proj.weight) + self.v_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.l_proj.weight) + self.l_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.values_v_proj.weight) + self.values_v_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.values_l_proj.weight) + self.values_l_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_v_proj.weight) + self.out_v_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_l_proj.weight) + self.out_l_proj.bias.data.fill_(0) + + def forward(self, v, l, attention_mask_v=None, attention_mask_l=None): + """_summary_ + + Args: + v (_type_): bs, n_img, dim + l (_type_): bs, n_text, dim + attention_mask_v (_type_, optional): _description_. bs, n_img + attention_mask_l (_type_, optional): _description_. bs, n_text + + Returns: + _type_: _description_ + """ + # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': + # import ipdb; ipdb.set_trace() + bsz, tgt_len, _ = v.size() + + query_states = self.v_proj(v) * self.scale + key_states = self._shape(self.l_proj(l), -1, bsz) + value_v_states = self._shape(self.values_v_proj(v), -1, bsz) + value_l_states = self._shape(self.values_l_proj(l), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_v_states = value_v_states.view(*proj_shape) + value_l_states = value_l_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) + + if self.stable_softmax_2d: + attn_weights = attn_weights - attn_weights.max() + + if self.clamp_min_for_underflow: + attn_weights = torch.clamp( + attn_weights, min=-50000 + ) # Do not increase -50000, data type half has quite limited range + if self.clamp_max_for_overflow: + attn_weights = torch.clamp( + attn_weights, max=50000 + ) # Do not increase 50000, data type half has quite limited range + + attn_weights_T = attn_weights.transpose(1, 2) + attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0] + if self.clamp_min_for_underflow: + attn_weights_l = torch.clamp( + attn_weights_l, min=-50000 + ) # Do not increase -50000, data type half has quite limited range + if self.clamp_max_for_overflow: + attn_weights_l = torch.clamp( + attn_weights_l, max=50000 + ) # Do not increase 50000, data type half has quite limited range + + # mask vison for language + if attention_mask_v is not None: + attention_mask_v = ( + attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1) + ) + attn_weights_l.masked_fill_(attention_mask_v, float("-inf")) + + attn_weights_l = attn_weights_l.softmax(dim=-1) + + # mask language for vision + if attention_mask_l is not None: + attention_mask_l = ( + attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1) + ) + attn_weights.masked_fill_(attention_mask_l, float("-inf")) + attn_weights_v = attn_weights.softmax(dim=-1) + + attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training) + attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training) + + attn_output_v = torch.bmm(attn_probs_v, value_l_states) + attn_output_l = torch.bmm(attn_probs_l, value_v_states) + + if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}" + ) + + if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim): + raise ValueError( + f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}" + ) + + attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output_v = attn_output_v.transpose(1, 2) + attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim) + + attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim) + attn_output_l = attn_output_l.transpose(1, 2) + attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim) + + attn_output_v = self.out_v_proj(attn_output_v) + attn_output_l = self.out_l_proj(attn_output_l) + + return attn_output_v, attn_output_l + + +# Bi-Direction MHA (text->image, image->text) +class BiAttentionBlock(nn.Module): + def __init__( + self, + v_dim, + l_dim, + embed_dim, + num_heads, + dropout=0.1, + drop_path=0.0, + init_values=1e-4, + cfg=None, + ): + """ + Inputs: + embed_dim - Dimensionality of input and attention feature vectors + hidden_dim - Dimensionality of hidden layer in feed-forward network + (usually 2-4x larger than embed_dim) + num_heads - Number of heads to use in the Multi-Head Attention block + dropout - Amount of dropout to apply in the feed-forward network + """ + super(BiAttentionBlock, self).__init__() + + # pre layer norm + self.layer_norm_v = nn.LayerNorm(v_dim) + self.layer_norm_l = nn.LayerNorm(l_dim) + self.attn = BiMultiHeadAttention( + v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout + ) + + # add layer scale for training stability + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True) + self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True) + + def forward(self, v, l, attention_mask_v=None, attention_mask_l=None): + v = self.layer_norm_v(v) + l = self.layer_norm_l(l) + delta_v, delta_l = self.attn( + v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l + ) + # v, l = v + delta_v, l + delta_l + v = v + self.drop_path(self.gamma_v * delta_v) + l = l + self.drop_path(self.gamma_l * delta_l) + return v, l + + # def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None) diff --git a/groundingdino/models/GroundingDINO/groundingdino.py b/groundingdino/models/GroundingDINO/groundingdino.py new file mode 100644 index 0000000000000000000000000000000000000000..73048bf8e5d8ab155ed32ddc9f35237f479caa54 --- /dev/null +++ b/groundingdino/models/GroundingDINO/groundingdino.py @@ -0,0 +1,385 @@ +# ------------------------------------------------------------------------ +# Grounding DINO +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Conditional DETR model and criterion classes. +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------ +import copy +from typing import List + +import torch +import torch.nn.functional as F +from torch import nn + +from groundingdino.util import get_tokenlizer +from groundingdino.util.misc import ( + NestedTensor, + inverse_sigmoid, + nested_tensor_from_tensor_list, +) + +from groundingdino.models.GroundingDINO.backbone import build_backbone +from groundingdino.models.GroundingDINO.bertwarper import ( + BertModelWarper, + generate_masks_with_special_tokens_and_transfer_map, +) +from groundingdino.models.GroundingDINO.transformer import build_transformer +from groundingdino.models.GroundingDINO.utils import MLP, ContrastiveEmbed + + +class GroundingDINO(nn.Module): + """This is the Cross-Attention Detector module that performs object detection""" + + def __init__( + self, + backbone, + transformer, + num_queries, + aux_loss=False, + iter_update=False, + query_dim=2, + num_feature_levels=1, + nheads=8, + # two stage + two_stage_type="no", # ['no', 'standard'] + dec_pred_bbox_embed_share=True, + two_stage_class_embed_share=True, + two_stage_bbox_embed_share=True, + num_patterns=0, + dn_number=100, + dn_box_noise_scale=0.4, + dn_label_noise_ratio=0.5, + dn_labelbook_size=100, + text_encoder_type="bert-base-uncased", + sub_sentence_present=True, + max_text_len=256, + ): + """Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + Conditional DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + self.hidden_dim = hidden_dim = transformer.d_model + self.num_feature_levels = num_feature_levels + self.nheads = nheads + self.max_text_len = 256 + self.sub_sentence_present = sub_sentence_present + + # setting query dim + self.query_dim = query_dim + assert query_dim == 4 + + # for dn training + self.num_patterns = num_patterns + self.dn_number = dn_number + self.dn_box_noise_scale = dn_box_noise_scale + self.dn_label_noise_ratio = dn_label_noise_ratio + self.dn_labelbook_size = dn_labelbook_size + + # bert + # print("Text Encoder Type is ", text_encoder_type) + self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type) + self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type) + self.bert.pooler.dense.weight.requires_grad_(False) + self.bert.pooler.dense.bias.requires_grad_(False) + self.bert = BertModelWarper(bert_model=self.bert) + + self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True) + nn.init.constant_(self.feat_map.bias.data, 0) + nn.init.xavier_uniform_(self.feat_map.weight.data) + # freeze + + # special tokens + self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"]) + + # prepare input projection layers + if num_feature_levels > 1: + num_backbone_outs = len(backbone.num_channels) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.num_channels[_] + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + ) + for _ in range(num_feature_levels - num_backbone_outs): + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, hidden_dim), + ) + ) + in_channels = hidden_dim + self.input_proj = nn.ModuleList(input_proj_list) + else: + assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!" + self.input_proj = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + ] + ) + + self.backbone = backbone + self.aux_loss = aux_loss + self.box_pred_damping = box_pred_damping = None + + self.iter_update = iter_update + assert iter_update, "Why not iter_update?" + + # prepare pred layers + self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share + # prepare class & box embed + _class_embed = ContrastiveEmbed() + + _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) + + if dec_pred_bbox_embed_share: + box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)] + else: + box_embed_layerlist = [ + copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers) + ] + class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)] + self.bbox_embed = nn.ModuleList(box_embed_layerlist) + self.class_embed = nn.ModuleList(class_embed_layerlist) + self.transformer.decoder.bbox_embed = self.bbox_embed + self.transformer.decoder.class_embed = self.class_embed + + # two stage + self.two_stage_type = two_stage_type + assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format( + two_stage_type + ) + if two_stage_type != "no": + if two_stage_bbox_embed_share: + assert dec_pred_bbox_embed_share + self.transformer.enc_out_bbox_embed = _bbox_embed + else: + self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed) + + if two_stage_class_embed_share: + assert dec_pred_bbox_embed_share + self.transformer.enc_out_class_embed = _class_embed + else: + self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed) + + self.refpoint_embed = None + + self._reset_parameters() + + def _reset_parameters(self): + # init input_proj + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + def init_ref_points(self, use_num_queries): + self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim) + + def forward(self, samples: NestedTensor, targets: List = None, **kw): + """The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x num_classes] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, width, height). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if targets is None: + captions = kw["captions"] + else: + captions = [t["caption"] for t in targets] + len(captions) + + # encoder texts + tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to( + samples.device + ) + ( + text_self_attention_masks, + position_ids, + cate_to_token_mask_list, + ) = generate_masks_with_special_tokens_and_transfer_map( + tokenized, self.specical_tokens, self.tokenizer + ) + + if text_self_attention_masks.shape[1] > self.max_text_len: + text_self_attention_masks = text_self_attention_masks[ + :, : self.max_text_len, : self.max_text_len + ] + position_ids = position_ids[:, : self.max_text_len] + tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len] + tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len] + tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len] + + # extract text embeddings + if self.sub_sentence_present: + tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"} + tokenized_for_encoder["attention_mask"] = text_self_attention_masks + tokenized_for_encoder["position_ids"] = position_ids + else: + # import ipdb; ipdb.set_trace() + tokenized_for_encoder = tokenized + + bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768 + + encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model + text_token_mask = tokenized.attention_mask.bool() # bs, 195 + # text_token_mask: True for nomask, False for mask + # text_self_attention_masks: True for nomask, False for mask + + if encoded_text.shape[1] > self.max_text_len: + encoded_text = encoded_text[:, : self.max_text_len, :] + text_token_mask = text_token_mask[:, : self.max_text_len] + position_ids = position_ids[:, : self.max_text_len] + text_self_attention_masks = text_self_attention_masks[ + :, : self.max_text_len, : self.max_text_len + ] + + text_dict = { + "encoded_text": encoded_text, # bs, 195, d_model + "text_token_mask": text_token_mask, # bs, 195 + "position_ids": position_ids, # bs, 195 + "text_self_attention_masks": text_self_attention_masks, # bs, 195,195 + } + + # import ipdb; ipdb.set_trace() + + if isinstance(samples, (list, torch.Tensor)): + samples = nested_tensor_from_tensor_list(samples) + features, poss = self.backbone(samples) + + srcs = [] + masks = [] + for l, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[l](src)) + masks.append(mask) + assert mask is not None + if self.num_feature_levels > len(srcs): + _len_srcs = len(srcs) + for l in range(_len_srcs, self.num_feature_levels): + if l == _len_srcs: + src = self.input_proj[l](features[-1].tensors) + else: + src = self.input_proj[l](srcs[-1]) + m = samples.mask + mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) + srcs.append(src) + masks.append(mask) + poss.append(pos_l) + + input_query_bbox = input_query_label = attn_mask = dn_meta = None + hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( + srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict + ) + + # deformable-detr-like anchor update + outputs_coord_list = [] + for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate( + zip(reference[:-1], self.bbox_embed, hs) + ): + layer_delta_unsig = layer_bbox_embed(layer_hs) + layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) + layer_outputs_unsig = layer_outputs_unsig.sigmoid() + outputs_coord_list.append(layer_outputs_unsig) + outputs_coord_list = torch.stack(outputs_coord_list) + + # output + outputs_class = torch.stack( + [ + layer_cls_embed(layer_hs, text_dict) + for layer_cls_embed, layer_hs in zip(self.class_embed, hs) + ] + ) + out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]} + + # # for intermediate outputs + # if self.aux_loss: + # out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list) + + # # for encoder output + # if hs_enc is not None: + # # prepare intermediate outputs + # interm_coord = ref_enc[-1] + # interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict) + # out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord} + # out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal} + + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [ + {"pred_logits": a, "pred_boxes": b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) + ] + + +def build_groundingdino(args): + + backbone = build_backbone(args) + + transformer = build_transformer(args) + + dn_labelbook_size = args.dn_labelbook_size + dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share + sub_sentence_present = args.sub_sentence_present + + model = GroundingDINO( + backbone, + transformer, + num_queries=args.num_queries, + aux_loss=True, + iter_update=True, + query_dim=4, + num_feature_levels=args.num_feature_levels, + nheads=args.nheads, + dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, + two_stage_type=args.two_stage_type, + two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, + two_stage_class_embed_share=args.two_stage_class_embed_share, + num_patterns=args.num_patterns, + dn_number=0, + dn_box_noise_scale=args.dn_box_noise_scale, + dn_label_noise_ratio=args.dn_label_noise_ratio, + dn_labelbook_size=dn_labelbook_size, + text_encoder_type=args.text_encoder_type, + sub_sentence_present=sub_sentence_present, + max_text_len=args.max_text_len, + ) + + return model diff --git a/groundingdino/models/GroundingDINO/ms_deform_attn.py b/groundingdino/models/GroundingDINO/ms_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..828945afa48593e9b756d90419ba995bb33d20de --- /dev/null +++ b/groundingdino/models/GroundingDINO/ms_deform_attn.py @@ -0,0 +1,417 @@ +# ------------------------------------------------------------------------ +# Grounding DINO +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from: +# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py +# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py +# https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py +# ------------------------------------------------------------------------------------------------ + +import math +import warnings +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.init import constant_, xavier_uniform_ + +try: + # from groundingdino import _C + from mmcv.utils import ext_loader + + _C = ext_loader.load_ext( + '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) +except: + warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!") + + +# helpers +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + +class MultiScaleDeformableAttnFunction(Function): + @staticmethod + def forward( + ctx, + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step, + ): + ctx.im2col_step = im2col_step + output = _C.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ctx.im2col_step, + ) + ctx.save_for_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + ( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + ctx.im2col_step, + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def multi_scale_deformable_attn_pytorch( + value: torch.Tensor, + value_spatial_shapes: torch.Tensor, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, +) -> torch.Tensor: + + bs, _, num_heads, embed_dims = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (H_, W_) in enumerate(value_spatial_shapes): + # bs, H_*W_, num_heads, embed_dims -> + # bs, H_*W_, num_heads*embed_dims -> + # bs, num_heads*embed_dims, H_*W_ -> + # bs*num_heads, embed_dims, H_, W_ + value_l_ = ( + value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_) + ) + # bs, num_queries, num_heads, num_points, 2 -> + # bs, num_heads, num_queries, num_points, 2 -> + # bs*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) + # bs*num_heads, embed_dims, num_queries, num_points + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (bs, num_queries, num_heads, num_levels, num_points) -> + # (bs, num_heads, num_queries, num_levels, num_points) -> + # (bs, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(bs, num_heads * embed_dims, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +class MultiScaleDeformableAttention(nn.Module): + """Multi-Scale Deformable Attention Module used in Deformable-DETR + + `Deformable DETR: Deformable Transformers for End-to-End Object Detection. + `_. + + Args: + embed_dim (int): The embedding dimension of Attention. Default: 256. + num_heads (int): The number of attention heads. Default: 8. + num_levels (int): The number of feature map used in Attention. Default: 4. + num_points (int): The number of sampling points for each query + in each head. Default: 4. + img2col_steps (int): The step used in image_to_column. Defualt: 64. + dropout (float): Dropout layer used in output. Default: 0.1. + batch_first (bool): if ``True``, then the input and output tensor will be + provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)` + """ + + def __init__( + self, + embed_dim: int = 256, + num_heads: int = 8, + num_levels: int = 4, + num_points: int = 4, + img2col_step: int = 64, + batch_first: bool = False, + ): + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError( + "embed_dim must be divisible by num_heads, but got {} and {}".format( + embed_dim, num_heads + ) + ) + head_dim = embed_dim // num_heads + + self.batch_first = batch_first + + if not _is_power_of_2(head_dim): + warnings.warn( + """ + You'd better set d_model in MSDeformAttn to make sure that + each dim of the attention head a power of 2, which is more efficient. + """ + ) + + self.im2col_step = img2col_step + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_levels = num_levels + self.num_points = num_points + self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2) + self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points) + self.value_proj = nn.Linear(embed_dim, embed_dim) + self.output_proj = nn.Linear(embed_dim, embed_dim) + + self.init_weights() + + def _reset_parameters(self): + return self.init_weights() + + def init_weights(self): + """ + Default initialization for Parameters of Module. + """ + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.num_heads, dtype=torch.float32) * ( + 2.0 * math.pi / self.num_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.num_heads, 1, 1, 2) + .repeat(1, self.num_levels, self.num_points, 1) + ) + for i in range(self.num_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def freeze_sampling_offsets(self): + print("Freeze sampling offsets") + self.sampling_offsets.weight.requires_grad = False + self.sampling_offsets.bias.requires_grad = False + + def freeze_attention_weights(self): + print("Freeze attention weights") + self.attention_weights.weight.requires_grad = False + self.attention_weights.bias.requires_grad = False + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + value: Optional[torch.Tensor] = None, + query_pos: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, + reference_points: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.Tensor] = None, + level_start_index: Optional[torch.Tensor] = None, + **kwargs + ) -> torch.Tensor: + + """Forward Function of MultiScaleDeformableAttention + + Args: + query (torch.Tensor): Query embeddings with shape + `(num_query, bs, embed_dim)` + key (torch.Tensor): Key embeddings with shape + `(num_key, bs, embed_dim)` + value (torch.Tensor): Value embeddings with shape + `(num_key, bs, embed_dim)` + query_pos (torch.Tensor): The position embedding for `query`. Default: None. + key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`, + indicating which elements within `key` to be ignored in attention. + reference_points (torch.Tensor): The normalized reference points + with shape `(bs, num_query, num_levels, 2)`, + all elements is range in [0, 1], top-left (0, 0), + bottom-right (1, 1), including padding are. + or `(N, Length_{query}, num_levels, 4)`, add additional + two dimensions `(h, w)` to form reference boxes. + spatial_shapes (torch.Tensor): Spatial shape of features in different levels. + With shape `(num_levels, 2)`, last dimension represents `(h, w)`. + level_start_index (torch.Tensor): The start index of each level. A tensor with + shape `(num_levels, )` which can be represented as + `[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`. + + Returns: + torch.Tensor: forward results with shape `(num_query, bs, embed_dim)` + """ + + if value is None: + value = query + + if query_pos is not None: + query = query + query_pos + + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + + bs, num_query, _ = query.shape + bs, num_value, _ = value.shape + + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + + value = self.value_proj(value) + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], float(0)) + value = value.view(bs, num_value, self.num_heads, -1) + sampling_offsets = self.sampling_offsets(query).view( + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2 + ) + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_levels * self.num_points + ) + attention_weights = attention_weights.softmax(-1) + attention_weights = attention_weights.view( + bs, + num_query, + self.num_heads, + self.num_levels, + self.num_points, + ) + + # bs, num_query, num_heads, num_levels, num_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets + / self.num_points + * reference_points[:, :, None, :, None, 2:] + * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format( + reference_points.shape[-1] + ) + ) + + if torch.cuda.is_available() and value.is_cuda: + halffloat = False + if value.dtype == torch.float16: + halffloat = True + value = value.float() + sampling_locations = sampling_locations.float() + attention_weights = attention_weights.float() + + output = MultiScaleDeformableAttnFunction.apply( + value, + spatial_shapes, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + + if halffloat: + output = output.half() + else: + output = multi_scale_deformable_attn_pytorch( + value, spatial_shapes, sampling_locations, attention_weights + ) + + output = self.output_proj(output) + + if not self.batch_first: + output = output.permute(1, 0, 2) + + return output + + +def create_dummy_class(klass, dependency, message=""): + """ + When a dependency of a class is not available, create a dummy class which throws ImportError + when used. + + Args: + klass (str): name of the class. + dependency (str): name of the dependency. + message: extra message to print + Returns: + class: a class object + """ + err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass) + if message: + err = err + " " + message + + class _DummyMetaClass(type): + # throw error on class attribute access + def __getattr__(_, __): # noqa: B902 + raise ImportError(err) + + class _Dummy(object, metaclass=_DummyMetaClass): + # throw error on constructor + def __init__(self, *args, **kwargs): + raise ImportError(err) + + return _Dummy + + +def create_dummy_func(func, dependency, message=""): + """ + When a dependency of a function is not available, create a dummy function which throws + ImportError when used. + + Args: + func (str): name of the function. + dependency (str or list[str]): name(s) of the dependency. + message: extra message to print + Returns: + function: a function object + """ + err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func) + if message: + err = err + " " + message + + if isinstance(dependency, (list, tuple)): + dependency = ",".join(dependency) + + def _dummy(*args, **kwargs): + raise ImportError(err) + + return _dummy diff --git a/groundingdino/models/GroundingDINO/transformer.py b/groundingdino/models/GroundingDINO/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb8742dbdde6e80fd38b11d064211f6935aae76 --- /dev/null +++ b/groundingdino/models/GroundingDINO/transformer.py @@ -0,0 +1,959 @@ +# ------------------------------------------------------------------------ +# Grounding DINO +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DINO +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Conditional DETR Transformer class. +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +from typing import Optional + +import torch +import torch.utils.checkpoint as checkpoint +from torch import Tensor, nn + +from groundingdino.util.misc import inverse_sigmoid + +from .fuse_modules import BiAttentionBlock +from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn +from .transformer_vanilla import TransformerEncoderLayer +from .utils import ( + MLP, + _get_activation_fn, + _get_clones, + gen_encoder_output_proposals, + gen_sineembed_for_position, + get_sine_pos_embed, +) + + +class Transformer(nn.Module): + def __init__( + self, + d_model=256, + nhead=8, + num_queries=300, + num_encoder_layers=6, + num_unicoder_layers=0, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.0, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + query_dim=4, + num_patterns=0, + # for deformable encoder + num_feature_levels=1, + enc_n_points=4, + dec_n_points=4, + # init query + learnable_tgt_init=False, + # two stage + two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1'] + embed_init_tgt=False, + # for text + use_text_enhancer=False, + use_fusion_layer=False, + use_checkpoint=False, + use_transformer_ckpt=False, + use_text_cross_attention=False, + text_dropout=0.1, + fusion_dropout=0.1, + fusion_droppath=0.0, + ): + super().__init__() + self.num_feature_levels = num_feature_levels + self.num_encoder_layers = num_encoder_layers + self.num_unicoder_layers = num_unicoder_layers + self.num_decoder_layers = num_decoder_layers + self.num_queries = num_queries + assert query_dim == 4 + + # choose encoder layer type + encoder_layer = DeformableTransformerEncoderLayer( + d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points + ) + + if use_text_enhancer: + text_enhance_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead // 2, + dim_feedforward=dim_feedforward // 2, + dropout=text_dropout, + ) + else: + text_enhance_layer = None + + if use_fusion_layer: + feature_fusion_layer = BiAttentionBlock( + v_dim=d_model, + l_dim=d_model, + embed_dim=dim_feedforward // 2, + num_heads=nhead // 2, + dropout=fusion_dropout, + drop_path=fusion_droppath, + ) + else: + feature_fusion_layer = None + + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + assert encoder_norm is None + self.encoder = TransformerEncoder( + encoder_layer, + num_encoder_layers, + d_model=d_model, + num_queries=num_queries, + text_enhance_layer=text_enhance_layer, + feature_fusion_layer=feature_fusion_layer, + use_checkpoint=use_checkpoint, + use_transformer_ckpt=use_transformer_ckpt, + ) + + # choose decoder layer type + decoder_layer = DeformableTransformerDecoderLayer( + d_model, + dim_feedforward, + dropout, + activation, + num_feature_levels, + nhead, + dec_n_points, + use_text_cross_attention=use_text_cross_attention, + ) + + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + d_model=d_model, + query_dim=query_dim, + num_feature_levels=num_feature_levels, + ) + + self.d_model = d_model + self.nhead = nhead + self.dec_layers = num_decoder_layers + self.num_queries = num_queries # useful for single stage model only + self.num_patterns = num_patterns + if not isinstance(num_patterns, int): + Warning("num_patterns should be int but {}".format(type(num_patterns))) + self.num_patterns = 0 + + if num_feature_levels > 1: + if self.num_encoder_layers > 0: + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + else: + self.level_embed = None + + self.learnable_tgt_init = learnable_tgt_init + assert learnable_tgt_init, "why not learnable_tgt_init" + self.embed_init_tgt = embed_init_tgt + if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"): + self.tgt_embed = nn.Embedding(self.num_queries, d_model) + nn.init.normal_(self.tgt_embed.weight.data) + else: + self.tgt_embed = None + + # for two stage + self.two_stage_type = two_stage_type + assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format( + two_stage_type + ) + if two_stage_type == "standard": + # anchor selection at the output of encoder + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + self.two_stage_wh_embedding = None + + if two_stage_type == "no": + self.init_ref_points(num_queries) # init self.refpoint_embed + + self.enc_out_class_embed = None + self.enc_out_bbox_embed = None + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + if self.num_feature_levels > 1 and self.level_embed is not None: + nn.init.normal_(self.level_embed) + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def init_ref_points(self, use_num_queries): + self.refpoint_embed = nn.Embedding(use_num_queries, 4) + + def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None): + """ + Input: + - srcs: List of multi features [bs, ci, hi, wi] + - masks: List of multi masks [bs, hi, wi] + - refpoint_embed: [bs, num_dn, 4]. None in infer + - pos_embeds: List of multi pos embeds [bs, ci, hi, wi] + - tgt: [bs, num_dn, d_model]. None in infer + + """ + # prepare input for encoder + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + + src = src.flatten(2).transpose(1, 2) # bs, hw, c + mask = mask.flatten(1) # bs, hw + pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c + if self.num_feature_levels > 1 and self.level_embed is not None: + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + else: + lvl_pos_embed = pos_embed + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c + mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=src_flatten.device + ) + level_start_index = torch.cat( + (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) + ) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + # two stage + enc_topk_proposals = enc_refpoint_embed = None + + ######################################################### + # Begin Encoder + ######################################################### + memory, memory_text = self.encoder( + src_flatten, + pos=lvl_pos_embed_flatten, + level_start_index=level_start_index, + spatial_shapes=spatial_shapes, + valid_ratios=valid_ratios, + key_padding_mask=mask_flatten, + memory_text=text_dict["encoded_text"], + text_attention_mask=~text_dict["text_token_mask"], + # we ~ the mask . False means use the token; True means pad the token + position_ids=text_dict["position_ids"], + text_self_attention_masks=text_dict["text_self_attention_masks"], + ) + ######################################################### + # End Encoder + # - memory: bs, \sum{hw}, c + # - mask_flatten: bs, \sum{hw} + # - lvl_pos_embed_flatten: bs, \sum{hw}, c + # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) + # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) + ######################################################### + text_dict["encoded_text"] = memory_text + # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': + # if memory.isnan().any() | memory.isinf().any(): + # import ipdb; ipdb.set_trace() + + if self.two_stage_type == "standard": + output_memory, output_proposals = gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes + ) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + + if text_dict is not None: + enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict) + else: + enc_outputs_class_unselected = self.enc_out_class_embed(output_memory) + + topk_logits = enc_outputs_class_unselected.max(-1)[0] + enc_outputs_coord_unselected = ( + self.enc_out_bbox_embed(output_memory) + output_proposals + ) # (bs, \sum{hw}, 4) unsigmoid + topk = self.num_queries + + topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq + + # gather boxes + refpoint_embed_undetach = torch.gather( + enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) + ) # unsigmoid + refpoint_embed_ = refpoint_embed_undetach.detach() + init_box_proposal = torch.gather( + output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) + ).sigmoid() # sigmoid + + # gather tgt + tgt_undetach = torch.gather( + output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) + ) + if self.embed_init_tgt: + tgt_ = ( + self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) + ) # nq, bs, d_model + else: + tgt_ = tgt_undetach.detach() + + if refpoint_embed is not None: + refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) + tgt = torch.cat([tgt, tgt_], dim=1) + else: + refpoint_embed, tgt = refpoint_embed_, tgt_ + + elif self.two_stage_type == "no": + tgt_ = ( + self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) + ) # nq, bs, d_model + refpoint_embed_ = ( + self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) + ) # nq, bs, 4 + + if refpoint_embed is not None: + refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) + tgt = torch.cat([tgt, tgt_], dim=1) + else: + refpoint_embed, tgt = refpoint_embed_, tgt_ + + if self.num_patterns > 0: + tgt_embed = tgt.repeat(1, self.num_patterns, 1) + refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1) + tgt_pat = self.patterns.weight[None, :, :].repeat_interleave( + self.num_queries, 1 + ) # 1, n_q*n_pat, d_model + tgt = tgt_embed + tgt_pat + + init_box_proposal = refpoint_embed_.sigmoid() + + else: + raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type)) + ######################################################### + # End preparing tgt + # - tgt: bs, NQ, d_model + # - refpoint_embed(unsigmoid): bs, NQ, d_model + ######################################################### + + ######################################################### + # Begin Decoder + ######################################################### + hs, references = self.decoder( + tgt=tgt.transpose(0, 1), + memory=memory.transpose(0, 1), + memory_key_padding_mask=mask_flatten, + pos=lvl_pos_embed_flatten.transpose(0, 1), + refpoints_unsigmoid=refpoint_embed.transpose(0, 1), + level_start_index=level_start_index, + spatial_shapes=spatial_shapes, + valid_ratios=valid_ratios, + tgt_mask=attn_mask, + memory_text=text_dict["encoded_text"], + text_attention_mask=~text_dict["text_token_mask"], + # we ~ the mask . False means use the token; True means pad the token + ) + ######################################################### + # End Decoder + # hs: n_dec, bs, nq, d_model + # references: n_dec+1, bs, nq, query_dim + ######################################################### + + ######################################################### + # Begin postprocess + ######################################################### + if self.two_stage_type == "standard": + hs_enc = tgt_undetach.unsqueeze(0) + ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0) + else: + hs_enc = ref_enc = None + ######################################################### + # End postprocess + # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None + # ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None + ######################################################### + + return hs, references, hs_enc, ref_enc, init_box_proposal + # hs: (n_dec, bs, nq, d_model) + # references: sigmoid coordinates. (n_dec+1, bs, bq, 4) + # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None + # ref_enc: sigmoid coordinates. \ + # (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None + + +class TransformerEncoder(nn.Module): + def __init__( + self, + encoder_layer, + num_layers, + d_model=256, + num_queries=300, + enc_layer_share=False, + text_enhance_layer=None, + feature_fusion_layer=None, + use_checkpoint=False, + use_transformer_ckpt=False, + ): + """_summary_ + + Args: + encoder_layer (_type_): _description_ + num_layers (_type_): _description_ + norm (_type_, optional): _description_. Defaults to None. + d_model (int, optional): _description_. Defaults to 256. + num_queries (int, optional): _description_. Defaults to 300. + enc_layer_share (bool, optional): _description_. Defaults to False. + + """ + super().__init__() + # prepare layers + self.layers = [] + self.text_layers = [] + self.fusion_layers = [] + if num_layers > 0: + self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share) + + if text_enhance_layer is not None: + self.text_layers = _get_clones( + text_enhance_layer, num_layers, layer_share=enc_layer_share + ) + if feature_fusion_layer is not None: + self.fusion_layers = _get_clones( + feature_fusion_layer, num_layers, layer_share=enc_layer_share + ) + else: + self.layers = [] + del encoder_layer + + if text_enhance_layer is not None: + self.text_layers = [] + del text_enhance_layer + if feature_fusion_layer is not None: + self.fusion_layers = [] + del feature_fusion_layer + + self.query_scale = None + self.num_queries = num_queries + self.num_layers = num_layers + self.d_model = d_model + + self.use_checkpoint = use_checkpoint + self.use_transformer_ckpt = use_transformer_ckpt + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward( + self, + # for images + src: Tensor, + pos: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor, + key_padding_mask: Tensor, + # for texts + memory_text: Tensor = None, + text_attention_mask: Tensor = None, + pos_text: Tensor = None, + text_self_attention_masks: Tensor = None, + position_ids: Tensor = None, + ): + """ + Input: + - src: [bs, sum(hi*wi), 256] + - pos: pos embed for src. [bs, sum(hi*wi), 256] + - spatial_shapes: h,w of each level [num_level, 2] + - level_start_index: [num_level] start point of level in sum(hi*wi). + - valid_ratios: [bs, num_level, 2] + - key_padding_mask: [bs, sum(hi*wi)] + + - memory_text: bs, n_text, 256 + - text_attention_mask: bs, n_text + False for no padding; True for padding + - pos_text: bs, n_text, 256 + + - position_ids: bs, n_text + Intermedia: + - reference_points: [bs, sum(hi*wi), num_level, 2] + Outpus: + - output: [bs, sum(hi*wi), 256] + """ + + output = src + + # preparation and reshape + if self.num_layers > 0: + reference_points = self.get_reference_points( + spatial_shapes, valid_ratios, device=src.device + ) + + if self.text_layers: + # generate pos_text + bs, n_text, text_dim = memory_text.shape + if pos_text is None and position_ids is None: + pos_text = ( + torch.arange(n_text, device=memory_text.device) + .float() + .unsqueeze(0) + .unsqueeze(-1) + .repeat(bs, 1, 1) + ) + pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False) + if position_ids is not None: + pos_text = get_sine_pos_embed( + position_ids[..., None], num_pos_feats=256, exchange_xy=False + ) + + # main process + for layer_id, layer in enumerate(self.layers): + # if output.isnan().any() or memory_text.isnan().any(): + # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': + # import ipdb; ipdb.set_trace() + if self.fusion_layers: + if self.use_checkpoint: + output, memory_text = checkpoint.checkpoint( + self.fusion_layers[layer_id], + output, + memory_text, + key_padding_mask, + text_attention_mask, + ) + else: + output, memory_text = self.fusion_layers[layer_id]( + v=output, + l=memory_text, + attention_mask_v=key_padding_mask, + attention_mask_l=text_attention_mask, + ) + + if self.text_layers: + memory_text = self.text_layers[layer_id]( + src=memory_text.transpose(0, 1), + src_mask=~text_self_attention_masks, # note we use ~ for mask here + src_key_padding_mask=text_attention_mask, + pos=(pos_text.transpose(0, 1) if pos_text is not None else None), + ).transpose(0, 1) + + # main process + if self.use_transformer_ckpt: + output = checkpoint.checkpoint( + layer, + output, + pos, + reference_points, + spatial_shapes, + level_start_index, + key_padding_mask, + ) + else: + output = layer( + src=output, + pos=pos, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + key_padding_mask=key_padding_mask, + ) + + return output, memory_text + + +class TransformerDecoder(nn.Module): + def __init__( + self, + decoder_layer, + num_layers, + norm=None, + return_intermediate=False, + d_model=256, + query_dim=4, + num_feature_levels=1, + ): + super().__init__() + if num_layers > 0: + self.layers = _get_clones(decoder_layer, num_layers) + else: + self.layers = [] + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + assert return_intermediate, "support return_intermediate only" + self.query_dim = query_dim + assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) + self.num_feature_levels = num_feature_levels + + self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) + self.query_pos_sine_scale = None + + self.query_scale = None + self.bbox_embed = None + self.class_embed = None + + self.d_model = d_model + + self.ref_anchor_head = None + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 + # for memory + level_start_index: Optional[Tensor] = None, # num_levels + spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 + valid_ratios: Optional[Tensor] = None, + # for text + memory_text: Optional[Tensor] = None, + text_attention_mask: Optional[Tensor] = None, + ): + """ + Input: + - tgt: nq, bs, d_model + - memory: hw, bs, d_model + - pos: hw, bs, d_model + - refpoints_unsigmoid: nq, bs, 2/4 + - valid_ratios/spatial_shapes: bs, nlevel, 2 + """ + output = tgt + + intermediate = [] + reference_points = refpoints_unsigmoid.sigmoid() + ref_points = [reference_points] + + for layer_id, layer in enumerate(self.layers): + + if reference_points.shape[-1] == 4: + reference_points_input = ( + reference_points[:, :, None] + * torch.cat([valid_ratios, valid_ratios], -1)[None, :] + ) # nq, bs, nlevel, 4 + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * valid_ratios[None, :] + query_sine_embed = gen_sineembed_for_position( + reference_points_input[:, :, 0, :] + ) # nq, bs, 256*2 + + # conditional query + raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 + pos_scale = self.query_scale(output) if self.query_scale is not None else 1 + query_pos = pos_scale * raw_query_pos + # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': + # if query_pos.isnan().any() | query_pos.isinf().any(): + # import ipdb; ipdb.set_trace() + + # main process + output = layer( + tgt=output, + tgt_query_pos=query_pos, + tgt_query_sine_embed=query_sine_embed, + tgt_key_padding_mask=tgt_key_padding_mask, + tgt_reference_points=reference_points_input, + memory_text=memory_text, + text_attention_mask=text_attention_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + memory_level_start_index=level_start_index, + memory_spatial_shapes=spatial_shapes, + memory_pos=pos, + self_attn_mask=tgt_mask, + cross_attn_mask=memory_mask, + ) + if output.isnan().any() | output.isinf().any(): + print(f"output layer_id {layer_id} is nan") + try: + num_nan = output.isnan().sum().item() + num_inf = output.isinf().sum().item() + print(f"num_nan {num_nan}, num_inf {num_inf}") + except Exception as e: + print(e) + # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': + # import ipdb; ipdb.set_trace() + + # iter update + if self.bbox_embed is not None: + # box_holder = self.bbox_embed(output) + # box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points) + # new_reference_points = box_holder[..., :self.query_dim].sigmoid() + + reference_before_sigmoid = inverse_sigmoid(reference_points) + delta_unsig = self.bbox_embed[layer_id](output) + outputs_unsig = delta_unsig + reference_before_sigmoid + new_reference_points = outputs_unsig.sigmoid() + + reference_points = new_reference_points.detach() + # if layer_id != self.num_layers - 1: + ref_points.append(new_reference_points) + + intermediate.append(self.norm(output)) + + return [ + [itm_out.transpose(0, 1) for itm_out in intermediate], + [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points], + ] + + +class DeformableTransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model=256, + d_ffn=1024, + dropout=0.1, + activation="relu", + n_levels=4, + n_heads=8, + n_points=4, + ): + super().__init__() + + # self attention + self.self_attn = MSDeformAttn( + embed_dim=d_model, + num_levels=n_levels, + num_heads=n_heads, + num_points=n_points, + batch_first=True, + ) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation, d_model=d_ffn) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward( + self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None + ): + # self attention + # import ipdb; ipdb.set_trace() + src2 = self.self_attn( + query=self.with_pos_embed(src, pos), + reference_points=reference_points, + value=src, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + key_padding_mask=key_padding_mask, + ) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # ffn + src = self.forward_ffn(src) + + return src + + +class DeformableTransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model=256, + d_ffn=1024, + dropout=0.1, + activation="relu", + n_levels=4, + n_heads=8, + n_points=4, + use_text_feat_guide=False, + use_text_cross_attention=False, + ): + super().__init__() + + # cross attention + self.cross_attn = MSDeformAttn( + embed_dim=d_model, + num_levels=n_levels, + num_heads=n_heads, + num_points=n_points, + batch_first=True, + ) + self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm1 = nn.LayerNorm(d_model) + + # cross attention text + if use_text_cross_attention: + self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.catext_norm = nn.LayerNorm(d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1) + self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm3 = nn.LayerNorm(d_model) + + self.key_aware_proj = None + self.use_text_feat_guide = use_text_feat_guide + assert not use_text_feat_guide + self.use_text_cross_attention = use_text_cross_attention + + def rm_self_attn_modules(self): + self.self_attn = None + self.dropout2 = None + self.norm2 = None + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + with torch.cuda.amp.autocast(enabled=False): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward( + self, + # for tgt + tgt: Optional[Tensor], # nq, bs, d_model + tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) + tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) + tgt_key_padding_mask: Optional[Tensor] = None, + tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 + memory_text: Optional[Tensor] = None, # bs, num_token, d_model + text_attention_mask: Optional[Tensor] = None, # bs, num_token + # for memory + memory: Optional[Tensor] = None, # hw, bs, d_model + memory_key_padding_mask: Optional[Tensor] = None, + memory_level_start_index: Optional[Tensor] = None, # num_levels + memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 + memory_pos: Optional[Tensor] = None, # pos for memory + # sa + self_attn_mask: Optional[Tensor] = None, # mask used for self-attention + cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention + ): + """ + Input: + - tgt/tgt_query_pos: nq, bs, d_model + - + """ + assert cross_attn_mask is None + + # self attention + if self.self_attn is not None: + # import ipdb; ipdb.set_trace() + q = k = self.with_pos_embed(tgt, tgt_query_pos) + tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + if self.use_text_cross_attention: + tgt2 = self.ca_text( + self.with_pos_embed(tgt, tgt_query_pos), + memory_text.transpose(0, 1), + memory_text.transpose(0, 1), + key_padding_mask=text_attention_mask, + )[0] + tgt = tgt + self.catext_dropout(tgt2) + tgt = self.catext_norm(tgt) + + tgt2 = self.cross_attn( + query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), + reference_points=tgt_reference_points.transpose(0, 1).contiguous(), + value=memory.transpose(0, 1), + spatial_shapes=memory_spatial_shapes, + level_start_index=memory_level_start_index, + key_padding_mask=memory_key_padding_mask, + ).transpose(0, 1) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + num_queries=args.num_queries, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + query_dim=args.query_dim, + activation=args.transformer_activation, + num_patterns=args.num_patterns, + num_feature_levels=args.num_feature_levels, + enc_n_points=args.enc_n_points, + dec_n_points=args.dec_n_points, + learnable_tgt_init=True, + # two stage + two_stage_type=args.two_stage_type, # ['no', 'standard', 'early'] + embed_init_tgt=args.embed_init_tgt, + use_text_enhancer=args.use_text_enhancer, + use_fusion_layer=args.use_fusion_layer, + use_checkpoint=args.use_checkpoint, + use_transformer_ckpt=args.use_transformer_ckpt, + use_text_cross_attention=args.use_text_cross_attention, + text_dropout=args.text_dropout, + fusion_dropout=args.fusion_dropout, + fusion_droppath=args.fusion_droppath, + ) diff --git a/groundingdino/models/GroundingDINO/transformer_vanilla.py b/groundingdino/models/GroundingDINO/transformer_vanilla.py new file mode 100644 index 0000000000000000000000000000000000000000..10c0920c1a217af5bb3e1b13077568035ab3b7b5 --- /dev/null +++ b/groundingdino/models/GroundingDINO/transformer_vanilla.py @@ -0,0 +1,123 @@ +# ------------------------------------------------------------------------ +# Grounding DINO +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from .utils import ( + MLP, + _get_activation_fn, + _get_clones, + gen_encoder_output_proposals, + gen_sineembed_for_position, + sigmoid_focal_loss, +) + + +class TextTransformer(nn.Module): + def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1): + super().__init__() + self.num_layers = num_layers + self.d_model = d_model + self.nheads = nheads + self.dim_feedforward = dim_feedforward + self.norm = None + + single_encoder_layer = TransformerEncoderLayer( + d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout + ) + self.layers = _get_clones(single_encoder_layer, num_layers) + + def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor): + """ + + Args: + text_attention_mask: bs, num_token + memory_text: bs, num_token, d_model + + Raises: + RuntimeError: _description_ + + Returns: + output: bs, num_token, d_model + """ + + output = memory_text.transpose(0, 1) + + for layer in self.layers: + output = layer(output, src_key_padding_mask=text_attention_mask) + + if self.norm is not None: + output = self.norm(output) + + return output.transpose(0, 1) + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + self.nhead = nhead + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + # repeat attn mask + if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]: + # bs, num_q, num_k + src_mask = src_mask.repeat(self.nhead, 1, 1) + + q = k = self.with_pos_embed(src, pos) + + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0] + + # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src diff --git a/groundingdino/models/GroundingDINO/utils.py b/groundingdino/models/GroundingDINO/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd18f70225e12b2e27fdb4eabcde91d959f8e31 --- /dev/null +++ b/groundingdino/models/GroundingDINO/utils.py @@ -0,0 +1,268 @@ +# ------------------------------------------------------------------------ +# Grounding DINO +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import copy +import math + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +def _get_clones(module, N, layer_share=False): + # import ipdb; ipdb.set_trace() + if layer_share: + return nn.ModuleList([module for i in range(N)]) + else: + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def get_sine_pos_embed( + pos_tensor: torch.Tensor, + num_pos_feats: int = 128, + temperature: int = 10000, + exchange_xy: bool = True, +): + """generate sine position embedding from a position tensor + Args: + pos_tensor (torch.Tensor): shape: [..., n]. + num_pos_feats (int): projected shape for each float in the tensor. + temperature (int): temperature in the sine/cosine function. + exchange_xy (bool, optional): exchange pos x and pos y. \ + For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True. + Returns: + pos_embed (torch.Tensor): shape: [..., n*num_pos_feats]. + """ + scale = 2 * math.pi + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device) + dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) + + def sine_func(x: torch.Tensor): + sin_x = x * scale / dim_t + sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2) + return sin_x + + pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)] + if exchange_xy: + pos_res[0], pos_res[1] = pos_res[1], pos_res[0] + pos_res = torch.cat(pos_res, dim=-1) + return pos_res + + +def gen_encoder_output_proposals( + memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None +): + """ + Input: + - memory: bs, \sum{hw}, d_model + - memory_padding_mask: bs, \sum{hw} + - spatial_shapes: nlevel, 2 + - learnedwh: 2 + Output: + - output_memory: bs, \sum{hw}, d_model + - output_proposals: bs, \sum{hw}, 4 + """ + N_, S_, C_ = memory.shape + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + # import ipdb; ipdb.set_trace() + + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device), + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + + if learnedwh is not None: + # import ipdb; ipdb.set_trace() + wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl) + else: + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + + # scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1) + # grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + # wh = torch.ones_like(grid) / scale + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += H_ * W_ + # import ipdb; ipdb.set_trace() + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all( + -1, keepdim=True + ) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf")) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + + # output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) + # output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf')) + + return output_memory, output_proposals + + +class RandomBoxPerturber: + def __init__( + self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2 + ) -> None: + self.noise_scale = torch.Tensor( + [x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale] + ) + + def __call__(self, refanchors: Tensor) -> Tensor: + nq, bs, query_dim = refanchors.shape + device = refanchors.device + + noise_raw = torch.rand_like(refanchors) + noise_scale = self.noise_scale.to(device)[:query_dim] + + new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale) + return new_refanchors.clamp_(0, 1) + + +def sigmoid_focal_loss( + inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False +): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if no_reduction: + return loss + + return loss.mean(1).sum() / num_boxes + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def _get_activation_fn(activation, d_model=256, batch_dim=0): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + if activation == "prelu": + return nn.PReLU() + if activation == "selu": + return F.selu + + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def gen_sineembed_for_position(pos_tensor): + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + scale = 2 * math.pi + dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) + return pos + + +class ContrastiveEmbed(nn.Module): + def __init__(self, max_text_len=256): + """ + Args: + max_text_len: max length of text. + """ + super().__init__() + self.max_text_len = max_text_len + + def forward(self, x, text_dict): + """_summary_ + + Args: + x (_type_): _description_ + text_dict (_type_): _description_ + { + 'encoded_text': encoded_text, # bs, 195, d_model + 'text_token_mask': text_token_mask, # bs, 195 + # True for used tokens. False for padding tokens + } + Returns: + _type_: _description_ + """ + assert isinstance(text_dict, dict) + + y = text_dict["encoded_text"] + text_token_mask = text_dict["text_token_mask"] + + res = x @ y.transpose(-1, -2) + res.masked_fill_(~text_token_mask[:, None, :], float("-inf")) + + # padding to max_text_len + new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device) + new_res[..., : res.shape[-1]] = res + + return new_res diff --git a/groundingdino/models/GroundingDINO/version.py b/groundingdino/models/GroundingDINO/version.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc1f76bc69e3f559bee6253b24fc93acee9e1f9 --- /dev/null +++ b/groundingdino/models/GroundingDINO/version.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/groundingdino/models/__init__.py b/groundingdino/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/groundingdino/util/__init__.py b/groundingdino/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..168f9979a4623806934b0ff1102ac166704e7dec --- /dev/null +++ b/groundingdino/util/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/groundingdino/util/box_ops.py b/groundingdino/util/box_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..781068d294e576954edb4bd07b6e0f30e4e1bcd9 --- /dev/null +++ b/groundingdino/util/box_ops.py @@ -0,0 +1,140 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + # import ipdb; ipdb.set_trace() + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / (union + 1e-6) + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + # except: + # import ipdb; ipdb.set_trace() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / (area + 1e-6) + + +# modified from torchvision to also return the union +def box_iou_pairwise(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] + rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] + + wh = (rb - lt).clamp(min=0) # [N,2] + inter = wh[:, 0] * wh[:, 1] # [N] + + union = area1 + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou_pairwise(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + Input: + - boxes1, boxes2: N,4 + Output: + - giou: N, 4 + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + assert boxes1.shape == boxes2.shape + iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4 + + lt = torch.min(boxes1[:, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,2] + area = wh[:, 0] * wh[:, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = masks * x.unsqueeze(0) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = masks * y.unsqueeze(0) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) + + +if __name__ == "__main__": + x = torch.rand(5, 4) + y = torch.rand(3, 4) + iou, union = box_iou(x, y) + import ipdb + + ipdb.set_trace() diff --git a/groundingdino/util/get_tokenlizer.py b/groundingdino/util/get_tokenlizer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd2d972b4278e04a1ebef7d5e77aecd4eaf4205b --- /dev/null +++ b/groundingdino/util/get_tokenlizer.py @@ -0,0 +1,29 @@ +from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast +import os + +def get_tokenlizer(text_encoder_type): + if not isinstance(text_encoder_type, str): + # print("text_encoder_type is not a str") + if hasattr(text_encoder_type, "text_encoder_type"): + text_encoder_type = text_encoder_type.text_encoder_type + elif text_encoder_type.get("text_encoder_type", False): + text_encoder_type = text_encoder_type.get("text_encoder_type") + elif os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type): + pass + else: + raise ValueError( + "Unknown type of text_encoder_type: {}".format(type(text_encoder_type)) + ) + print("final text_encoder_type: {}".format(text_encoder_type)) + + tokenizer = AutoTokenizer.from_pretrained(text_encoder_type) + return tokenizer + + +def get_pretrained_language_model(text_encoder_type): + if text_encoder_type == "bert-base-uncased" or (os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)): + return BertModel.from_pretrained(text_encoder_type) + if text_encoder_type == "roberta-base": + return RobertaModel.from_pretrained(text_encoder_type) + + raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type)) diff --git a/groundingdino/util/logger.py b/groundingdino/util/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..18145f54c927abd59b95f3fa6e6da8002bc2ce97 --- /dev/null +++ b/groundingdino/util/logger.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import functools +import logging +import os +import sys + +from termcolor import colored + + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop("root_name") + "." + self._abbrev_name = kwargs.pop("abbrev_name", "") + if len(self._abbrev_name): + self._abbrev_name = self._abbrev_name + "." + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + record.name = record.name.replace(self._root_name, self._abbrev_name) + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +# so that calling setup_logger multiple times won't add many handlers +@functools.lru_cache() +def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None): + """ + Initialize the detectron2 logger and set its verbosity level to "INFO". + + Args: + output (str): a file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name (str): the root module name of this logger + + Returns: + logging.Logger: a logger + """ + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.propagate = False + + if abbrev_name is None: + abbrev_name = name + + plain_formatter = logging.Formatter( + "[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S" + ) + # stdout logging: master only + if distributed_rank == 0: + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + root_name=name, + abbrev_name=str(abbrev_name), + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + logger.addHandler(ch) + + # file logging: all workers + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "log.txt") + if distributed_rank > 0: + filename = filename + f".rank{distributed_rank}" + os.makedirs(os.path.dirname(filename), exist_ok=True) + + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging.DEBUG) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + return logger + + +# cache the opened file object, so that different calls to `setup_logger` +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + return open(filename, "a") diff --git a/groundingdino/util/misc.py b/groundingdino/util/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..d64b84ef24bea0c98e76824feb1903f6bfebe7a5 --- /dev/null +++ b/groundingdino/util/misc.py @@ -0,0 +1,717 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import colorsys +import datetime +import functools +import io +import json +import os +import pickle +import subprocess +import time +from collections import OrderedDict, defaultdict, deque +from typing import List, Optional + +import numpy as np +import torch +import torch.distributed as dist + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +from torch import Tensor + +__torchvision_need_compat_flag = float(torchvision.__version__.split(".")[1]) < 7 +if __torchvision_need_compat_flag: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + if d.shape[0] == 0: + return 0 + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + if os.environ.get("SHILONG_AMP", None) == "1": + eps = 1e-4 + else: + eps = 1e-6 + return self.total / (self.count + eps) + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + + return dist.group.WORLD + + +def all_gather_cpu(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + + world_size = get_world_size() + if world_size == 1: + return [data] + + cpu_group = _get_global_gloo_group() + + buffer = io.BytesIO() + torch.save(data, buffer) + data_view = buffer.getbuffer() + device = "cuda" if cpu_group is None else "cpu" + tensor = torch.ByteTensor(data_view).to(device) + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long) + size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)] + if cpu_group is None: + dist.all_gather(size_list, local_size) + else: + print("gathering on cpu") + dist.all_gather(size_list, local_size, group=cpu_group) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + assert isinstance(local_size.item(), int) + local_size = int(local_size.item()) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device) + tensor = torch.cat((tensor, padding), dim=0) + if cpu_group is None: + dist.all_gather(tensor_list, tensor) + else: + dist.all_gather(tensor_list, tensor, group=cpu_group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] + buffer = io.BytesIO(tensor.cpu().numpy()) + obj = torch.load(buffer) + data_list.append(obj) + + return data_list + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + + if os.getenv("CPU_REDUCE") == "1": + return all_gather_cpu(data) + + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + # print(name, str(meter)) + # import ipdb;ipdb.set_trace() + if meter.count > 0: + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None, logger=None): + if logger is None: + print_func = print + else: + print_func = logger.info + + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + # import ipdb; ipdb.set_trace() + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print_func( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print_func( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print_func( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommited changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + # import ipdb; ipdb.set_trace() + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + if mask == "auto": + self.mask = torch.zeros_like(tensors).to(tensors.device) + if self.mask.dim() == 3: + self.mask = self.mask.sum(0).to(bool) + elif self.mask.dim() == 4: + self.mask = self.mask.sum(1).to(bool) + else: + raise ValueError( + "tensors dim must be 3 or 4 but {}({})".format( + self.tensors.dim(), self.tensors.shape + ) + ) + + def imgsize(self): + res = [] + for i in range(self.tensors.shape[0]): + mask = self.mask[i] + maxH = (~mask).sum(0).max() + maxW = (~mask).sum(1).max() + res.append(torch.Tensor([maxH, maxW])) + return res + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def to_img_list_single(self, tensor, mask): + assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim()) + maxH = (~mask).sum(0).max() + maxW = (~mask).sum(1).max() + img = tensor[:, :maxH, :maxW] + return img + + def to_img_list(self): + """remove the padding and convert to img list + + Returns: + [type]: [description] + """ + if self.tensors.dim() == 3: + return self.to_img_list_single(self.tensors, self.mask) + else: + res = [] + for i in range(self.tensors.shape[0]): + tensor_i = self.tensors[i] + mask_i = self.mask[i] + res.append(self.to_img_list_single(tensor_i, mask_i)) + return res + + @property + def device(self): + return self.tensors.device + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + @property + def shape(self): + return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape} + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max( + torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) + ).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if "WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != "": # 'RANK' in os.environ and + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = args.local_rank = int(os.environ["LOCAL_RANK"]) + + # launch by torch.distributed.launch + # Single node + # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ... + # Multi nodes + # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... + # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... + # args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK')) + # local_world_size = int(os.environ['GPU_PER_NODE_COUNT']) + # args.world_size = args.world_size * local_world_size + # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) + # args.rank = args.rank * local_world_size + args.local_rank + print( + "world size: {}, rank: {}, local rank: {}".format( + args.world_size, args.rank, args.local_rank + ) + ) + print(json.dumps(dict(os.environ), indent=2)) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.local_rank = int(os.environ["SLURM_LOCALID"]) + args.world_size = int(os.environ["SLURM_NPROCS"]) + + print( + "world size: {}, world rank: {}, local rank: {}, device_count: {}".format( + args.world_size, args.rank, args.local_rank, torch.cuda.device_count() + ) + ) + else: + print("Not using distributed mode") + args.distributed = False + args.world_size = 1 + args.rank = 0 + args.local_rank = 0 + return + + print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank)) + args.distributed = True + torch.cuda.set_device(args.local_rank) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + + torch.distributed.init_process_group( + backend=args.dist_backend, + world_size=args.world_size, + rank=args.rank, + init_method=args.dist_url, + ) + + print("Before torch.distributed.barrier()") + torch.distributed.barrier() + print("End torch.distributed.barrier()") + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +@torch.no_grad() +def accuracy_onehot(pred, gt): + """_summary_ + + Args: + pred (_type_): n, c + gt (_type_): n, c + """ + tp = ((pred - gt).abs().sum(-1) < 1e-4).float().sum() + acc = tp / gt.shape[0] * 100 + return acc + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if __torchvision_need_compat_flag < 0.7: + if input.numel() > 0: + return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + +class color_sys: + def __init__(self, num_colors) -> None: + self.num_colors = num_colors + colors = [] + for i in np.arange(0.0, 360.0, 360.0 / num_colors): + hue = i / 360.0 + lightness = (50 + np.random.rand() * 10) / 100.0 + saturation = (90 + np.random.rand() * 10) / 100.0 + colors.append( + tuple([int(j * 255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)]) + ) + self.colors = colors + + def __call__(self, idx): + return self.colors[idx] + + +def inverse_sigmoid(x, eps=1e-3): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def clean_state_dict(state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k[:7] == "module.": + k = k[7:] # remove `module.` + new_state_dict[k] = v + return new_state_dict diff --git a/groundingdino/util/predict.py b/groundingdino/util/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a165360112b7d8266f86c112be943b46584987 --- /dev/null +++ b/groundingdino/util/predict.py @@ -0,0 +1,46 @@ +from typing import Tuple, List + +import torch + +from groundingdino.util.utils import get_phrases_from_posmap + + +def preprocess_caption(caption: str) -> str: + result = caption.lower().strip() + if result.endswith("."): + return result + return result + "." + +def predict( + model, + image: torch.Tensor, + caption: str, + box_threshold: float, + text_threshold: float, + device: str = "cuda" +) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: + caption = preprocess_caption(caption=caption) + + model = model.to(device) + image = image.to(device) + + with torch.no_grad(): + outputs = model(image[None], captions=[caption]) + + prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) + prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) + + mask = prediction_logits.max(dim=1)[0] > box_threshold + logits = prediction_logits[mask] # logits.shape = (n, 256) + boxes = prediction_boxes[mask] # boxes.shape = (n, 4) + + tokenizer = model.tokenizer + tokenized = tokenizer(caption) + print(tokenized) + phrases = [ + get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') + for logit + in logits + ] + + return boxes, logits.max(dim=1)[0], phrases \ No newline at end of file diff --git a/groundingdino/util/slconfig.py b/groundingdino/util/slconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..672e72ed0b68a54c13ade66c9f146d2d542e97c6 --- /dev/null +++ b/groundingdino/util/slconfig.py @@ -0,0 +1,427 @@ +# ========================================================== +# Modified from mmcv +# ========================================================== +import ast +import os +import os.path as osp +import shutil +import sys +import tempfile +from argparse import Action +from importlib import import_module + +from addict import Dict +from yapf.yapflib.yapf_api import FormatCode + +BASE_KEY = "_base_" +DELETE_KEY = "_delete_" +RESERVED_KEYS = ["filename", "text", "pretty_text", "get", "dump", "merge_from_dict"] + + +def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): + if not osp.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +class ConfigDict(Dict): + def __missing__(self, name): + raise KeyError(name) + + def __getattr__(self, name): + try: + value = super(ConfigDict, self).__getattr__(name) + except KeyError: + ex = AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'") + except Exception as e: + ex = e + else: + return value + raise ex + + +class SLConfig(object): + """ + config files. + only support .py file as config now. + + ref: mmcv.utils.config + + Example: + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> cfg.a + 1 + >>> cfg.b + {'b1': [0, 1]} + >>> cfg.b.b1 + [0, 1] + >>> cfg = Config.fromfile('tests/data/config/a.py') + >>> cfg.filename + "/home/kchen/projects/mmcv/tests/data/config/a.py" + >>> cfg.item4 + 'test' + >>> cfg + "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " + "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" + """ + + @staticmethod + def _validate_py_syntax(filename): + with open(filename) as f: + content = f.read() + try: + ast.parse(content) + except SyntaxError: + raise SyntaxError("There are syntax errors in config " f"file {filename}") + + @staticmethod + def _file2dict(filename): + filename = osp.abspath(osp.expanduser(filename)) + check_file_exist(filename) + if filename.lower().endswith(".py"): + with tempfile.TemporaryDirectory() as temp_config_dir: + temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py") + temp_config_name = osp.basename(temp_config_file.name) + if os.name == 'nt': + temp_config_file.close() + shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name)) + temp_module_name = osp.splitext(temp_config_name)[0] + sys.path.insert(0, temp_config_dir) + SLConfig._validate_py_syntax(filename) + mod = import_module(temp_module_name) + sys.path.pop(0) + cfg_dict = { + name: value for name, value in mod.__dict__.items() if not name.startswith("__") + } + # delete imported module + del sys.modules[temp_module_name] + # close temp file + temp_config_file.close() + elif filename.lower().endswith((".yml", ".yaml", ".json")): + from .slio import slload + + cfg_dict = slload(filename) + else: + raise IOError("Only py/yml/yaml/json type are supported now!") + + cfg_text = filename + "\n" + with open(filename, "r") as f: + cfg_text += f.read() + + # parse the base file + if BASE_KEY in cfg_dict: + cfg_dir = osp.dirname(filename) + base_filename = cfg_dict.pop(BASE_KEY) + base_filename = base_filename if isinstance(base_filename, list) else [base_filename] + + cfg_dict_list = list() + cfg_text_list = list() + for f in base_filename: + _cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f)) + cfg_dict_list.append(_cfg_dict) + cfg_text_list.append(_cfg_text) + + base_cfg_dict = dict() + for c in cfg_dict_list: + if len(base_cfg_dict.keys() & c.keys()) > 0: + raise KeyError("Duplicate key is not allowed among bases") + # TODO Allow the duplicate key while warnning user + base_cfg_dict.update(c) + + base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict) + cfg_dict = base_cfg_dict + + # merge cfg_text + cfg_text_list.append(cfg_text) + cfg_text = "\n".join(cfg_text_list) + + return cfg_dict, cfg_text + + @staticmethod + def _merge_a_into_b(a, b): + """merge dict `a` into dict `b` (non-inplace). + values in `a` will overwrite `b`. + copy first to avoid inplace modification + + Args: + a ([type]): [description] + b ([type]): [description] + + Returns: + [dict]: [description] + """ + # import ipdb; ipdb.set_trace() + if not isinstance(a, dict): + return a + + b = b.copy() + for k, v in a.items(): + if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): + + if not isinstance(b[k], dict) and not isinstance(b[k], list): + # if : + # import ipdb; ipdb.set_trace() + raise TypeError( + f"{k}={v} in child config cannot inherit from base " + f"because {k} is a dict in the child config but is of " + f"type {type(b[k])} in base config. You may set " + f"`{DELETE_KEY}=True` to ignore the base config" + ) + b[k] = SLConfig._merge_a_into_b(v, b[k]) + elif isinstance(b, list): + try: + _ = int(k) + except: + raise TypeError( + f"b is a list, " f"index {k} should be an int when input but {type(k)}" + ) + b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)]) + else: + b[k] = v + + return b + + @staticmethod + def fromfile(filename): + cfg_dict, cfg_text = SLConfig._file2dict(filename) + return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename) + + def __init__(self, cfg_dict=None, cfg_text=None, filename=None): + if cfg_dict is None: + cfg_dict = dict() + elif not isinstance(cfg_dict, dict): + raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}") + for key in cfg_dict: + if key in RESERVED_KEYS: + raise KeyError(f"{key} is reserved for config file") + + super(SLConfig, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict)) + super(SLConfig, self).__setattr__("_filename", filename) + if cfg_text: + text = cfg_text + elif filename: + with open(filename, "r") as f: + text = f.read() + else: + text = "" + super(SLConfig, self).__setattr__("_text", text) + + @property + def filename(self): + return self._filename + + @property + def text(self): + return self._text + + @property + def pretty_text(self): + + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split("\n") + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = f"'{v}'" + else: + v_str = str(v) + + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: {v_str}" + else: + attr_str = f"{str(k)}={v_str}" + attr_str = _indent(attr_str, indent) + + return attr_str + + def _format_list(k, v, use_mapping=False): + # check if all items in the list are dict + if all(isinstance(_, dict) for _ in v): + v_str = "[\n" + v_str += "\n".join( + f"dict({_indent(_format_dict(v_), indent)})," for v_ in v + ).rstrip(",") + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: {v_str}" + else: + attr_str = f"{str(k)}={v_str}" + attr_str = _indent(attr_str, indent) + "]" + else: + attr_str = _format_basic_types(k, v, use_mapping) + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= not str(key_name).isidentifier() + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = "" + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += "{" + for idx, (k, v) in enumerate(input_dict.items()): + is_last = idx >= len(input_dict) - 1 + end = "" if outest_level or is_last else "," + if isinstance(v, dict): + v_str = "\n" + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: dict({v_str}" + else: + attr_str = f"{str(k)}=dict({v_str}" + attr_str = _indent(attr_str, indent) + ")" + end + elif isinstance(v, list): + attr_str = _format_list(k, v, use_mapping) + end + else: + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += "\n".join(s) + if use_mapping: + r += "}" + return r + + cfg_dict = self._cfg_dict.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + # copied from setup.cfg + yapf_style = dict( + based_on_style="pep8", + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True, + ) + text, _ = FormatCode(text, style_config=yapf_style, verify=True) + + return text + + def __repr__(self): + return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}" + + def __len__(self): + return len(self._cfg_dict) + + def __getattr__(self, name): + # # debug + # print('+'*15) + # print('name=%s' % name) + # print("addr:", id(self)) + # # print('type(self):', type(self)) + # print(self.__dict__) + # print('+'*15) + # if self.__dict__ == {}: + # raise ValueError + + return getattr(self._cfg_dict, name) + + def __getitem__(self, name): + return self._cfg_dict.__getitem__(name) + + def __setattr__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setattr__(name, value) + + def __setitem__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setitem__(name, value) + + def __iter__(self): + return iter(self._cfg_dict) + + def dump(self, file=None): + # import ipdb; ipdb.set_trace() + if file is None: + return self.pretty_text + else: + with open(file, "w") as f: + f.write(self.pretty_text) + + def merge_from_dict(self, options): + """Merge list into cfg_dict + + Merge the dict parsed by MultipleKVAction into this cfg. + + Examples: + >>> options = {'model.backbone.depth': 50, + ... 'model.backbone.with_cp':True} + >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) + >>> cfg.merge_from_dict(options) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict( + ... model=dict(backbone=dict(depth=50, with_cp=True))) + + Args: + options (dict): dict of configs to merge from. + """ + option_cfg_dict = {} + for full_key, v in options.items(): + d = option_cfg_dict + key_list = full_key.split(".") + for subkey in key_list[:-1]: + d.setdefault(subkey, ConfigDict()) + d = d[subkey] + subkey = key_list[-1] + d[subkey] = v + + cfg_dict = super(SLConfig, self).__getattribute__("_cfg_dict") + super(SLConfig, self).__setattr__( + "_cfg_dict", SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict) + ) + + # for multiprocess + def __setstate__(self, state): + self.__init__(state) + + def copy(self): + return SLConfig(self._cfg_dict.copy()) + + def deepcopy(self): + return SLConfig(self._cfg_dict.deepcopy()) + + +class DictAction(Action): + """ + argparse action to split an argument into KEY=VALUE form + on the first = and append to a dictionary. List options should + be passed as comma separated values, i.e KEY=V1,V2,V3 + """ + + @staticmethod + def _parse_int_float_bool(val): + try: + return int(val) + except ValueError: + pass + try: + return float(val) + except ValueError: + pass + if val.lower() in ["true", "false"]: + return True if val.lower() == "true" else False + if val.lower() in ["none", "null"]: + return None + return val + + def __call__(self, parser, namespace, values, option_string=None): + options = {} + for kv in values: + key, val = kv.split("=", maxsplit=1) + val = [self._parse_int_float_bool(v) for v in val.split(",")] + if len(val) == 1: + val = val[0] + options[key] = val + setattr(namespace, self.dest, options) diff --git a/groundingdino/util/slio.py b/groundingdino/util/slio.py new file mode 100644 index 0000000000000000000000000000000000000000..72c1f0f7b82cdc931d381feef64fe15815ba657e --- /dev/null +++ b/groundingdino/util/slio.py @@ -0,0 +1,177 @@ +# ========================================================== +# Modified from mmcv +# ========================================================== + +import json +import pickle +from abc import ABCMeta, abstractmethod +from pathlib import Path + +import yaml + +try: + from yaml import CLoader as Loader, CDumper as Dumper +except ImportError: + from yaml import Loader, Dumper + + +# =========================== +# Rigister handler +# =========================== + + +class BaseFileHandler(metaclass=ABCMeta): + @abstractmethod + def load_from_fileobj(self, file, **kwargs): + pass + + @abstractmethod + def dump_to_fileobj(self, obj, file, **kwargs): + pass + + @abstractmethod + def dump_to_str(self, obj, **kwargs): + pass + + def load_from_path(self, filepath, mode="r", **kwargs): + with open(filepath, mode) as f: + return self.load_from_fileobj(f, **kwargs) + + def dump_to_path(self, obj, filepath, mode="w", **kwargs): + with open(filepath, mode) as f: + self.dump_to_fileobj(obj, f, **kwargs) + + +class JsonHandler(BaseFileHandler): + def load_from_fileobj(self, file): + return json.load(file) + + def dump_to_fileobj(self, obj, file, **kwargs): + json.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + return json.dumps(obj, **kwargs) + + +class PickleHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + return pickle.load(file, **kwargs) + + def load_from_path(self, filepath, **kwargs): + return super(PickleHandler, self).load_from_path(filepath, mode="rb", **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault("protocol", 2) + return pickle.dumps(obj, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault("protocol", 2) + pickle.dump(obj, file, **kwargs) + + def dump_to_path(self, obj, filepath, **kwargs): + super(PickleHandler, self).dump_to_path(obj, filepath, mode="wb", **kwargs) + + +class YamlHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + kwargs.setdefault("Loader", Loader) + return yaml.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault("Dumper", Dumper) + yaml.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault("Dumper", Dumper) + return yaml.dump(obj, **kwargs) + + +file_handlers = { + "json": JsonHandler(), + "yaml": YamlHandler(), + "yml": YamlHandler(), + "pickle": PickleHandler(), + "pkl": PickleHandler(), +} + +# =========================== +# load and dump +# =========================== + + +def is_str(x): + """Whether the input is an string instance. + + Note: This method is deprecated since python 2 is no longer supported. + """ + return isinstance(x, str) + + +def slload(file, file_format=None, **kwargs): + """Load data from json/yaml/pickle files. + + This method provides a unified api for loading data from serialized files. + + Args: + file (str or :obj:`Path` or file-like object): Filename or a file-like + object. + file_format (str, optional): If not specified, the file format will be + inferred from the file extension, otherwise use the specified one. + Currently supported formats include "json", "yaml/yml" and + "pickle/pkl". + + Returns: + The content from the file. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None and is_str(file): + file_format = file.split(".")[-1] + if file_format not in file_handlers: + raise TypeError(f"Unsupported format: {file_format}") + + handler = file_handlers[file_format] + if is_str(file): + obj = handler.load_from_path(file, **kwargs) + elif hasattr(file, "read"): + obj = handler.load_from_fileobj(file, **kwargs) + else: + raise TypeError('"file" must be a filepath str or a file-object') + return obj + + +def sldump(obj, file=None, file_format=None, **kwargs): + """Dump data to json/yaml/pickle strings or files. + + This method provides a unified api for dumping data as strings or to files, + and also supports custom arguments for each file format. + + Args: + obj (any): The python object to be dumped. + file (str or :obj:`Path` or file-like object, optional): If not + specified, then the object is dump to a str, otherwise to a file + specified by the filename or file-like object. + file_format (str, optional): Same as :func:`load`. + + Returns: + bool: True for success, False otherwise. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None: + if is_str(file): + file_format = file.split(".")[-1] + elif file is None: + raise ValueError("file_format must be specified since file is None") + if file_format not in file_handlers: + raise TypeError(f"Unsupported format: {file_format}") + + handler = file_handlers[file_format] + if file is None: + return handler.dump_to_str(obj, **kwargs) + elif is_str(file): + handler.dump_to_path(obj, file, **kwargs) + elif hasattr(file, "write"): + handler.dump_to_fileobj(obj, file, **kwargs) + else: + raise TypeError('"file" must be a filename str or a file-object') diff --git a/groundingdino/util/time_counter.py b/groundingdino/util/time_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..0aedb2e4d61bfbe7571dca9d50053f0fedaa1359 --- /dev/null +++ b/groundingdino/util/time_counter.py @@ -0,0 +1,62 @@ +import json +import time + + +class TimeCounter: + def __init__(self) -> None: + pass + + def clear(self): + self.timedict = {} + self.basetime = time.perf_counter() + + def timeit(self, name): + nowtime = time.perf_counter() - self.basetime + self.timedict[name] = nowtime + self.basetime = time.perf_counter() + + +class TimeHolder: + def __init__(self) -> None: + self.timedict = {} + + def update(self, _timedict: dict): + for k, v in _timedict.items(): + if k not in self.timedict: + self.timedict[k] = AverageMeter(name=k, val_only=True) + self.timedict[k].update(val=v) + + def final_res(self): + return {k: v.avg for k, v in self.timedict.items()} + + def __str__(self): + return json.dumps(self.final_res(), indent=2) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f", val_only=False): + self.name = name + self.fmt = fmt + self.val_only = val_only + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + if self.val_only: + fmtstr = "{name} {val" + self.fmt + "}" + else: + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) diff --git a/groundingdino/util/transforms.py b/groundingdino/util/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..980002328abcdca80953bcb37383891c54c7c432 --- /dev/null +++ b/groundingdino/util/transforms.py @@ -0,0 +1,312 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Transforms and data augmentation for both image + bbox. +""" +import os +import random + +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F + +from groundingdino.util.box_ops import box_xyxy_to_cxcywh +from groundingdino.util.misc import interpolate + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area", "iscrowd", "positive_map"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target["masks"] = target["masks"][:, i : i + h, j : j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target["boxes"].reshape(-1, 2, 2) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target["masks"].flatten(1).any(1) + + for field in fields: + if field in target: + target[field] = target[field][keep] + + if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO": + # for debug and visualization only. + if "strings_positive" in target: + target["strings_positive"] = [ + _i for _i, _j in zip(target["strings_positive"], keep) if _j + ] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor( + [w, 0, w, 0] + ) + target["boxes"] = boxes + + if "masks" in target: + target["masks"] = target["masks"].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height] + ) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target["masks"] = ( + interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5 + ) + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target["size"] = torch.tensor(padded_image.size[::-1]) + if "masks" in target: + target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1])) + return padded_image, target + + +class ResizeDebug(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + return resize(img, target, self.size) + + +class RandomCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop(object): + def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False): + # respect_boxes: True to keep all boxes + # False to tolerence box filter + self.min_size = min_size + self.max_size = max_size + self.respect_boxes = respect_boxes + + def __call__(self, img: PIL.Image.Image, target: dict): + init_boxes = len(target["boxes"]) + max_patience = 10 + for i in range(max_patience): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + result_img, result_target = crop(img, target, region) + if ( + not self.respect_boxes + or len(result_target["boxes"]) == init_boxes + or i == max_patience - 1 + ): + return result_img, result_target + return result_img, result_target + + +class CenterCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize(object): + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class RandomPad(object): + def __init__(self, max_pad): + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class RandomSelect(object): + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + + def __init__(self, transforms1, transforms2, p=0.5): + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class ToTensor(object): + def __call__(self, img, target): + return F.to_tensor(img), target + + +class RandomErasing(object): + def __init__(self, *args, **kwargs): + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + return image, target + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + + return format_string diff --git a/groundingdino/util/utils.py b/groundingdino/util/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..390d43b6c41213fa844913e038005398ca5f4809 --- /dev/null +++ b/groundingdino/util/utils.py @@ -0,0 +1,607 @@ +import argparse +import json +import warnings +from collections import OrderedDict +from copy import deepcopy +from typing import Any, Dict, List + +import numpy as np +import torch +from transformers import AutoTokenizer + + + +def slprint(x, name="x"): + if isinstance(x, (torch.Tensor, np.ndarray)): + print(f"{name}.shape:", x.shape) + elif isinstance(x, (tuple, list)): + print("type x:", type(x)) + for i in range(min(10, len(x))): + slprint(x[i], f"{name}[{i}]") + elif isinstance(x, dict): + for k, v in x.items(): + slprint(v, f"{name}[{k}]") + else: + print(f"{name}.type:", type(x)) + + +def clean_state_dict(state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k[:7] == "module.": + k = k[7:] # remove `module.` + new_state_dict[k] = v + return new_state_dict + + +def renorm( + img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] +) -> torch.FloatTensor: + # img: tensor(3,H,W) or tensor(B,3,H,W) + # return: same as img + assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() + if img.dim() == 3: + assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % ( + img.size(0), + str(img.size()), + ) + img_perm = img.permute(1, 2, 0) + mean = torch.Tensor(mean) + std = torch.Tensor(std) + img_res = img_perm * std + mean + return img_res.permute(2, 0, 1) + else: # img.dim() == 4 + assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % ( + img.size(1), + str(img.size()), + ) + img_perm = img.permute(0, 2, 3, 1) + mean = torch.Tensor(mean) + std = torch.Tensor(std) + img_res = img_perm * std + mean + return img_res.permute(0, 3, 1, 2) + + +class CocoClassMapper: + def __init__(self) -> None: + self.category_map_str = { + "1": 1, + "2": 2, + "3": 3, + "4": 4, + "5": 5, + "6": 6, + "7": 7, + "8": 8, + "9": 9, + "10": 10, + "11": 11, + "13": 12, + "14": 13, + "15": 14, + "16": 15, + "17": 16, + "18": 17, + "19": 18, + "20": 19, + "21": 20, + "22": 21, + "23": 22, + "24": 23, + "25": 24, + "27": 25, + "28": 26, + "31": 27, + "32": 28, + "33": 29, + "34": 30, + "35": 31, + "36": 32, + "37": 33, + "38": 34, + "39": 35, + "40": 36, + "41": 37, + "42": 38, + "43": 39, + "44": 40, + "46": 41, + "47": 42, + "48": 43, + "49": 44, + "50": 45, + "51": 46, + "52": 47, + "53": 48, + "54": 49, + "55": 50, + "56": 51, + "57": 52, + "58": 53, + "59": 54, + "60": 55, + "61": 56, + "62": 57, + "63": 58, + "64": 59, + "65": 60, + "67": 61, + "70": 62, + "72": 63, + "73": 64, + "74": 65, + "75": 66, + "76": 67, + "77": 68, + "78": 69, + "79": 70, + "80": 71, + "81": 72, + "82": 73, + "84": 74, + "85": 75, + "86": 76, + "87": 77, + "88": 78, + "89": 79, + "90": 80, + } + self.origin2compact_mapper = {int(k): v - 1 for k, v in self.category_map_str.items()} + self.compact2origin_mapper = {int(v - 1): int(k) for k, v in self.category_map_str.items()} + + def origin2compact(self, idx): + return self.origin2compact_mapper[int(idx)] + + def compact2origin(self, idx): + return self.compact2origin_mapper[int(idx)] + + +def to_device(item, device): + if isinstance(item, torch.Tensor): + return item.to(device) + elif isinstance(item, list): + return [to_device(i, device) for i in item] + elif isinstance(item, dict): + return {k: to_device(v, device) for k, v in item.items()} + else: + raise NotImplementedError( + "Call Shilong if you use other containers! type: {}".format(type(item)) + ) + + +# +def get_gaussian_mean(x, axis, other_axis, softmax=True): + """ + + Args: + x (float): Input images(BxCxHxW) + axis (int): The index for weighted mean + other_axis (int): The other index + + Returns: weighted index for axis, BxC + + """ + mat2line = torch.sum(x, axis=other_axis) + # mat2line = mat2line / mat2line.mean() * 10 + if softmax: + u = torch.softmax(mat2line, axis=2) + else: + u = mat2line / (mat2line.sum(2, keepdim=True) + 1e-6) + size = x.shape[axis] + ind = torch.linspace(0, 1, size).to(x.device) + batch = x.shape[0] + channel = x.shape[1] + index = ind.repeat([batch, channel, 1]) + mean_position = torch.sum(index * u, dim=2) + return mean_position + + +def get_expected_points_from_map(hm, softmax=True): + """get_gaussian_map_from_points + B,C,H,W -> B,N,2 float(0, 1) float(0, 1) + softargmax function + + Args: + hm (float): Input images(BxCxHxW) + + Returns: + weighted index for axis, BxCx2. float between 0 and 1. + + """ + # hm = 10*hm + B, C, H, W = hm.shape + y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C + x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C + # return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2) + return torch.stack([x_mean, y_mean], dim=2) + + +# Positional encoding (section 5.1) +# borrow from nerf +class Embedder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs["input_dims"] + out_dim = 0 + if self.kwargs["include_input"]: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs["max_freq_log2"] + N_freqs = self.kwargs["num_freqs"] + + if self.kwargs["log_sampling"]: + freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs["periodic_fns"]: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + + +def get_embedder(multires, i=0): + import torch.nn as nn + + if i == -1: + return nn.Identity(), 3 + + embed_kwargs = { + "include_input": True, + "input_dims": 3, + "max_freq_log2": multires - 1, + "num_freqs": multires, + "log_sampling": True, + "periodic_fns": [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + embed = lambda x, eo=embedder_obj: eo.embed(x) + return embed, embedder_obj.out_dim + + +class APOPMeter: + def __init__(self) -> None: + self.tp = 0 + self.fp = 0 + self.tn = 0 + self.fn = 0 + + def update(self, pred, gt): + """ + Input: + pred, gt: Tensor() + """ + assert pred.shape == gt.shape + self.tp += torch.logical_and(pred == 1, gt == 1).sum().item() + self.fp += torch.logical_and(pred == 1, gt == 0).sum().item() + self.tn += torch.logical_and(pred == 0, gt == 0).sum().item() + self.tn += torch.logical_and(pred == 1, gt == 0).sum().item() + + def update_cm(self, tp, fp, tn, fn): + self.tp += tp + self.fp += fp + self.tn += tn + self.tn += fn + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def get_raw_dict(args): + """ + return the dicf contained in args. + + e.g: + >>> with open(path, 'w') as f: + json.dump(get_raw_dict(args), f, indent=2) + """ + if isinstance(args, argparse.Namespace): + return vars(args) + elif isinstance(args, dict): + return args + # elif isinstance(args, SLConfig): + # return args._cfg_dict + else: + raise NotImplementedError("Unknown type {}".format(type(args))) + + +def stat_tensors(tensor): + assert tensor.dim() == 1 + tensor_sm = tensor.softmax(0) + entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum() + + return { + "max": tensor.max(), + "min": tensor.min(), + "mean": tensor.mean(), + "var": tensor.var(), + "std": tensor.var() ** 0.5, + "entropy": entropy, + } + + +class NiceRepr: + """Inherit from this class and define ``__nice__`` to "nicely" print your + objects. + + Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function + Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``. + If the inheriting class has a ``__len__``, method then the default + ``__nice__`` method will return its length. + + Example: + >>> class Foo(NiceRepr): + ... def __nice__(self): + ... return 'info' + >>> foo = Foo() + >>> assert str(foo) == '' + >>> assert repr(foo).startswith('>> class Bar(NiceRepr): + ... pass + >>> bar = Bar() + >>> import pytest + >>> with pytest.warns(None) as record: + >>> assert 'object at' in str(bar) + >>> assert 'object at' in repr(bar) + + Example: + >>> class Baz(NiceRepr): + ... def __len__(self): + ... return 5 + >>> baz = Baz() + >>> assert str(baz) == '' + """ + + def __nice__(self): + """str: a "nice" summary string describing this module""" + if hasattr(self, "__len__"): + # It is a common pattern for objects to use __len__ in __nice__ + # As a convenience we define a default __nice__ for these objects + return str(len(self)) + else: + # In all other cases force the subclass to overload __nice__ + raise NotImplementedError(f"Define the __nice__ method for {self.__class__!r}") + + def __repr__(self): + """str: the string of the module""" + try: + nice = self.__nice__() + classname = self.__class__.__name__ + return f"<{classname}({nice}) at {hex(id(self))}>" + except NotImplementedError as ex: + warnings.warn(str(ex), category=RuntimeWarning) + return object.__repr__(self) + + def __str__(self): + """str: the string of the module""" + try: + classname = self.__class__.__name__ + nice = self.__nice__() + return f"<{classname}({nice})>" + except NotImplementedError as ex: + warnings.warn(str(ex), category=RuntimeWarning) + return object.__repr__(self) + + +def ensure_rng(rng=None): + """Coerces input into a random number generator. + + If the input is None, then a global random state is returned. + + If the input is a numeric value, then that is used as a seed to construct a + random state. Otherwise the input is returned as-is. + + Adapted from [1]_. + + Args: + rng (int | numpy.random.RandomState | None): + if None, then defaults to the global rng. Otherwise this can be an + integer or a RandomState class + Returns: + (numpy.random.RandomState) : rng - + a numpy random number generator + + References: + .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501 + """ + + if rng is None: + rng = np.random.mtrand._rand + elif isinstance(rng, int): + rng = np.random.RandomState(rng) + else: + rng = rng + return rng + + +def random_boxes(num=1, scale=1, rng=None): + """Simple version of ``kwimage.Boxes.random`` + + Returns: + Tensor: shape (n, 4) in x1, y1, x2, y2 format. + + References: + https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390 + + Example: + >>> num = 3 + >>> scale = 512 + >>> rng = 0 + >>> boxes = random_boxes(num, scale, rng) + >>> print(boxes) + tensor([[280.9925, 278.9802, 308.6148, 366.1769], + [216.9113, 330.6978, 224.0446, 456.5878], + [405.3632, 196.3221, 493.3953, 270.7942]]) + """ + rng = ensure_rng(rng) + + tlbr = rng.rand(num, 4).astype(np.float32) + + tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2]) + tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3]) + br_x = np.maximum(tlbr[:, 0], tlbr[:, 2]) + br_y = np.maximum(tlbr[:, 1], tlbr[:, 3]) + + tlbr[:, 0] = tl_x * scale + tlbr[:, 1] = tl_y * scale + tlbr[:, 2] = br_x * scale + tlbr[:, 3] = br_y * scale + + boxes = torch.from_numpy(tlbr) + return boxes + + +class ModelEma(torch.nn.Module): + def __init__(self, model, decay=0.9997, device=None): + super(ModelEma, self).__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + + # import ipdb; ipdb.set_trace() + + self.decay = decay + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip( + self.module.state_dict().values(), model.state_dict().values() + ): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) + + +class BestMetricSingle: + def __init__(self, init_res=0.0, better="large") -> None: + self.init_res = init_res + self.best_res = init_res + self.best_ep = -1 + + self.better = better + assert better in ["large", "small"] + + def isbetter(self, new_res, old_res): + if self.better == "large": + return new_res > old_res + if self.better == "small": + return new_res < old_res + + def update(self, new_res, ep): + if self.isbetter(new_res, self.best_res): + self.best_res = new_res + self.best_ep = ep + return True + return False + + def __str__(self) -> str: + return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep) + + def __repr__(self) -> str: + return self.__str__() + + def summary(self) -> dict: + return { + "best_res": self.best_res, + "best_ep": self.best_ep, + } + + +class BestMetricHolder: + def __init__(self, init_res=0.0, better="large", use_ema=False) -> None: + self.best_all = BestMetricSingle(init_res, better) + self.use_ema = use_ema + if use_ema: + self.best_ema = BestMetricSingle(init_res, better) + self.best_regular = BestMetricSingle(init_res, better) + + def update(self, new_res, epoch, is_ema=False): + """ + return if the results is the best. + """ + if not self.use_ema: + return self.best_all.update(new_res, epoch) + else: + if is_ema: + self.best_ema.update(new_res, epoch) + return self.best_all.update(new_res, epoch) + else: + self.best_regular.update(new_res, epoch) + return self.best_all.update(new_res, epoch) + + def summary(self): + if not self.use_ema: + return self.best_all.summary() + + res = {} + res.update({f"all_{k}": v for k, v in self.best_all.summary().items()}) + res.update({f"regular_{k}": v for k, v in self.best_regular.summary().items()}) + res.update({f"ema_{k}": v for k, v in self.best_ema.summary().items()}) + return res + + def __repr__(self) -> str: + return json.dumps(self.summary(), indent=2) + + def __str__(self) -> str: + return self.__repr__() + + +def targets_to(targets: List[Dict[str, Any]], device): + """Moves the target dicts to the given device.""" + excluded_keys = [ + "questionId", + "tokens_positive", + "strings_positive", + "tokens", + "dataset_name", + "sentence_id", + "original_img_id", + "nb_eval", + "task_id", + "original_id", + "token_span", + "caption", + "dataset_type", + ] + return [ + {k: v.to(device) if k not in excluded_keys else v for k, v in t.items()} for t in targets + ] + + +def get_phrases_from_posmap( + posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer +): + assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor" + if posmap.dim() == 1: + non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist() + token_ids = [tokenized["input_ids"][i] for i in non_zero_idx] + return tokenizer.decode(token_ids) + else: + raise NotImplementedError("posmap must be 1-dim") diff --git a/groundingdino/util/visualizer.py b/groundingdino/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1b7b101e9b73f75f9136bc67f2063c7c1cf1c1 --- /dev/null +++ b/groundingdino/util/visualizer.py @@ -0,0 +1,318 @@ +# -*- coding: utf-8 -*- +""" +@File : visualizer.py +@Time : 2022/04/05 11:39:33 +@Author : Shilong Liu +@Contact : slongliu86@gmail.com +""" + +import datetime +import os + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib import transforms +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon +from pycocotools import mask as maskUtils + + +def renorm( + img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] +) -> torch.FloatTensor: + # img: tensor(3,H,W) or tensor(B,3,H,W) + # return: same as img + assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() + if img.dim() == 3: + assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % ( + img.size(0), + str(img.size()), + ) + img_perm = img.permute(1, 2, 0) + mean = torch.Tensor(mean) + std = torch.Tensor(std) + img_res = img_perm * std + mean + return img_res.permute(2, 0, 1) + else: # img.dim() == 4 + assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % ( + img.size(1), + str(img.size()), + ) + img_perm = img.permute(0, 2, 3, 1) + mean = torch.Tensor(mean) + std = torch.Tensor(std) + img_res = img_perm * std + mean + return img_res.permute(0, 3, 1, 2) + + +class ColorMap: + def __init__(self, basergb=[255, 255, 0]): + self.basergb = np.array(basergb) + + def __call__(self, attnmap): + # attnmap: h, w. np.uint8. + # return: h, w, 4. np.uint8. + assert attnmap.dtype == np.uint8 + h, w = attnmap.shape + res = self.basergb.copy() + res = res[None][None].repeat(h, 0).repeat(w, 1) # h, w, 3 + attn1 = attnmap.copy()[..., None] # h, w, 1 + res = np.concatenate((res, attn1), axis=-1).astype(np.uint8) + return res + + +def rainbow_text(x, y, ls, lc, **kw): + """ + Take a list of strings ``ls`` and colors ``lc`` and place them next to each + other, with text ls[i] being shown in color lc[i]. + + This example shows how to do both vertical and horizontal text, and will + pass all keyword arguments to plt.text, so you can set the font size, + family, etc. + """ + t = plt.gca().transData + fig = plt.gcf() + plt.show() + + # horizontal version + for s, c in zip(ls, lc): + text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw) + text.draw(fig.canvas.get_renderer()) + ex = text.get_window_extent() + t = transforms.offset_copy(text._transform, x=ex.width, units="dots") + + # #vertical version + # for s,c in zip(ls,lc): + # text = plt.text(x,y," "+s+" ",color=c, transform=t, + # rotation=90,va='bottom',ha='center',**kw) + # text.draw(fig.canvas.get_renderer()) + # ex = text.get_window_extent() + # t = transforms.offset_copy(text._transform, y=ex.height, units='dots') + + +class COCOVisualizer: + def __init__(self, coco=None, tokenlizer=None) -> None: + self.coco = coco + + def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"): + """ + img: tensor(3, H, W) + tgt: make sure they are all on cpu. + must have items: 'image_id', 'boxes', 'size' + """ + plt.figure(dpi=dpi) + plt.rcParams["font.size"] = "5" + ax = plt.gca() + img = renorm(img).permute(1, 2, 0) + # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': + # import ipdb; ipdb.set_trace() + ax.imshow(img) + + self.addtgt(tgt) + + if tgt is None: + image_id = 0 + elif "image_id" not in tgt: + image_id = 0 + else: + image_id = tgt["image_id"] + + if caption is None: + savename = "{}/{}-{}.png".format( + savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-") + ) + else: + savename = "{}/{}-{}-{}.png".format( + savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-") + ) + print("savename: {}".format(savename)) + os.makedirs(os.path.dirname(savename), exist_ok=True) + plt.savefig(savename) + plt.close() + + def addtgt(self, tgt): + """ """ + if tgt is None or not "boxes" in tgt: + ax = plt.gca() + + if "caption" in tgt: + ax.set_title(tgt["caption"], wrap=True) + + ax.set_axis_off() + return + + ax = plt.gca() + H, W = tgt["size"] + numbox = tgt["boxes"].shape[0] + + color = [] + polygons = [] + boxes = [] + for box in tgt["boxes"].cpu(): + unnormbbox = box * torch.Tensor([W, H, W, H]) + unnormbbox[:2] -= unnormbbox[2:] / 2 + [bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist() + boxes.append([bbox_x, bbox_y, bbox_w, bbox_h]) + poly = [ + [bbox_x, bbox_y], + [bbox_x, bbox_y + bbox_h], + [bbox_x + bbox_w, bbox_y + bbox_h], + [bbox_x + bbox_w, bbox_y], + ] + np_poly = np.array(poly).reshape((4, 2)) + polygons.append(Polygon(np_poly)) + c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] + color.append(c) + + p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1) + ax.add_collection(p) + p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) + ax.add_collection(p) + + if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0: + assert ( + len(tgt["strings_positive"]) == numbox + ), f"{len(tgt['strings_positive'])} = {numbox}, " + for idx, strlist in enumerate(tgt["strings_positive"]): + cate_id = int(tgt["labels"][idx]) + _string = str(cate_id) + ":" + " ".join(strlist) + bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] + # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1}) + ax.text( + bbox_x, + bbox_y, + _string, + color="black", + bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1}, + ) + + if "box_label" in tgt: + assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, " + for idx, bl in enumerate(tgt["box_label"]): + _string = str(bl) + bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] + # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1}) + ax.text( + bbox_x, + bbox_y, + _string, + color="black", + bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1}, + ) + + if "caption" in tgt: + ax.set_title(tgt["caption"], wrap=True) + # plt.figure() + # rainbow_text(0.0,0.0,"all unicorns poop rainbows ! ! !".split(), + # ['red', 'orange', 'brown', 'green', 'blue', 'purple', 'black']) + + if "attn" in tgt: + # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': + # import ipdb; ipdb.set_trace() + if isinstance(tgt["attn"], tuple): + tgt["attn"] = [tgt["attn"]] + for item in tgt["attn"]: + attn_map, basergb = item + attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3) + attn_map = (attn_map * 255).astype(np.uint8) + cm = ColorMap(basergb) + heatmap = cm(attn_map) + ax.imshow(heatmap) + ax.set_axis_off() + + def showAnns(self, anns, draw_bbox=False): + """ + Display the specified annotations. + :param anns (array of object): annotations to display + :return: None + """ + if len(anns) == 0: + return 0 + if "segmentation" in anns[0] or "keypoints" in anns[0]: + datasetType = "instances" + elif "caption" in anns[0]: + datasetType = "captions" + else: + raise Exception("datasetType not supported") + if datasetType == "instances": + ax = plt.gca() + ax.set_autoscale_on(False) + polygons = [] + color = [] + for ann in anns: + c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] + if "segmentation" in ann: + if type(ann["segmentation"]) == list: + # polygon + for seg in ann["segmentation"]: + poly = np.array(seg).reshape((int(len(seg) / 2), 2)) + polygons.append(Polygon(poly)) + color.append(c) + else: + # mask + t = self.imgs[ann["image_id"]] + if type(ann["segmentation"]["counts"]) == list: + rle = maskUtils.frPyObjects( + [ann["segmentation"]], t["height"], t["width"] + ) + else: + rle = [ann["segmentation"]] + m = maskUtils.decode(rle) + img = np.ones((m.shape[0], m.shape[1], 3)) + if ann["iscrowd"] == 1: + color_mask = np.array([2.0, 166.0, 101.0]) / 255 + if ann["iscrowd"] == 0: + color_mask = np.random.random((1, 3)).tolist()[0] + for i in range(3): + img[:, :, i] = color_mask[i] + ax.imshow(np.dstack((img, m * 0.5))) + if "keypoints" in ann and type(ann["keypoints"]) == list: + # turn skeleton into zero-based index + sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1 + kp = np.array(ann["keypoints"]) + x = kp[0::3] + y = kp[1::3] + v = kp[2::3] + for sk in sks: + if np.all(v[sk] > 0): + plt.plot(x[sk], y[sk], linewidth=3, color=c) + plt.plot( + x[v > 0], + y[v > 0], + "o", + markersize=8, + markerfacecolor=c, + markeredgecolor="k", + markeredgewidth=2, + ) + plt.plot( + x[v > 1], + y[v > 1], + "o", + markersize=8, + markerfacecolor=c, + markeredgecolor=c, + markeredgewidth=2, + ) + + if draw_bbox: + [bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"] + poly = [ + [bbox_x, bbox_y], + [bbox_x, bbox_y + bbox_h], + [bbox_x + bbox_w, bbox_y + bbox_h], + [bbox_x + bbox_w, bbox_y], + ] + np_poly = np.array(poly).reshape((4, 2)) + polygons.append(Polygon(np_poly)) + color.append(c) + + # p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4) + # ax.add_collection(p) + p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) + ax.add_collection(p) + elif datasetType == "captions": + for ann in anns: + print(ann["caption"]) diff --git a/groundingdino/util/vl_utils.py b/groundingdino/util/vl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c91bb02f584398f08a28e6b7719e2b99f6e28616 --- /dev/null +++ b/groundingdino/util/vl_utils.py @@ -0,0 +1,100 @@ +import os +import random +from typing import List + +import torch + + +def create_positive_map_from_span(tokenized, token_span, max_text_len=256): + """construct a map such that positive_map[i,j] = True iff box i is associated to token j + Input: + - tokenized: + - input_ids: Tensor[1, ntokens] + - attention_mask: Tensor[1, ntokens] + - token_span: list with length num_boxes. + - each item: [start_idx, end_idx] + """ + positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float) + for j, tok_list in enumerate(token_span): + for (beg, end) in tok_list: + beg_pos = tokenized.char_to_token(beg) + end_pos = tokenized.char_to_token(end - 1) + if beg_pos is None: + try: + beg_pos = tokenized.char_to_token(beg + 1) + if beg_pos is None: + beg_pos = tokenized.char_to_token(beg + 2) + except: + beg_pos = None + if end_pos is None: + try: + end_pos = tokenized.char_to_token(end - 2) + if end_pos is None: + end_pos = tokenized.char_to_token(end - 3) + except: + end_pos = None + if beg_pos is None or end_pos is None: + continue + + assert beg_pos is not None and end_pos is not None + if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE": + positive_map[j, beg_pos] = 1 + break + else: + positive_map[j, beg_pos : end_pos + 1].fill_(1) + + return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) + + +def build_captions_and_token_span(cat_list, force_lowercase): + """ + Return: + captions: str + cat2tokenspan: dict + { + 'dog': [[0, 2]], + ... + } + """ + + cat2tokenspan = {} + captions = "" + for catname in cat_list: + class_name = catname + if force_lowercase: + class_name = class_name.lower() + if "/" in class_name: + class_name_list: List = class_name.strip().split("/") + class_name_list.append(class_name) + class_name: str = random.choice(class_name_list) + + tokens_positive_i = [] + subnamelist = [i.strip() for i in class_name.strip().split(" ")] + for subname in subnamelist: + if len(subname) == 0: + continue + if len(captions) > 0: + captions = captions + " " + strat_idx = len(captions) + end_idx = strat_idx + len(subname) + tokens_positive_i.append([strat_idx, end_idx]) + captions = captions + subname + + if len(tokens_positive_i) > 0: + captions = captions + " ." + cat2tokenspan[class_name] = tokens_positive_i + + return captions, cat2tokenspan + + +def build_id2posspan_and_caption(category_dict: dict): + """Build id2pos_span and caption from category_dict + + Args: + category_dict (dict): category_dict + """ + cat_list = [item["name"].lower() for item in category_dict] + id2catname = {item["id"]: item["name"].lower() for item in category_dict} + caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True) + id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()} + return id2posspan, caption diff --git a/imagebind/__init__.py b/imagebind/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imagebind/data/__init__.py b/imagebind/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imagebind/data/data_utils.py b/imagebind/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e04e6a8966bb7ac57ea64e44313fda8cc7cb3a --- /dev/null +++ b/imagebind/data/data_utils.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torchaudio +import logging + +import torchvision + +from imagebind.models.multimodal_preprocessors import SimpleTokenizer +from PIL import Image +from pytorchvideo import transforms as pv_transforms +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler +from pytorchvideo.data.encoded_video import EncodedVideo + +from torchvision import transforms +from torchvision.transforms._transforms_video import NormalizeVideo + +DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds + +BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz" + + +def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): + # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102 + waveform -= waveform.mean() + fbank = torchaudio.compliance.kaldi.fbank( + waveform, + htk_compat=True, + sample_frequency=sample_rate, + use_energy=False, + window_type="hanning", + num_mel_bins=num_mel_bins, + dither=0.0, + frame_length=25, + frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, + ) + # Convert to [mel_bins, num_frames] shape + fbank = fbank.transpose(0, 1) + # Pad to target_length + n_frames = fbank.size(1) + p = target_length - n_frames + # if p is too large (say >20%), flash a warning + # if abs(p) / n_frames > 0.2: + # logging.warning( + # "Large gap between audio n_frames(%d) and " + # "target_length (%d). Is the audio_target_length " + # "setting correct?", + # n_frames, + # target_length, + # ) + # cut and pad + if p > 0: + fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0) + fbank = fbank.unsqueeze(0) + elif p < 0: + # fbank = fbank[:, 0:target_length] + # NOTE: Modified to compatible with longer clips + fbank = fbank.unsqueeze(0) + fbank = torchvision.transforms.Resize(size=[num_mel_bins, target_length])(fbank) + # Convert to [1, mel_bins, num_frames] shape, essentially like a 1 channel image + return fbank + + +def load_and_transform_vision_data(image_paths, device): + if image_paths is None: + return None + + image_ouputs = [] + for image_path in image_paths: + data_transform = transforms.Compose( + [ + transforms.Resize( + 224, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + with open(image_path, "rb") as fopen: + image = Image.open(fopen).convert("RGB") + + image = data_transform(image).to(device) + image_ouputs.append(image) + return torch.stack(image_ouputs, dim=0) + + +def load_and_transform_text(text, device): + if text is None: + return None + tokenizer = SimpleTokenizer(bpe_path=BPE_PATH) + tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text] + tokens = torch.cat(tokens, dim=0) + return tokens + + +def load_and_transform_audio_data( + audio_paths, + device, + num_mel_bins=128, + target_length=204, + sample_rate=16000, + clip_duration=2, + clips_per_video=3, + mean=-4.268, + std=9.138, +): + if audio_paths is None: + return None + + audio_outputs = [] + clip_sampler = ConstantClipsPerVideoSampler( + clip_duration=clip_duration, clips_per_video=clips_per_video + ) + + for audio_path in audio_paths: + waveform, sr = torchaudio.load(audio_path) + if sample_rate != sr: + waveform = torchaudio.functional.resample( + waveform, orig_freq=sr, new_freq=sample_rate + ) + all_clips_timepoints = get_constant_clip_timepoints( + clip_sampler, waveform.size(1) / sample_rate + ) + all_clips = [] + for clip_timepoints in all_clips_timepoints: + waveform_clip = waveform[ + :, + int(clip_timepoints[0] * sample_rate): int( + clip_timepoints[1] * sample_rate + ), + ] + waveform_melspec = waveform2melspec( + waveform_clip, sample_rate, num_mel_bins, target_length + ) + all_clips.append(waveform_melspec) + + normalize = transforms.Normalize(mean=mean, std=std) + all_clips = [normalize(ac).to(device) for ac in all_clips] + + all_clips = torch.stack(all_clips, dim=0) + audio_outputs.append(all_clips) + + return torch.stack(audio_outputs, dim=0) + + +def get_constant_clip_timepoints(clip_sampler, duration): + assert isinstance(clip_sampler, ConstantClipsPerVideoSampler), "Incompatible Type of Sampler!" + # Read out all clips in this video + all_clips_timepoints = [] + is_last_clip = False + end = 0.0 + while not is_last_clip: + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) + all_clips_timepoints.append((start, end)) + return all_clips_timepoints + + +def get_random_clip_timepoints(clip_sampler, duration): + assert isinstance(clip_sampler, RandomMultiClipSampler), "Incompatible Type of Sampler!" + starts, ends, _, _, _ = clip_sampler(0.0, duration, annotation=None) + all_clips_timepoints = sorted(list(zip(starts, ends)), key=lambda x: x[0]) + return all_clips_timepoints + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Perform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to perform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size] + cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +class SpatialCrop(nn.Module): + """ + Convert the video into 3 smaller clips spatially. Must be used after the + temporal crops to get spatial crops, and should be used with + -2 in the spatial crop at the slowfast augmentation stage (so full + frames are passed in here). Will return a larger list with the + 3x spatial crops as well. + """ + + def __init__(self, crop_size: int = 224, num_crops: int = 3): + super().__init__() + self.crop_size = crop_size + if num_crops == 3: + self.crops_to_ext = [0, 1, 2] + self.flipped_crops_to_ext = [] + elif num_crops == 1: + self.crops_to_ext = [1] + self.flipped_crops_to_ext = [] + else: + raise NotImplementedError("Nothing else supported yet") + + def forward(self, videos): + """ + Args: + videos: A list of C, T, H, W videos. + Returns: + videos: A list with 3x the number of elements. Each video converted + to C, T, H', W' by spatial cropping. + """ + assert isinstance(videos, list), "Must be a list of videos after temporal crops" + assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" + res = [] + for video in videos: + for spatial_idx in self.crops_to_ext: + res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) + if not self.flipped_crops_to_ext: + continue + flipped_video = transforms.functional.hflip(video) + for spatial_idx in self.flipped_crops_to_ext: + res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) + return res + + +def load_and_transform_video_data( + video_paths, + device, + clip_duration=2, + clips_per_video=5, + sample_rate=16000, +): + if video_paths is None: + return None + + video_outputs = [] + video_transform = transforms.Compose( + [ + pv_transforms.ShortSideScale(224), + NormalizeVideo( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + clip_sampler = ConstantClipsPerVideoSampler( + clip_duration=clip_duration, clips_per_video=clips_per_video + ) + frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) + + for video_path in video_paths: + video = EncodedVideo.from_path( + video_path, + decoder="decord", + decode_audio=False, + **{"sample_rate": sample_rate}, + ) + + all_clips_timepoints = get_constant_clip_timepoints(clip_sampler, video.duration) + + all_video = [] + for clip_timepoints in all_clips_timepoints: + # Read the clip, get frames + clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) + if clip is None: + raise ValueError("No clip found") + video_clip = frame_sampler(clip["video"]) + video_clip = video_clip / 255.0 # since this is float, need 0-1 + + all_video.append(video_clip) + + all_video = [video_transform(clip) for clip in all_video] + all_video = SpatialCrop(224, num_crops=3)(all_video) + + all_video = torch.stack(all_video, dim=0) + video_outputs.append(all_video) + + return torch.stack(video_outputs, dim=0).to(device) diff --git a/imagebind/models/__init__.py b/imagebind/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imagebind/models/helper.py b/imagebind/models/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..514ea46ceeef212cf642d5362cff2670ac344a59 --- /dev/null +++ b/imagebind/models/helper.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import einops +import numpy as np +import torch + +import torch.nn as nn + + +class Normalize(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.nn.functional.normalize(x, dim=self.dim, p=2) + + +class LearnableLogitScaling(nn.Module): + def __init__( + self, + logit_scale_init: float = 1 / 0.07, + learnable: bool = True, + max_logit_scale: float = 100, + ) -> None: + super().__init__() + self.max_logit_scale = max_logit_scale + self.logit_scale_init = logit_scale_init + self.learnable = learnable + log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) + if learnable: + self.log_logit_scale = nn.Parameter(log_logit_scale) + else: + self.register_buffer("log_logit_scale", log_logit_scale) + + def forward(self, x): + return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x + + def extra_repr(self): + st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}" + return st + + +class EinOpsRearrange(nn.Module): + def __init__(self, rearrange_expr: str, **kwargs) -> None: + super().__init__() + self.rearrange_expr = rearrange_expr + self.kwargs = kwargs + + def forward(self, x): + assert isinstance(x, torch.Tensor) + return einops.rearrange(x, self.rearrange_expr, **self.kwargs) + + +class VerboseNNModule(nn.Module): + """ + Wrapper around nn.Module that prints registered buffers and parameter names. + """ + + @staticmethod + def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: + st = ( + "(" + + name + + "): " + + "tensor(" + + str(tuple(tensor[1].shape)) + + ", requires_grad=" + + str(tensor[1].requires_grad) + + ")\n" + ) + return st + + def extra_repr(self) -> str: + named_modules = set() + for p in self.named_modules(): + named_modules.update([p[0]]) + named_modules = list(named_modules) + + string_repr = "" + for p in self.named_parameters(): + name = p[0].split(".")[0] + if name not in named_modules: + string_repr += self.get_readable_tensor_repr(name, p) + + for p in self.named_buffers(): + name = p[0].split(".")[0] + string_repr += self.get_readable_tensor_repr(name, p) + + return string_repr + + +def cast_if_src_dtype( + tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype +): + updated = False + if tensor.dtype == src_dtype: + tensor = tensor.to(dtype=tgt_dtype) + updated = True + return tensor, updated + + +class QuickGELU(nn.Module): + # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class SelectElement(nn.Module): + def __init__(self, index) -> None: + super().__init__() + self.index = index + + def forward(self, x): + assert x.ndim >= 3 + return x[:, self.index, ...] + + +class SelectEOSAndProject(nn.Module): + """ + Text Pooling used in OpenCLIP + """ + + def __init__(self, proj: nn.Module) -> None: + super().__init__() + self.proj = proj + + def forward(self, x, seq_len): + assert x.ndim == 3 + # x is of shape B x L x D + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), seq_len] + x = self.proj(x) + return x \ No newline at end of file diff --git a/imagebind/models/image_bind.py b/imagebind/models/image_bind.py new file mode 100644 index 0000000000000000000000000000000000000000..37ba2e2b1e4716e69d7f146de742270bbd31b0e1 --- /dev/null +++ b/imagebind/models/image_bind.py @@ -0,0 +1,663 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import os +from functools import partial +from types import SimpleNamespace +from typing import Union, Optional, Tuple, Dict, List + +import torch +import torch.nn as nn +from torch import Tensor +from omegaconf import DictConfig + +from imagebind.models.helper import ( + EinOpsRearrange, + LearnableLogitScaling, + Normalize, + SelectElement, + SelectEOSAndProject, +) +from imagebind.models.multimodal_formers import SequenceGenericQFormer, disabled_train +from imagebind.models.multimodal_preprocessors import ( + AudioPreprocessor, + IMUPreprocessor, + PadIm2Video, + PatchEmbedGeneric, + RGBDTPreprocessor, + SpatioTemporalPosEmbeddingHelper, + TextPreprocessor, + ThermalPreprocessor, + BlipPreprocessor, +) +from imagebind.models.multimodal_projectors import create_projectors + +from imagebind.models.transformer import MultiheadAttention, SimpleTransformer + +ModalityType = SimpleNamespace( + VISION="vision", + TEXT="text", + AUDIO="audio", + THERMAL="thermal", + DEPTH="depth", + IMU="imu", +) + + +class ImageBindJoiner(nn.Module): + def __init__(self, cfg: DictConfig, output_dim: int): + super().__init__() + """ + cfg: + - share_key: Optional[str] + - modality_key: DictConfig, the modality cfg for the corresponding key + - feat_dim: int, defaults to 1024, the input dimension to the qformer + - post_dims: tuple, defaults to (768,), layers for post-qformer projection + - pre_dims: tuple, defaults to (), layers for pre-qformer projection + - num_query_token: int, defaults to 32, the numbher of query tokens in qformer + - freeze_qformer: bool, defaults to true, keeping the qformer frozen or not + - qformer_model: str, defaults to "", path to the checkpoint of a qformer, "" for not loading + - modality_key ... + """ + # vision_qformer_model is always "" + # assert not (vision_qformer_frozen and vision_qformer_model == "") + self.share_key = share_key = cfg.get('share_key', None) + self.use_pre_ln = cfg.pop('use_pre_ln') if 'use_pre_ln' in cfg else False + if share_key is not None and isinstance(share_key, str): + self.share_joiner = True + cfg.pop("share_key") + assert share_key in cfg, "The modality key to share does not exist." + # assert len(cfg) == 1, "Only one config is needed for shared joiner." + else: + self.share_joiner = False + + for modality_cfg in cfg.values(): + modality_cfg.pre_dims = modality_cfg.get("pre_dims", ()) + modality_cfg.post_dims = modality_cfg.get("post_dims", (768,)) + modality_cfg.num_query_token = modality_cfg.get("num_query_token", 32) + modality_cfg.freeze_qformer = modality_cfg.get("freeze_qformer", True) + modality_cfg.qformer_model = modality_cfg.get("qformer_model", "") + modality_cfg.freeze_post = modality_cfg.get("freeze_post", False) + + if self.use_pre_ln: + self.modality_pre_lns = self._create_modality_pre_lns(cfg) + self.modality_pre_projectors = self._create_modality_pre_projectors(cfg) + self.modality_qformers = self._create_modality_qformers(cfg) + self.modality_post_projectors = self._create_modality_post_projectors(cfg, output_dim) + + def _create_modality_pre_lns(cfg): + lns = {} + for modality, modality_cfg in cfg.items(): + lns[modality] = nn.LayerNorm(cfg.feat_dim) + return nn.ModuleDict(lns) + + def _create_modality_pre_projectors(self, cfg): + projectors = {} + for modality, modality_cfg in cfg.items(): + projectors[modality] = create_projectors(tuple(modality_cfg.pre_dims)) + return nn.ModuleDict(projectors) + + def _create_modality_post_projectors(self, cfg, output_dim): + projectors = {} + for modality, modality_cfg in cfg.items(): + projectors[modality] = create_projectors(tuple(modality_cfg.post_dims) + (output_dim,)) + if modality_cfg.freeze_post: + for p in projectors[modality].parameters(): + p.requires_grad = False + return nn.ModuleDict(projectors) + + def _create_modality_qformers(self, cfg): + modality_qformers = {} + for modality, modality_cfg in cfg.items(): + modality_qformers[modality] = SequenceGenericQFormer( + num_query_token=modality_cfg.num_query_token, + freeze_qformer=modality_cfg.freeze_qformer, + encoder_width=modality_cfg.feat_dim, + q_former_model=modality_cfg.get("qformer_model", ""), + ) + return nn.ModuleDict(modality_qformers) + + def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + outputs = {} + for modality_key, modality_value in inputs.items(): + model_key = self.share_key if self.share_joiner else modality_key + if modality_value is not None: + if self.use_pre_ln: + modality_value = self.modality_pre_lns[modality_key](modality_value) + modality_value = self.modality_pre_projectors[model_key](modality_value) + modality_value = self.modality_qformers[model_key](modality_value) + modality_value = self.modality_post_projectors[model_key](modality_value) + outputs[modality_key] = modality_value + return outputs + + +class ImageBindModel(nn.Module): + def __init__( + self, + video_frames=2, + kernel_size=(2, 14, 14), + audio_kernel_size=16, + audio_stride=10, + out_embed_dim=768, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_num_mel_bins=128, + audio_target_len=204, + audio_drop_path=0.1, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + depth_embed_dim=384, + depth_kernel_size=16, + depth_num_blocks=12, + depth_num_heads=8, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_kernel_size=16, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_kernel_size=8, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + with_head=True, + ): + super().__init__() + self.with_head = with_head + + self.modality_preprocessors = self._create_modality_preprocessors( + video_frames, + vision_embed_dim, + kernel_size, + text_embed_dim, + audio_embed_dim, + audio_kernel_size, + audio_stride, + audio_num_mel_bins, + audio_target_len, + depth_embed_dim, + depth_kernel_size, + thermal_embed_dim, + thermal_kernel_size, + imu_embed_dim, + ) + + self.modality_trunks = self._create_modality_trunks( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + text_embed_dim, + text_num_blocks, + text_num_heads, + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + audio_drop_path, + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + depth_drop_path, + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + thermal_drop_path, + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + imu_drop_path, + ) + + self.modality_heads = self._create_modality_heads( + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ) + + self.modality_postprocessors = self._create_modality_postprocessors( + out_embed_dim + ) + + def _create_modality_preprocessors( + self, + video_frames=2, + vision_embed_dim=1024, + kernel_size=(2, 14, 14), + text_embed_dim=768, + audio_embed_dim=768, + audio_kernel_size=16, + audio_stride=10, + audio_num_mel_bins=128, + audio_target_len=204, + depth_embed_dim=768, + depth_kernel_size=16, + thermal_embed_dim=768, + thermal_kernel_size=16, + imu_embed_dim=512, + ): + rgbt_stem = PatchEmbedGeneric( + proj_stem=[ + PadIm2Video(pad_type="repeat", ntimes=2), + nn.Conv3d( + in_channels=3, + kernel_size=kernel_size, + out_channels=vision_embed_dim, + stride=kernel_size, + bias=False, + ), + ] + ) + rgbt_preprocessor = RGBDTPreprocessor( + img_size=[3, video_frames, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=rgbt_stem, + depth_stem=None, + ) + + text_preprocessor = TextPreprocessor( + context_length=77, + vocab_size=49408, + embed_dim=text_embed_dim, + causal_masking=True, + ) + + audio_stem = PatchEmbedGeneric( + proj_stem=[ + nn.Conv2d( + in_channels=1, + kernel_size=audio_kernel_size, + stride=audio_stride, + out_channels=audio_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), + ) + audio_preprocessor = AudioPreprocessor( + img_size=[1, audio_num_mel_bins, audio_target_len], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + audio_stem=audio_stem, + ) + + depth_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=depth_kernel_size, + in_channels=1, + out_channels=depth_embed_dim, + stride=depth_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), + ) + + depth_preprocessor = RGBDTPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=None, + depth_stem=depth_stem, + ) + + thermal_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=thermal_kernel_size, + in_channels=1, + out_channels=thermal_embed_dim, + stride=thermal_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), + ) + thermal_preprocessor = ThermalPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + thermal_stem=thermal_stem, + ) + + imu_stem = PatchEmbedGeneric( + [ + nn.Linear( + in_features=48, + out_features=imu_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), + ) + + imu_preprocessor = IMUPreprocessor( + img_size=[6, 2000], + num_cls_tokens=1, + kernel_size=8, + embed_dim=imu_embed_dim, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + imu_stem=imu_stem, + ) + + modality_preprocessors = { + ModalityType.VISION: rgbt_preprocessor, + ModalityType.TEXT: text_preprocessor, + ModalityType.AUDIO: audio_preprocessor, + ModalityType.DEPTH: depth_preprocessor, + ModalityType.THERMAL: thermal_preprocessor, + ModalityType.IMU: imu_preprocessor, + } + + return nn.ModuleDict(modality_preprocessors) + + def _create_modality_trunks( + self, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_drop_path=0.0, + depth_embed_dim=768, + depth_num_blocks=12, + depth_num_heads=12, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + def instantiate_trunk( + embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path + ): + return SimpleTransformer( + embed_dim=embed_dim, + num_blocks=num_blocks, + ffn_dropout_rate=0.0, + drop_path_rate=drop_path, + attn_target=partial( + MultiheadAttention, + embed_dim=embed_dim, + num_heads=num_heads, + bias=True, + add_bias_kv=add_bias_kv, + ), + pre_transformer_layer=nn.Sequential( + nn.LayerNorm(embed_dim, eps=1e-6) + if pre_transformer_ln + else nn.Identity(), + EinOpsRearrange("b l d -> l b d"), + ), + post_transformer_layer=EinOpsRearrange("l b d -> b l d"), + ) + + modality_trunks = {} + modality_trunks[ModalityType.VISION] = instantiate_trunk( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + pre_transformer_ln=True, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.TEXT] = instantiate_trunk( + text_embed_dim, + text_num_blocks, + text_num_heads, + pre_transformer_ln=False, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.AUDIO] = instantiate_trunk( + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=audio_drop_path, + ) + modality_trunks[ModalityType.DEPTH] = instantiate_trunk( + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=depth_drop_path, + ) + modality_trunks[ModalityType.THERMAL] = instantiate_trunk( + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=thermal_drop_path, + ) + modality_trunks[ModalityType.IMU] = instantiate_trunk( + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=imu_drop_path, + ) + + return nn.ModuleDict(modality_trunks) + + def _create_modality_heads( + self, + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + use_selection=False, + ): + modality_heads = {} + + modality_heads[ModalityType.VISION] = nn.Sequential( + nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), + SelectElement(index=0) if use_selection else nn.Identity(), + nn.Linear(vision_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.TEXT] = SelectEOSAndProject( + proj=nn.Sequential( + nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), + nn.Linear(text_embed_dim, out_embed_dim, bias=False), + ) + ) + + modality_heads[ModalityType.AUDIO] = nn.Sequential( + nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), + SelectElement(index=0) if use_selection else nn.Identity(), + nn.Linear(audio_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.DEPTH] = nn.Sequential( + nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), + SelectElement(index=0) if use_selection else nn.Identity(), + nn.Linear(depth_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.THERMAL] = nn.Sequential( + nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), + SelectElement(index=0) if use_selection else nn.Identity(), + nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.IMU] = nn.Sequential( + nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), + SelectElement(index=0) if use_selection else nn.Identity(), + nn.Dropout(p=0.5), + nn.Linear(imu_embed_dim, out_embed_dim, bias=False), + ) + + return nn.ModuleDict(modality_heads) + + def _create_modality_postprocessors(self, out_embed_dim): + modality_postprocessors = {} + + modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) + modality_postprocessors[ModalityType.TEXT] = nn.Sequential( + Normalize(dim=-1), LearnableLogitScaling(learnable=True) + ) + modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=20.0, learnable=False), + ) + modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=10.0, learnable=False), + ) + modality_postprocessors[ModalityType.IMU] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + + return nn.ModuleDict(modality_postprocessors) + + def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + outputs = {} + for modality_key, modality_value in inputs.items(): + reduce_list = ( + modality_value.ndim >= 5 + ) # Audio and Video inputs consist of multiple clips + if reduce_list: + B, S = modality_value.shape[:2] + modality_value = modality_value.reshape( + B * S, *modality_value.shape[2:] + ) + + if modality_value is not None: + modality_value = self.modality_preprocessors[modality_key]( + **{modality_key: modality_value} + ) + trunk_inputs = modality_value["trunk"] + head_inputs = modality_value["head"] + modality_value = self.modality_trunks[modality_key](**trunk_inputs) + + # NOTE: No heads are needed any more. + if self.with_head: + modality_value = self.modality_heads[modality_key]( + modality_value, **head_inputs + ) + + modality_value = self.modality_postprocessors[modality_key]( + modality_value + ) + + # NOTE: The reduction operation has been modified. + if reduce_list: + modality_value = modality_value.reshape(B, S, *modality_value.shape[1:]) + modality_value = modality_value.mean(dim=1) + + outputs[modality_key] = modality_value + + return outputs + + +def imagebind_huge(pretrained=False, freeze_imagebind=False, with_head=True, use_blip_vision=False): + model = ImageBindModel( + vision_embed_dim=1280, + vision_num_blocks=32, + vision_num_heads=16, + text_embed_dim=1024, + text_num_blocks=24, + text_num_heads=16, + out_embed_dim=1024, + audio_drop_path=0.1, + imu_drop_path=0.7, + with_head=with_head, + ) + + if pretrained: + if not os.path.exists(".checkpoints/imagebind_huge.pth"): + print( + "Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..." + ) + os.makedirs(".checkpoints", exist_ok=True) + torch.hub.download_url_to_file( + "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", + ".checkpoints/imagebind_huge.pth", + progress=True, + ) + + model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth")) + + if use_blip_vision: + from bubogpt.models.eva_vit import create_eva_vit_g + visual_encoder = create_eva_vit_g( + img_size=224, drop_path_rate=0., use_checkpoint=False, precision='fp16' + ) + vision_ln = LayerNorm(visual_encoder.num_features) + vision_ln.load_state_dict(load_ln_params()) + model.modality_preprocessors[ModalityType.VISION] = BlipPreprocessor() + model.modality_trunks[ModalityType.VISION] = visual_encoder + model.modality_postprocessors[ModalityType.VISION] = vision_ln + + if freeze_imagebind: + for name, param in model.named_parameters(): + param.requires_grad = False + model = model.eval() + model.train = disabled_train + + return model + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +def load_ln_params(path="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"): + state_dict = torch.load(path, map_location="cpu")["model"] + params = type(state_dict)() + params["weight"] = state_dict["ln_vision.weight"] + params["bias"] = state_dict["ln_vision.bias"] + return params + + +def replace_joiner_vision(joiner, q_former_model, proj_model): + assert isinstance(joiner.modality_pre_projectors.vision, nn.Identity) + + joiner.modality_qformers[ModalityType.VISION].load_Qformer(q_former_model) + + state_dict = torch.load(proj_model, map_location="cpu")["model"] + params = type(state_dict)() + params["fc.weight"] = state_dict["llama_proj.weight"] + params["fc.bias"] = state_dict["llama_proj.bias"] + joiner.modality_post_projectors[ModalityType.VISION].load_state_dict(params, strict=False) diff --git a/imagebind/models/multimodal_formers.py b/imagebind/models/multimodal_formers.py new file mode 100644 index 0000000000000000000000000000000000000000..f571fac4f6b2d4b351086eb70352bbc473bad4fa --- /dev/null +++ b/imagebind/models/multimodal_formers.py @@ -0,0 +1,110 @@ +import logging +import os + +import torch +from torch import nn, Tensor + +from bubogpt.common.dist_utils import download_cached_file +from bubogpt.common.utils import is_url +from bubogpt.models.Qformer import BertConfig, BertLMHeadModel + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class BaseQFormer(nn.Module): + def __init__(self, freeze_qformer=False): + super().__init__() + self.freeze_qformer = freeze_qformer + self.Qformer = None + + def check_and_freeze(self): + assert self.Qformer is not None + if self.freeze_qformer: + for name, param in self.Qformer.named_parameters(): + param.requires_grad = False + self.Qformer = self.Qformer.eval() + self.Qformer.train = disabled_train + self.query_tokens.requires_grad = False + logging.info("Freeze This QFormer") + + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + + msg = self.load_state_dict(state_dict, strict=False) + + logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + +class SequenceGenericQFormer(BaseQFormer): + def __init__(self, + num_query_token: int, + encoder_width: int = 768, + freeze_qformer: bool = False, + q_former_model: str = "", + cross_attention_freq: int = 2 + ): + super().__init__(freeze_qformer) + self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, encoder_width, cross_attention_freq) + if q_former_model != "": + self.load_Qformer(q_former_model) + self.check_and_freeze() + + def set_Qformer(self): + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + def load_Qformer(self, q_former_model): + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.load_from_pretrained(url_or_filename=q_former_model) + + @classmethod + def init_Qformer(cls, num_query_token, encoder_width, cross_attention_freq=2): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = encoder_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens + + def forward(self, input_embeds: Tensor) -> Tensor: + input_atts = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device) + query_tokens = self.query_tokens.expand(input_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=input_embeds, + encoder_attention_mask=input_atts, + return_dict=True, + ) + return query_output.last_hidden_state diff --git a/imagebind/models/multimodal_preprocessors.py b/imagebind/models/multimodal_preprocessors.py new file mode 100644 index 0000000000000000000000000000000000000000..0938d6be3e87b37f407a7949b7952655d8e1a083 --- /dev/null +++ b/imagebind/models/multimodal_preprocessors.py @@ -0,0 +1,698 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gzip +import html +import io +import math +from functools import lru_cache +from typing import Callable, List, Optional + +import ftfy + +import numpy as np +import regex as re +import torch +import torch.nn as nn +from iopath.common.file_io import g_pathmgr +from timm.models.layers import trunc_normal_ +from imagebind.models.helper import VerboseNNModule, cast_if_src_dtype + + +def get_sinusoid_encoding_table(n_position, d_hid): + """Sinusoid position encoding table""" + + # TODO: make it with torch instead of numpy + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(n_position)] + ) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): + N = pos_embed.shape[1] + if N == target_spatial_size: + return pos_embed + dim = pos_embed.shape[-1] + # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32 + pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32) + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=math.sqrt(target_spatial_size / N), + mode="bicubic", + ) + if updated: + pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16) + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=None, + first_patch_idx=1, +): + assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none" + N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists + if npatch_per_img == N: + return pos_embed + + assert ( + patches_layout[-1] == patches_layout[-2] + ), "Interpolation of pos embed not supported for non-square layouts" + + class_emb = pos_embed[:, :first_patch_idx] + pos_embed = pos_embed[:, first_patch_idx:] + + if input_shape is None or patches_layout[0] == 1: + # simple 2D pos embedding, no temporal component + pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed) + elif patches_layout[0] > 1: + # pos embed has a temporal component + assert len(input_shape) == 4, "temporal interpolation not supported" + # we only support 2D interpolation in this case + num_frames = patches_layout[0] + num_spatial_tokens = patches_layout[1] * patches_layout[2] + pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1) + # interpolate embedding for zeroth frame + pos_embed = interpolate_pos_encoding_2d( + npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0) + ) + else: + raise ValueError("This type of interpolation isn't implemented") + + return torch.cat((class_emb, pos_embed), dim=1) + + +def _get_pos_embedding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape, + first_patch_idx=1, +): + pos_embed = interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=input_shape, + first_patch_idx=first_patch_idx, + ) + return pos_embed + + +class PatchEmbedGeneric(nn.Module): + """ + PatchEmbed from Hydra + """ + + def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None): + super().__init__() + + if len(proj_stem) > 1: + self.proj = nn.Sequential(*proj_stem) + else: + # Special case to be able to load pre-trained models that were + # trained with a standard stem + self.proj = proj_stem[0] + self.norm_layer = norm_layer + + def get_patch_layout(self, img_size): + with torch.no_grad(): + dummy_img = torch.zeros( + [ + 1, + ] + + img_size + ) + dummy_out = self.proj(dummy_img) + embed_dim = dummy_out.shape[1] + patches_layout = tuple(dummy_out.shape[2:]) + num_patches = np.prod(patches_layout) + return patches_layout, num_patches, embed_dim + + def forward(self, x): + x = self.proj(x) + # B C (T) H W -> B (T)HW C + x = x.flatten(2).transpose(1, 2) + if self.norm_layer is not None: + x = self.norm_layer(x) + return x + + +class SpatioTemporalPosEmbeddingHelper(VerboseNNModule): + def __init__( + self, + patches_layout: List, + num_patches: int, + num_cls_tokens: int, + embed_dim: int, + learnable: bool, + ) -> None: + super().__init__() + self.num_cls_tokens = num_cls_tokens + self.patches_layout = patches_layout + self.num_patches = num_patches + self.num_tokens = num_cls_tokens + num_patches + self.learnable = learnable + if self.learnable: + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) + trunc_normal_(self.pos_embed, std=0.02) + else: + self.register_buffer( + "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim) + ) + + def get_pos_embedding(self, vision_input, all_vision_tokens): + input_shape = vision_input.shape + pos_embed = _get_pos_embedding( + all_vision_tokens.size(1) - self.num_cls_tokens, + pos_embed=self.pos_embed, + patches_layout=self.patches_layout, + input_shape=input_shape, + first_patch_idx=self.num_cls_tokens, + ) + return pos_embed + +class BlipPreprocessor(VerboseNNModule): + def __init__(self) -> None: + super().__init__() + + def forward(self, vision=None): + return_dict = { + "trunk": { + "x": vision, + }, + "head": {}, + } + return return_dict + +class RGBDTPreprocessor(VerboseNNModule): + def __init__( + self, + rgbt_stem: PatchEmbedGeneric, + depth_stem: PatchEmbedGeneric, + img_size: List = (3, 224, 224), + num_cls_tokens: int = 1, + pos_embed_fn: Callable = None, + use_type_embed: bool = False, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + stem = rgbt_stem if rgbt_stem is not None else depth_stem + ( + self.patches_layout, + self.num_patches, + self.embed_dim, + ) = stem.get_patch_layout(img_size) + self.rgbt_stem = rgbt_stem + self.depth_stem = depth_stem + self.use_pos_embed = pos_embed_fn is not None + self.use_type_embed = use_type_embed + self.num_cls_tokens = num_cls_tokens + + if self.use_pos_embed: + self.pos_embedding_helper = pos_embed_fn( + patches_layout=self.patches_layout, + num_cls_tokens=num_cls_tokens, + num_patches=self.num_patches, + embed_dim=self.embed_dim, + ) + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + if self.use_type_embed: + self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.use_pos_embed: + nn.init.normal_(self.pos_embedding_helper.pos_embed) + self.pos_embedding_helper.pos_embed *= scale + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + if self.use_type_embed: + nn.init.normal_(self.type_embed) + + def tokenize_input_and_cls_pos(self, input, stem, mask): + # tokens is of shape B x L x D + tokens = stem(input) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens) + tokens = tokens + pos_embed + if self.use_type_embed: + tokens = tokens + self.type_embed.expand(B, -1, -1) + return tokens + + def forward(self, vision=None, depth=None, patch_mask=None): + if patch_mask is not None: + raise NotImplementedError() + + if vision is not None: + vision_tokens = self.tokenize_input_and_cls_pos( + vision, self.rgbt_stem, patch_mask + ) + + if depth is not None: + depth_tokens = self.tokenize_input_and_cls_pos( + depth, self.depth_stem, patch_mask + ) + + # aggregate tokens + if vision is not None and depth is not None: + final_tokens = vision_tokens + depth_tokens + else: + final_tokens = vision_tokens if vision is not None else depth_tokens + return_dict = { + "trunk": { + "tokens": final_tokens, + }, + "head": {}, + } + return return_dict + + +class AudioPreprocessor(RGBDTPreprocessor): + def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs) + + def forward(self, audio=None): + return super().forward(vision=audio) + + +class ThermalPreprocessor(RGBDTPreprocessor): + def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs) + + def forward(self, thermal=None): + return super().forward(vision=thermal) + + +def build_causal_attention_mask(context_length): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(context_length, context_length, requires_grad=False) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + +class TextPreprocessor(VerboseNNModule): + def __init__( + self, + vocab_size: int, + context_length: int, + embed_dim: int, + causal_masking: bool, + supply_seq_len_to_head: bool = True, + num_cls_tokens: int = 0, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + self.vocab_size = vocab_size + self.context_length = context_length + self.token_embedding = nn.Embedding(vocab_size, embed_dim) + self.pos_embed = nn.Parameter( + torch.empty(1, self.context_length + num_cls_tokens, embed_dim) + ) + self.causal_masking = causal_masking + if self.causal_masking: + mask = build_causal_attention_mask(self.context_length) + # register the mask as a buffer so it can be moved to the right device + self.register_buffer("mask", mask) + + self.supply_seq_len_to_head = supply_seq_len_to_head + self.num_cls_tokens = num_cls_tokens + self.embed_dim = embed_dim + if num_cls_tokens > 0: + assert self.causal_masking is False, "Masking + CLS token isn't implemented" + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style="openclip"): + # OpenCLIP style initialization + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def forward(self, text): + # text tokens are of shape B x L x D + text_tokens = self.token_embedding(text) + # concat CLS tokens if any + if self.num_cls_tokens > 0: + B = text_tokens.shape[0] + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + text_tokens = torch.cat((class_tokens, text_tokens), dim=1) + text_tokens = text_tokens + self.pos_embed + return_dict = { + "trunk": { + "tokens": text_tokens, + }, + "head": {}, + } + # Compute sequence length after adding CLS tokens + if self.supply_seq_len_to_head: + text_lengths = text.argmax(dim=-1) + return_dict["head"] = { + "seq_len": text_lengths, + } + if self.causal_masking: + return_dict["trunk"].update({"attn_mask": self.mask}) + return return_dict + + +class Im2Video(nn.Module): + """Convert an image into a trivial video.""" + + def __init__(self, time_dim=2): + super().__init__() + self.time_dim = time_dim + + def forward(self, x): + if x.ndim == 4: + # B, C, H, W -> B, C, T, H, W + return x.unsqueeze(self.time_dim) + elif x.ndim == 5: + return x + else: + raise ValueError(f"Dimension incorrect {x.shape}") + + +class PadIm2Video(Im2Video): + def __init__(self, ntimes, pad_type, time_dim=2): + super().__init__(time_dim=time_dim) + assert ntimes > 0 + assert pad_type in ["zero", "repeat"] + self.ntimes = ntimes + self.pad_type = pad_type + + def forward(self, x): + x = super().forward(x) + if x.shape[self.time_dim] == 1: + if self.pad_type == "repeat": + new_shape = [1] * len(x.shape) + new_shape[self.time_dim] = self.ntimes + x = x.repeat(new_shape) + elif self.pad_type == "zero": + padarg = [0, 0] * len(x.shape) + padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim] + x = nn.functional.pad(x, padarg) + return x + + +# Modified from github.com/openai/CLIP +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str, context_length=77): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + with g_pathmgr.open(bpe_path, "rb") as fh: + bpe_bytes = io.BytesIO(fh.read()) + merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + self.context_length = context_length + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + def __call__(self, texts, context_length=None): + if not context_length: + context_length = self.context_length + + if isinstance(texts, str): + texts = [texts] + + sot_token = self.encoder["<|startoftext|>"] + eot_token = self.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + tokens = tokens[:context_length] + result[i, : len(tokens)] = torch.tensor(tokens) + + if len(result) == 1: + return result[0] + return result + + +class IMUPreprocessor(VerboseNNModule): + def __init__( + self, + kernel_size: int, + imu_stem: PatchEmbedGeneric, + embed_dim: int, + img_size: List = (6, 2000), + num_cls_tokens: int = 1, + pos_embed_fn: Callable = None, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + stem = imu_stem + self.imu_stem = imu_stem + self.embed_dim = embed_dim + self.use_pos_embed = pos_embed_fn is not None + self.num_cls_tokens = num_cls_tokens + self.kernel_size = kernel_size + self.pos_embed = nn.Parameter( + torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim) + ) + + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def tokenize_input_and_cls_pos(self, input, stem): + # tokens is of shape B x L x D + tokens = stem.norm_layer(stem.proj(input)) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + tokens = tokens + self.pos_embed + return tokens + + def forward(self, imu): + # Patchify + imu = imu.unfold( + -1, + self.kernel_size, + self.kernel_size, + ).permute(0, 2, 1, 3) + imu = imu.reshape(imu.size(0), imu.size(1), -1) + + imu_tokens = self.tokenize_input_and_cls_pos( + imu, + self.imu_stem, + ) + + return_dict = { + "trunk": { + "tokens": imu_tokens, + }, + "head": {}, + } + return return_dict \ No newline at end of file diff --git a/imagebind/models/multimodal_projectors.py b/imagebind/models/multimodal_projectors.py new file mode 100644 index 0000000000000000000000000000000000000000..420a00c6b5803688462e523d987ad714bbc37f3e --- /dev/null +++ b/imagebind/models/multimodal_projectors.py @@ -0,0 +1,45 @@ +from torch import nn, Tensor + +from typing import Union, Optional, Tuple + + +class BaseProjector(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + raise NotImplementedError + + +class LinearProjector(BaseProjector): + def __init__(self, in_dim, out_dim): + super().__init__() + self.fc = nn.Linear(in_dim, out_dim) + + def forward(self, x: Tensor) -> Tensor: + return self.fc(x) + + +class AdapterProjector(BaseProjector): + def __init__(self, in_dim, mid_dim, out_dim): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(in_dim, mid_dim, bias=False), + nn.ReLU(inplace=True), + nn.Linear(mid_dim, out_dim, bias=False), + nn.ReLU(inplace=True) + ) + + def forward(self, x: Tensor) -> Tensor: + return self.fc(x) + + +def create_projectors(dims): + if len(dims) == 0: + return nn.Identity() + elif len(dims) == 2: + return LinearProjector(*dims) + elif len(dims) == 3: + return AdapterProjector(*dims) + else: + raise NotImplementedError diff --git a/imagebind/models/transformer.py b/imagebind/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb3cfde43f2317df61290e6e3b5a210243578ce --- /dev/null +++ b/imagebind/models/transformer.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Code modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ; +# https://github.com/facebookresearch/deit/blob/main/models.py +# and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py + + +import copy +import fnmatch +import logging +from functools import partial +from typing import Callable, List + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import DropPath, trunc_normal_ + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, + # can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class MultiheadAttention(nn.MultiheadAttention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + +class ViTAttention(Attention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + assert attn_mask is None + return super().forward(x) + + +class BlockWithMasking(nn.Module): + def __init__( + self, + dim: int, + attn_target: Callable, + mlp_ratio: int = 4, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ffn_dropout_rate: float = 0.0, + drop_path: float = 0.0, + layer_scale_type: str = None, + layer_scale_init_value: float = 1e-4, + ): + super().__init__() + + assert not isinstance( + attn_target, nn.Module + ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" + self.attn = attn_target() + if drop_path > 0.0: + self.drop_path = DropPath(drop_path) + else: + self.drop_path = nn.Identity() + self.norm_1 = norm_layer(dim) + mlp_hidden_dim = int(mlp_ratio * dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=ffn_dropout_rate, + ) + self.norm_2 = norm_layer(dim) + self.layer_scale_type = layer_scale_type + if self.layer_scale_type is not None: + assert self.layer_scale_type in [ + "per_channel", + "scalar", + ], f"Found Layer scale type {self.layer_scale_type}" + if self.layer_scale_type == "per_channel": + # one gamma value per channel + gamma_shape = [1, 1, dim] + elif self.layer_scale_type == "scalar": + # single gamma value for all channels + gamma_shape = [1, 1, 1] + # two gammas: for each part of the fwd in the encoder + self.layer_scale_gamma1 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + self.layer_scale_gamma2 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + if self.layer_scale_type is None: + x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + x = x + self.drop_path(self.mlp(self.norm_2(x))) + else: + x = ( + x + + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + * self.layer_scale_gamma1 + ) + x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 + return x + + +_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) + + +class SimpleTransformer(nn.Module): + def __init__( + self, + attn_target: Callable, + embed_dim: int, + num_blocks: int, + block: Callable = BlockWithMasking, + pre_transformer_layer: Callable = None, + post_transformer_layer: Callable = None, + drop_path_rate: float = 0.0, + drop_path_type: str = "progressive", + norm_layer: Callable = _LAYER_NORM, + mlp_ratio: int = 4, + ffn_dropout_rate: float = 0.0, + layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar" + layer_scale_init_value: float = 1e-4, # from cait; float + weight_init_style: str = "jax", # possible values jax or pytorch + ): + """ + Simple Transformer with the following features + 1. Supports masked attention + 2. Supports DropPath + 3. Supports LayerScale + 4. Supports Dropout in Attention and FFN + 5. Makes few assumptions about the input except that it is a Tensor + """ + super().__init__() + self.pre_transformer_layer = pre_transformer_layer + if drop_path_type == "progressive": + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] + elif drop_path_type == "uniform": + dpr = [drop_path_rate for i in range(num_blocks)] + else: + raise ValueError(f"Unknown drop_path_type: {drop_path_type}") + + self.blocks = nn.Sequential( + *[ + block( + dim=embed_dim, + attn_target=attn_target, + mlp_ratio=mlp_ratio, + ffn_dropout_rate=ffn_dropout_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + layer_scale_type=layer_scale_type, + layer_scale_init_value=layer_scale_init_value, + ) + for i in range(num_blocks) + ] + ) + self.post_transformer_layer = post_transformer_layer + self.weight_init_style = weight_init_style + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + if self.weight_init_style == "jax": + # Based on MAE and official Jax ViT implementation + torch.nn.init.xavier_uniform_(m.weight) + elif self.weight_init_style == "pytorch": + # PyTorch ViT uses trunc_normal_ + trunc_normal_(m.weight, std=0.02) + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + tokens: torch.Tensor, + attn_mask: torch.Tensor = None, + use_checkpoint: bool = False, + checkpoint_every_n: int = 1, + checkpoint_blk_ids: List[int] = None, + ): + """ + Inputs + - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) + - attn: mask of shape L x L + + Output + - x: data of shape N x L x D (or L x N x D depending on the attention implementation) + """ + if self.pre_transformer_layer: + tokens = self.pre_transformer_layer(tokens) + if use_checkpoint and checkpoint_blk_ids is None: + checkpoint_blk_ids = [ + blk_id + for blk_id in range(len(self.blocks)) + if blk_id % checkpoint_every_n == 0 + ] + if checkpoint_blk_ids: + checkpoint_blk_ids = set(checkpoint_blk_ids) + for blk_id, blk in enumerate(self.blocks): + if use_checkpoint and blk_id in checkpoint_blk_ids: + tokens = checkpoint.checkpoint( + blk, tokens, attn_mask, use_reentrant=False + ) + else: + tokens = blk(tokens, attn_mask=attn_mask) + if self.post_transformer_layer: + tokens = self.post_transformer_layer(tokens) + return tokens \ No newline at end of file diff --git a/match.py b/match.py new file mode 100644 index 0000000000000000000000000000000000000000..56a2d4fbf9db0fa8f04e19d98f0a1dcaefdd24ec --- /dev/null +++ b/match.py @@ -0,0 +1,184 @@ +import os +import torch +import torch.nn as nn +import openai +import random +import contextlib + +from constants.constant import COLORS + + +@contextlib.contextmanager +def auto_proxy(): + use_proxy = "OPENAI_PROXY" in os.environ + if use_proxy: + os.environ['http_proxy'] = os.environ["OPENAI_PROXY"] + os.environ['https_proxy'] = os.environ["OPENAI_PROXY"] + + yield + + if use_proxy: + os.unsetenv('http_proxy') + os.unsetenv('https_proxy') + + +class MatchModule(nn.Module): + def __init__(self, device='cpu', model="gpt-3.5-turbo"): + super().__init__() + self.device = device + self.model = model + if "OPENAI_API_KEY" not in os.environ: + raise RuntimeError("Please specify your openai API key with the environment variable OPENAI_API_KEY") + openai.api_key = os.environ["OPENAI_API_KEY"] + self.examples = [ + ( + "['dog', 'sheepdog', 'grass', 'chase sheepdog', 'field', 'field park', 'grassy', 'corgi', 'brown dog', 'brown', 'park']" + "A brown dog running in the grassy field", + 'brown dog - brown dog\n' + 'grassy field - field' + ), + ( + "['man', 'ride', 'bicycle', 'red', 'passenger train', 'track']" + "A man riding a bicycle next to a red passenger train on the tracks.", + "man - man\n" + "bicycle - bicycle\n" + "red passenger train - passenger train\n" + "tracks - track" + ), + ( + "['horse', 'herd', 'dust', 'grassy', 'field']" + "The image shows a large herd of wild horses running across a wide, open field . " + "There are many horses running in different directions, with some running towards the camera " + "and others running towards the edge of the field. " + "The horses are brown and white, with some having manes and tails", + "herd - herd\n" + "wild horses - horse\n" + "field - field" + ), + ( + "['man', 'plate platter', 'sandwich', 'tablening table', 'saucer', 'coffee coffee cup', 'coffee', 'bean chip fry', 'chip fry', 'coffee cup', 'bean', 'food', 'table', 'restaurant']" + "The image shows a man sitting at a table , surrounded by a large amount of food and drinks . There is a chicken sandwich on the table, as well as a bowl of soup, potato wedges, and several fried potatoes. The man is holding a spoon, which he is expected to use to eat one of the wedges or possibly a piece of the chicken sandwich. The other items on the table, such as drinks and a bowl of soup, appear to be for those accompanying the man at the table. The scene takes place in a dining establishment , likely a restaurant , based on the presence of a spoon and food items on the table, along with a tablecloth and table setting. Additionally, the presence of several chairs and the overall setup suggest this is a formal, sit-down setting rather than a fast food or take-out restaurant. The amount of food on the table suggests that this is a hearty, satisfying meal, providing a range of flavors and textures that satisfy the palate.", + "man - man\n" + "table - table\n" + "food - food\n" + "chicken sandwich - sandwich\n" + "restaurant - restaurant\n" + "fried potatoes - chip fry\n" + "drinks - coffee" + ), + ( + "['bacon', 'silverware utensil', 'fork', 'coffee', 'table dinning table', 'plate platter', 'beverage', 'napkin', 'bread french toast pan', 'pine cone', 'coffee cup cup mug', 'fruit', 'breakfast food fruit', 'bacon', 'gravy', 'bread pancake']" + "The image presents a delicious breakfast setting on a wooden dining table. The main course is a white plate with French toast and bacon . Adding to the meal are a bottle of maple syrup and a cup of coffee , both placed next to the plate. The table is set with a fork , a knife, and a spoon, all arranged neatly around the plate. There are also a few apples scattered across the table, possibly serving as a healthy addition to the meal. Overall, the scene is inviting and warmly lit, making the breakfast look especially appetizing.", + "wooden dinning table - table dinning table\n" + "fork - fork\n" + "coffee - coffee\n" + "apples - fruit\n" + "white plate - plate platter\n" + "french toast - bread french toast pan\n" + "bacon - bacon" + ), + ( + "['woman', 'canopy', 'man', 'dog pet', 'dog', 'canopy', 'bicycle', 'person', 'leash', " + "'dog pet', 'leash', 'stall', 'person woman', 'dog pet', 'city street road', 'street scene']" + "The image captures a lively street scene with several people walking and riding bikes. " + "There are two bicycles in the picture, one located in the middle of the scene and the other towards " + "the right side. Among the people, some are walking close to the bicycles, while others are scattered" + "throughout the scene. In addition to the bicycles and people, there are four dogs in the picture, " + "adding to the liveliness of the scene. The dogs are walking around the street, mingling with the " + "pedestrians and bikers. The street is bustling with activity, as people, bikes, and dogs all " + "share the space and enjoy the day.", + "street scene - street scene\n" + "the street - city street road\n" + "bicycles - bicycle\n" + "four dogs - dog\n" + "people - person" + ) + ] + self.system_prompt = "You are a helpful assistant. Now I will give you a list of entities and give you a " \ + "paragraph or sentence. " \ + "you need to first extract the entity given in the text and then" \ + "find the corresponding entity having similar or identical meanings in the given list. " \ + "Find all the pairs." \ + "Are you clear? let us think step by step. " \ + "The extracted entities must come from the given text and the corresponding entity must " \ + "come from the given list. " \ + "If multiple entities can be linked to the same span of text or vice versa, " \ + "just keep one and do not merge them." \ + "Here is an example: ['dog', 'sheepdog', 'grass', 'chase sheepdog', 'field', " \ + "'field park', 'grassy', 'corgi', 'brown dog', 'brown', 'park'] " \ + "A brown dog running in the grassy field" \ + "The answer is: brown dog — brown dog \n grassy field — field" + + @torch.no_grad() + def forward(self, text, entity_state): + entity_list = list(entity_state['grounding']['local'].keys()) + message = [ + {"role": "system", "content": self.system_prompt}, + ] + for q, a in self.examples: + message.append({"role": "user", "content": q}) + message.append({"role": "system", "content": a}) + message.append({ + "role": "user", + "content": '{}{}'.format(entity_state['grounding']['local'].keys(), text) + }) + + print('==> Sending request to ChatGPT...') + with auto_proxy(): + resp = openai.ChatCompletion.create( + model=self.model, + messages=message + ) + ans = resp['choices'][0]['message']['content'] + print("===> In the matching module.") + print('==> Response from ChatGPT received: {}.'.format(ans)) + # print(resp) + items = ans.split('\n') + res = [] + match_state = {} + for i in items: + if ' - ' not in i: + continue + name, ref = i.split(' - ', maxsplit=1) + name, ref = name.lower(), ref.lower() + # NOTE: ref may not be contained in the original text, double check later. + if ref in entity_list: + color_name = entity_state['grounding']['local'][ref]['color'] + else: + print('pair {} - {} not found'.format(name, ref)) + # color_name = "grey" + continue + match_state[name] = ref + entity_idx = text.lower().find(name) + if entity_idx == -1: + entity_idx = text.lower().find(name.lower()) + ref = name + if entity_idx == -1: + continue + + res.append((name, ref, entity_idx, color_name)) + res = sorted(res, key=lambda x: x[2]) + # TODO: Bug to fix + highlight_output = [] + prev = 0 + color_map = {} + + for i, r in enumerate(res): + if r[2] < prev: + continue + # to avoid one-vs-many alignments + if r[2] != prev: + highlight_output.append((text[prev:r[2]], None)) + highlight_output.append((text[r[2]:r[2] + len(r[0])], f'{i + 1}')) + color_map[f'{i + 1}'] = r[-1] + prev = r[2] + len(r[0]) + if prev != len(text) - 1: + highlight_output.append((text[prev:], None)) + print("=======> Highlight Output: ", highlight_output) + return highlight_output, match_state, color_map + + +if __name__ == '__main__': + ner = MatchModule(model='gpt-4') + print( + ner('The image shows a resort with a large swimming pool surrounded by lounge chairs and umbrellas. There are several buildings in the background with white walls and blue roofs. There are sand dunes and palm trees in the background indicating that the resort is located in a desert area. The sky is clear and blue with a few fluffy clouds in the distance.')) diff --git a/ram/__init__.py b/ram/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..137c4ae4731a8e9765a082bd9b83edf6eb4896d7 --- /dev/null +++ b/ram/__init__.py @@ -0,0 +1 @@ +from .inference import inference_tag2text, inference_ram, inference_ram_openset \ No newline at end of file diff --git a/ram/configs/med_config.json b/ram/configs/med_config.json new file mode 100644 index 0000000000000000000000000000000000000000..49d64f890cc38d558c4fd3bab048cc521a69a2be --- /dev/null +++ b/ram/configs/med_config.json @@ -0,0 +1,21 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 30524, + "encoder_width": 768, + "add_cross_attention": true + } \ No newline at end of file diff --git a/ram/configs/q2l_config.json b/ram/configs/q2l_config.json new file mode 100644 index 0000000000000000000000000000000000000000..a8eba56c27769cadd1506c8e88fe33aced92668f --- /dev/null +++ b/ram/configs/q2l_config.json @@ -0,0 +1,22 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 4, + "num_hidden_layers": 2, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 30522, + "encoder_width": 768, + "add_cross_attention": true, + "add_tag_cross_attention": false + } \ No newline at end of file diff --git a/ram/configs/swin/config_swinB_384.json b/ram/configs/swin/config_swinB_384.json new file mode 100644 index 0000000000000000000000000000000000000000..d2f3e0724319655e7a084d602db3712abda746ee --- /dev/null +++ b/ram/configs/swin/config_swinB_384.json @@ -0,0 +1,9 @@ +{ + "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", + "vision_width": 1024, + "image_res": 384, + "window_size": 12, + "embed_dim": 128, + "depths": [ 2, 2, 18, 2 ], + "num_heads": [ 4, 8, 16, 32 ] + } \ No newline at end of file diff --git a/ram/configs/swin/config_swinL_384.json b/ram/configs/swin/config_swinL_384.json new file mode 100644 index 0000000000000000000000000000000000000000..e6443a2d209fef96f4a183b7499323976f3a88e5 --- /dev/null +++ b/ram/configs/swin/config_swinL_384.json @@ -0,0 +1,9 @@ +{ + "ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth", + "vision_width": 1536, + "image_res": 384, + "window_size": 12, + "embed_dim": 192, + "depths": [ 2, 2, 18, 2 ], + "num_heads": [ 6, 12, 24, 48 ] + } \ No newline at end of file diff --git a/ram/data/ram_tag_list.txt b/ram/data/ram_tag_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..49c840b71915f639fb79cd83ac4a3e313cfbc2b1 --- /dev/null +++ b/ram/data/ram_tag_list.txt @@ -0,0 +1,4585 @@ +3D CG rendering +3D glasses +abacus +abalone +monastery +belly +academy +accessory +accident +accordion +acorn +acrylic paint +act +action +action film +activity +actor +adaptation +add +adhesive tape +adjust +adult +adventure +advertisement +antenna +aerobics +spray can +afro +agriculture +aid +air conditioner +air conditioning +air sock +aircraft cabin +aircraft model +air field +air line +airliner +airman +plane +airplane window +airport +airport runway +airport terminal +airship +airshow +aisle +alarm +alarm clock +mollymawk +album +album cover +alcohol +alcove +algae +alley +almond +aloe vera +alp +alpaca +alphabet +german shepherd +altar +amber +ambulance +bald eagle +American shorthair +amethyst +amphitheater +amplifier +amusement park +amusement ride +anchor +ancient +anemone +angel +angle +animal +animal sculpture +animal shelter +animation +animation film +animator +anime +ankle +anklet +anniversary +trench coat +ant +antelope +antique +antler +anvil +apartment +ape +app +app icon +appear +appearance +appetizer +applause +apple +apple juice +apple pie +apple tree +applesauce +appliance +appointment +approach +apricot +apron +aqua +aquarium +aquarium fish +aqueduct +arcade +arcade machine +arch +arch bridge +archaelogical excavation +archery +archipelago +architect +architecture +archive +archway +area +arena +argument +arm +armadillo +armband +armchair +armoire +armor +army +army base +army tank +array +arrest +arrow +art +art exhibition +art gallery +art print +art school +art studio +art vector illustration +artichoke +article +artifact +artist +artists loft +ash +ashtray +asia temple +asparagus +asphalt road +assemble +assembly +assembly line +association +astronaut +astronomer +athlete +athletic +atlas +atm +atmosphere +atrium +attach +fighter jet +attend +attraction +atv +eggplant +auction +audi +audio +auditorium +aurora +author +auto factory +auto mechanic +auto part +auto show +auto showroom +car battery +automobile make +automobile model +motor vehicle +autumn +autumn forest +autumn leave +autumn park +autumn tree +avatar +avenue +aviator sunglasses +avocado +award +award ceremony +award winner +shed +ax +azalea +baboon +baby +baby bottle +baby carriage +baby clothe +baby elephant +baby food +baby seat +baby shower +back +backdrop +backlight +backpack +backyard +bacon +badge +badger +badlands +badminton +badminton racket +bag +bagel +bagpipe +baguette +bait +baked goods +baker +bakery +baking +baking sheet +balance +balance car +balcony +ball +ball pit +ballerina +ballet +ballet dancer +ballet skirt +balloon +balloon arch +baseball player +ballroom +bamboo +bamboo forest +banana +banana bread +banana leaf +banana tree +band +band aid +bandage +headscarf +bandeau +bangs +bracelet +balustrade +banjo +bank +bank card +bank vault +banknote +banner +banquet +banquet hall +banyan tree +baozi +baptism +bar +bar code +bar stool +barbecue +barbecue grill +barbell +barber +barber shop +barbie +barge +barista +bark +barley +barn +barn owl +barn door +barrel +barricade +barrier +handcart +bartender +baseball +baseball base +baseball bat +baseball hat +baseball stadium +baseball game +baseball glove +baseball pitcher +baseball team +baseball uniform +basement +basil +basin +basket +basket container +basketball +basketball backboard +basketball coach +basketball court +basketball game +basketball hoop +basketball player +basketball stadium +basketball team +bass +bass guitar +bass horn +bassist +bat +bath +bath heater +bath mat +bath towel +swimwear +bathrobe +bathroom +bathroom accessory +bathroom cabinet +bathroom door +bathroom mirror +bathroom sink +toilet paper +bathroom window +batman +wand +batter +battery +battle +battle rope +battleship +bay +bay bridge +bay window +bayberry +bazaar +beach +beach ball +beach chair +beach house +beach hut +beach towel +beach volleyball +lighthouse +bead +beagle +beak +beaker +beam +bean +bean bag chair +beanbag +bear +bear cub +beard +beast +beat +beautiful +beauty +beauty salon +beaver +bed +bedcover +bed frame +bedroom +bedding +bedpan +bedroom window +bedside lamp +bee +beech tree +beef +beekeeper +beeper +beer +beer bottle +beer can +beer garden +beer glass +beer hall +beet +beetle +beige +clock +bell pepper +bell tower +belt +belt buckle +bench +bend +bengal tiger +bento +beret +berry +berth +beverage +bib +bibimbap +bible +bichon +bicycle +bicycle helmet +bicycle wheel +biker +bidet +big ben +bike lane +bike path +bike racing +bike ride +bikini +bikini top +bill +billard +billboard +billiard table +bin +binder +binocular +biology laboratory +biplane +birch +birch tree +bird +bird bath +bird feeder +bird house +bird nest +birdbath +bird cage +birth +birthday +birthday cake +birthday candle +birthday card +birthday party +biscuit +bishop +bison +bit +bite +black +black sheep +blackberry +blackbird +blackboard +blacksmith +blade +blanket +sports coat +bleacher +blender +blessing +blind +eye mask +flasher +snowstorm +block +blog +blood +bloom +blossom +blouse +blow +hair drier +blowfish +blue +blue artist +blue jay +blue sky +blueberry +bluebird +pig +board +board eraser +board game +boardwalk +boat +boat deck +boat house +paddle +boat ride +bobfloat +bobcat +body +bodyboard +bodybuilder +boiled egg +boiler +bolo tie +bolt +bomb +bomber +bonasa umbellu +bone +bonfire +bonnet +bonsai +book +book cover +bookcase +folder +bookmark +bookshelf +bookstore +boom microphone +boost +boot +border +Border collie +botanical garden +bottle +bottle cap +bottle opener +bottle screw +bougainvillea +boulder +bouquet +boutique +boutique hotel +bow +bow tie +bow window +bowl +bowling +bowling alley +bowling ball +bowling equipment +box +box girder bridge +box turtle +boxer +underdrawers +boxing +boxing glove +boxing ring +boy +brace +bracket +braid +brain +brake +brake light +branch +brand +brandy +brass +brass plaque +bread +breadbox +break +breakfast +seawall +chest +brewery +brick +brick building +wall +brickwork +wedding dress +bride +groom +bridesmaid +bridge +bridle +briefcase +bright +brim +broach +broadcasting +broccoli +bronze +bronze medal +bronze sculpture +bronze statue +brooch +creek +broom +broth +brown +brown bear +brownie +brunch +brunette +brush +coyote +brussels sprout +bubble +bubble gum +bubble tea +bucket cabinet +shield +bud +buddha +buffalo +buffet +bug +build +builder +building +building block +building facade +building material +lamp +bull +bulldog +bullet +bullet train +bulletin board +bulletproof vest +bullfighting +megaphone +bullring +bumblebee +bumper +roll +bundle +bungee +bunk bed +bunker +bunny +buoy +bureau +burial chamber +burn +burrito +bus +bus driver +bus interior +bus station +bus stop +bus window +bush +business +business card +business executive +business suit +business team +business woman +businessman +bust +butcher +butchers shop +butte +butter +cream +butterfly +butterfly house +button +buttonwood +buy +taxi +cabana +cabbage +cabin +cabin car +cabinet +cabinetry +cable +cable car +cactus +cafe +canteen +cage +cake +cake stand +calculator +caldron +calendar +calf +call +phone box +calligraphy +calm +camcorder +camel +camera +camera lens +camouflage +camp +camper +campfire +camping +campsite +campus +can +can opener +canal +canary +cancer +candle +candle holder +candy +candy bar +candy cane +candy store +cane +jar +cannon +canopy +canopy bed +cantaloupe +cantilever bridge +canvas +canyon +cap +cape +cape cod +cappuccino +capsule +captain +capture +car +car dealership +car door +car interior +car logo +car mirror +parking lot +car seat +car show +car wash +car window +caramel +card +card game +cardboard +cardboard box +cardigan +cardinal +cargo +cargo aircraft +cargo ship +caribbean +carnation +carnival +carnivore +carousel +carp +carpenter +carpet +slipper +house finch +coach +dalmatian +aircraft carrier +carrot +carrot cake +carry +cart +carton +cartoon +cartoon character +cartoon illustration +cartoon style +carve +case +cash +cashew +casino +casserole +cassette +cassette deck +plaster bandage +casting +castle +cat +cat bed +cat food +cat furniture +cat tree +catacomb +catamaran +catamount +catch +catcher +caterpillar +catfish +cathedral +cattle +catwalk +catwalk show +cauliflower +cave +caviar +CD +CD player +cedar +ceiling +ceiling fan +celebrate +celebration +celebrity +celery +cello +smartphone +cement +graveyard +centerpiece +centipede +ceramic +ceramic tile +cereal +ceremony +certificate +chain +chain saw +chair +chairlift +daybed +chalet +chalice +chalk +chamber +chameleon +champagne +champagne flute +champion +championship +chandelier +changing table +channel +chap +chapel +character sculpture +charcoal +charge +charger +chariot +charity +charity event +charm +graph +chase +chassis +check +checkbook +chessboard +checklist +cheer +cheerlead +cheese +cheeseburger +cheesecake +cheetah +chef +chemical compound +chemist +chemistry +chemistry lab +cheongsam +cherry +cherry blossom +cherry tomato +cherry tree +chess +chestnut +chicken +chicken breast +chicken coop +chicken salad +chicken wing +garbanzo +chiffonier +chihuahua +child +child actor +childs room +chile +chili dog +chimney +chimpanzee +chinaware +chinese cabbage +chinese garden +chinese knot +chinese rose +chinese tower +chip +chipmunk +chisel +chocolate +chocolate bar +chocolate cake +chocolate chip +chocolate chip cookie +chocolate milk +chocolate mousse +truffle +choir +kitchen knife +cutting board +chopstick +christmas +christmas ball +christmas card +christmas decoration +christmas dinner +christmas eve +christmas hat +christmas light +christmas market +christmas ornament +christmas tree +chrysanthemum +church +church tower +cider +cigar +cigar box +cigarette +cigarette case +waistband +cinema +photographer +cinnamon +circle +circuit +circuit board +circus +water tank +citrus fruit +city +city bus +city hall +city nightview +city park +city skyline +city square +city street +city wall +city view +clam +clarinet +clasp +class +classic +classroom +clavicle +claw +clay +pottery +clean +clean room +cleaner +cleaning product +clear +cleat +clementine +client +cliff +climb +climb mountain +climber +clinic +clip +clip art +clipboard +clipper +clivia +cloak +clogs +close-up +closet +cloth +clothe +clothing +clothespin +clothesline +clothing store +cloud +cloud forest +cloudy +clover +joker +clown fish +club +clutch +clutch bag +coal +coast +coat +coatrack +cob +cock +cockatoo +cocker +cockpit +roach +cocktail +cocktail dress +cocktail shaker +cocktail table +cocoa +coconut +coconut tree +coffee +coffee bean +coffee cup +coffee machine +coffee shop +coffeepot +coffin +cognac +spiral +coin +coke +colander +cold +slaw +collaboration +collage +collection +college student +sheepdog +crash +color +coloring book +coloring material +pony +pillar +comb +combination lock +comic +comedy +comedy film +comet +comfort +comfort food +comic book +comic book character +comic strip +commander +commentator +community +commuter +company +compass +compete +contest +competitor +composer +composition +compost +computer +computer box +computer chair +computer desk +keyboard +computer monitor +computer room +computer screen +computer tower +concept car +concert +concert hall +conch +concrete +condiment +condom +condominium +conductor +cone +meeting +conference center +conference hall +meeting room +confetti +conflict +confluence +connect +connector +conservatory +constellation +construction site +construction worker +contain +container +container ship +continent +profile +contract +control +control tower +convenience store +convention +conversation +converter +convertible +transporter +cook +cooking +cooking spray +cooker +cool +cooler +copper +copy +coral +coral reef +rope +corded phone +liquor +corgi +cork +corkboard +cormorant +corn +corn field +cornbread +corner +trumpet +cornice +cornmeal +corral +corridor +corset +cosmetic +cosmetics brush +cosmetics mirror +cosplay +costume +costumer film designer +infant bed +cottage +cotton +cotton candy +couch +countdown +counter +counter top +country artist +country house +country lane +country pop artist +countryside +coupe +couple +couple photo +courgette +course +court +courthouse +courtyard +cousin +coverall +cow +cowbell +cowboy +cowboy boot +cowboy hat +crab +crabmeat +crack +cradle +craft +craftsman +cranberry +crane +crape +crapper +crate +crater lake +lobster +crayon +cream cheese +cream pitcher +create +creature +credit card +crescent +croissant +crest +crew +cricket +cricket ball +cricket team +cricketer +crochet +crock pot +crocodile +crop +crop top +cross +crossbar +crossroad +crosstalk +crosswalk +crouton +crow +crowbar +crowd +crowded +crown +crt screen +crucifix +cruise +cruise ship +cruiser +crumb +crush +crutch +crystal +cub +cube +cucumber +cue +cuff +cufflink +cuisine +farmland +cup +cupcake +cupid +curb +curl +hair roller +currant +currency +curry +curtain +curve +pad +customer +cut +cutlery +cycle +cycling +cyclone +cylinder +cymbal +cypress +cypress tree +dachshund +daffodil +dagger +dahlia +daikon +dairy +daisy +dam +damage +damp +dance +dance floor +dance room +dancer +dandelion +dark +darkness +dart +dartboard +dashboard +date +daughter +dawn +day bed +daylight +deadbolt +death +debate +debris +decanter +deck +decker bus +decor +decorate +decorative picture +deer +defender +deity +delicatessen +deliver +demolition +monster +demonstration +den +denim jacket +dentist +department store +depression +derby +dermopathy +desert +desert road +design +designer +table +table lamp +desktop +desktop computer +dessert +destruction +detective +detergent +dew +dial +diamond +diaper +diaper bag +journal +die +diet +excavator +number +digital clock +dill +dinner +rowboat +dining room +dinner party +dinning table +dinosaur +dip +diploma +direct +director +dirt +dirt bike +dirt field +dirt road +dirt track +disaster +disciple +disco +disco ball +discotheque +disease +plate +dish antenna +dish washer +dishrag +dishes +dishsoap +Disneyland +dispenser +display +display window +trench +dive +diver +diving board +paper cup +dj +doberman +dock +doctor +document +documentary +dog +dog bed +dog breed +dog collar +dog food +dog house +doll +dollar +dollhouse +dolly +dolphin +dome +domicile +domino +donkey +donut +doodle +door +door handle +doormat +doorplate +doorway +dormitory +dough +downtown +dozer +drag +dragon +dragonfly +drain +drama +drama film +draw +drawer +drawing +drawing pin +pigtail +dress +dress hat +dress shirt +dress shoe +dress suit +dresser +dressing room +dribble +drift +driftwood +drill +drink +drinking water +drive +driver +driveway +drone +drop +droplight +dropper +drought +medicine +pharmacy +drum +drummer +drumstick +dry +duchess +duck +duckbill +duckling +duct tape +dude +duet +duffel +canoe +dumbbell +dumpling +dune +dunk +durian +dusk +dust +garbage truck +dustpan +duvet +DVD +dye +eagle +ear +earmuff +earphone +earplug +earring +earthquake +easel +easter +easter bunny +easter egg +eat +restaurant +eclair +eclipse +ecosystem +edit +education +educator +eel +egg +egg roll +egg tart +eggbeater +egret +Eiffel tower +elastic band +senior +electric chair +electric drill +electrician +electricity +electron +electronic +elephant +elevation map +elevator +elevator car +elevator door +elevator lobby +elevator shaft +embankment +embassy +embellishment +ember +emblem +embroidery +emerald +emergency +emergency service +emergency vehicle +emotion +Empire State Building +enamel +enclosure +side table +energy +engagement +engagement ring +engine +engine room +engineer +engineering +english shorthair +ensemble +enter +entertainer +entertainment +entertainment center +entrance +entrance hall +envelope +equestrian +equipment +eraser +erhu +erosion +escalator +escargot +espresso +estate +estuary +eucalyptus tree +evening +evening dress +evening light +evening sky +evening sun +event +evergreen +ewe +excavation +exercise +exhaust hood +exhibition +exit +explorer +explosion +extension cord +extinguisher +extractor +extrude +eye +eye shadow +eyebrow +eyeliner +fabric +fabric store +facade +face +face close-up +face powder +face towel +facial tissue holder +facility +factory +factory workshop +fair +fairground +fairy +falcon +fall +family +family car +family photo +family room +fan +fang +farm +farmer +farmer market +farmhouse +fashion +fashion accessory +fashion designer +fashion girl +fashion illustration +fashion look +fashion model +fashion show +fast food +fastfood restaurant +father +faucet +fault +fauna +fawn +fax +feast +feather +fedora +feed +feedbag +feeding +feeding chair +feline +mountain lion +fence +fender +fern +ferret +ferris wheel +ferry +fertilizer +festival +fiber +fiction +fiction book +field +field road +fig +fight +figure skater +figurine +file +file photo +file cabinet +fill +film camera +film director +film format +film premiere +film producer +filming +filter +fin +hand +finish line +fir +fir tree +fire +fire alarm +fire department +fire truck +fire escape +fire hose +fire pit +fire station +firecracker +fireman +fireplace +firework +firework display +first-aid kit +fish +fish boat +fish market +fish pond +fishbowl +fisherman +fishing +fishing boat +fishing net +fishing pole +fishing village +fitness +fitness course +five +fixture +fjord +flag +flag pole +flake +flame +flamingo +flannel +flap +flare +flash +flask +flat +flatfish +flavor +flea +flea market +fleet +flight +flight attendant +flip +flip-flop +flipchart +float +flock +flood +floor +floor fan +floor mat +floor plan +floor window +floral arrangement +florist +floss +flour +flow +flower +flower basket +flower bed +flower box +flower field +flower girl +flower market +fluid +flush +flute +fly +fly fishing +flyer +horse +foam +fog +foggy +foie gra +foil +folding chair +leaf +folk artist +folk dance +folk rock artist +fondant +hotpot +font +food +food coloring +food court +food processor +food stand +food truck +foosball +foot +foot bridge +football +football coach +football college game +football match +football field +football game +football helmet +football player +football stadium +football team +path +footprint +footrest +footstall +footwear +forbidden city +ford +forehead +forest +forest fire +forest floor +forest path +forest road +forge +fork +forklift +form +formal garden +formation +formula 1 +fort +fortification +forward +fossil +foundation +fountain +fountain pen +fox +frame +freckle +highway +lorry +French +French bulldog +French fries +French toast +freshener +fridge +fried chicken +fried egg +fried rice +friendship +frisbee +frog +frost +frosting +frosty +frozen +fruit +fruit cake +fruit dish +fruit market +fruit salad +fruit stand +fruit tree +fruits shop +fry +frying pan +fudge +fuel +fume hood +fun +funeral +fungi +funnel +fur +fur coat +furniture +futon +gadget +muzzle +galaxy +gallery +game +game board +game controller +ham +gang +garage +garage door +garage kit +garbage +garden +garden asparagus +garden hose +garden spider +gardener +gardening +garfield +gargoyle +wreath +garlic +garment +gas +gas station +gas stove +gasmask +collect +gathering +gauge +gazebo +gear +gecko +geisha +gel +general store +generator +geranium +ghost +gift +gift bag +gift basket +gift box +gift card +gift shop +gift wrap +gig +gin +ginger +gingerbread +gingerbread house +ginkgo tree +giraffe +girl +give +glacier +gladiator +glass bead +glass bottle +glass bowl +glass box +glass building +glass door +glass floor +glass house +glass jar +glass plate +glass table +glass vase +glass wall +glass window +glasses +glaze +glider +earth +glove +glow +glue pudding +go +go for +goal +goalkeeper +goat +goat cheese +gobi +goggles +gold +gold medal +Golden Gate Bridge +golden retriever +goldfish +golf +golf cap +golf cart +golf club +golf course +golfer +goose +gorilla +gothic +gourd +government +government agency +gown +graduate +graduation +grain +grampus +grand prix +grandfather +grandmother +grandparent +granite +granola +grape +grapefruit +wine +grass +grasshopper +grassland +grassy +grater +grave +gravel +gravestone +gravy +gravy boat +gray +graze +grazing +green +greenery +greet +greeting +greeting card +greyhound +grid +griddle +grill +grille +grilled eel +grind +grinder +grits +grocery bag +grotto +ground squirrel +group +group photo +grove +grow +guacamole +guard +guard dog +guest house +guest room +guide +guinea pig +guitar +guitarist +gulf +gull +gun +gundam +gurdwara +guzheng +gym +gymnast +habitat +hacker +hail +hair +hair color +hair spray +hairbrush +haircut +hairgrip +hairnet +hairpin +hairstyle +half +hall +halloween +halloween costume +halloween pumpkin +halter top +hamburg +hamburger +hami melon +hammer +hammock +hamper +hamster +hand dryer +hand glass +hand towel +handbag +handball +handcuff +handgun +handkerchief +handle +handsaw +handshake +handstand +handwriting +hanfu +hang +hangar +hanger +happiness +harbor +harbor seal +hard rock artist +hardback book +safety helmet +hardware +hardware store +hardwood +hardwood floor +mouth organ +pipe organ +harpsichord +harvest +harvester +hassock +hat +hatbox +hautboy +hawthorn +hay +hayfield +hazelnut +head +head coach +headlight +headboard +headdress +headland +headquarter +hearing +heart +heart shape +heat +heater +heather +hedge +hedgehog +heel +helicopter +heliport +helmet +help +hen +henna +herb +herd +hermit crab +hero +heron +hibiscus +hibiscus flower +hide +high bar +high heel +highland +highlight +hike +hiker +hiking boot +hiking equipment +hill +hill country +hill station +hillside +hindu temple +hinge +hip +hip hop artist +hippo +historian +historic +history +hockey +hockey arena +hockey game +hockey player +hockey stick +hoe +hole +vacation +holly +holothurian +home +home appliance +home base +home decor +home interior +home office +home theater +homework +hummus +honey +beehive +honeymoon +hood +hoodie +hook +jump +horizon +hornbill +horned cow +hornet +horror +horror film +horse blanket +horse cart +horse farm +horse ride +horseback +horseshoe +hose +hospital +hospital bed +hospital room +host +inn +hot +hot air balloon +hot dog +hot sauce +hot spring +hotel +hotel lobby +hotel room +hotplate +hourglass +house +house exterior +houseplant +hoverboard +howler +huddle +hug +hula hoop +person +humidifier +hummingbird +humpback whale +hunt +hunting lodge +hurdle +hurricane +husky +hut +hyaena +hybrid +hydrangea +hydrant +seaplane +ice +ice bag +polar bear +ice cave +icecream +ice cream cone +ice cream parlor +ice cube +ice floe +ice hockey player +ice hockey team +lollipop +ice maker +rink +ice sculpture +ice shelf +skate +ice skating +iceberg +icicle +icing +icon +id photo +identity card +igloo +light +iguana +illuminate +illustration +image +impala +incense +independence day +individual +indoor +indoor rower +induction cooker +industrial area +industry +infantry +inflatable boat +information desk +infrastructure +ingredient +inhalator +injection +injury +ink +inking pad +inlet +inscription +insect +install +instrument +insulated cup +interaction +interior design +website +intersection +interview +invertebrate +invitation +ipad +iphone +ipod +iris +iron +ironing board +irrigation system +island +islet +isopod +ivory +ivy +izakaya +jack +jackcrab +jacket +jacuzzi +jade +jaguar +jail cell +jam +japanese garden +jasmine +jaw +jay +jazz +jazz artist +jazz fusion artist +jeans +jeep +jelly +jelly bean +jellyfish +jet +motorboat +jewel +jewellery +jewelry shop +jigsaw puzzle +rickshaw +jockey +jockey cap +jog +joint +journalist +joystick +judge +jug +juggle +juice +juicer +jujube +jump rope +jumpsuit +jungle +junkyard +kale +kaleidoscope +kangaroo +karaoke +karate +karting +kasbah +kayak +kebab +key +keycard +khaki +kick +kilt +kimono +kindergarden classroom +kindergarten +king +king crab +kiss +kit +kitchen +kitchen cabinet +kitchen counter +kitchen floor +kitchen hood +kitchen island +kitchen sink +kitchen table +kitchen utensil +kitchen window +kitchenware +kite +kiwi +knee pad +kneel +knife +rider +knit +knitting needle +knob +knocker +knot +koala +koi +ktv +laboratory +lab coat +label +labrador +maze +lace +lace dress +ladder +ladle +ladybird +lagoon +lake +lake district +lake house +lakeshore +lamb +lamb chop +lamp post +lamp shade +spear +land +land vehicle +landfill +landing +landing deck +landmark +landscape +landslide +lanyard +lantern +lap +laptop +laptop keyboard +larva +lasagne +laser +lash +lasso +latch +latex +latte +laugh +launch +launch event +launch party +laundromat +laundry +laundry basket +laundry room +lava +lavender +lawn +lawn wedding +lawyer +lay +lead +lead singer +lead to +leader +leak +lean +learn +leash +leather +leather jacket +leather shoe +speech +lecture hall +lecture room +ledge +leftover +leg +legend +legging +legislative chamber +lego +legume +lemon +lemon juice +lemonade +lemur +lens +lens flare +lentil +leopard +leotard +tights +leprechaun +lesson +letter +mailbox +letter logo +lettering +lettuce +level +library +license +license plate +lichen +lick +lid +lie +life belt +life jacket +lifeboat +lifeguard +lift +light fixture +light show +light switch +lighting +lightning +lightning rod +lilac +lily +limb +lime +limestone +limo +line +line art +line up +linen +liner +lion +lip balm +lipstick +liquid +liquor store +list +litchi +live +livestock +living room +living space +lizard +load +loading dock +loafer +hallway +locate +lock +lock chamber +locker +loft +log +log cabin +logo +loki +long hair +longboard +loom +loop +lose +lottery +lotus +love +loveseat +luggage +lumber +lumberjack +lunch +lunch box +lush +luxury +luxury yacht +mac +macadamia +macaque +macaroni +macaw +machete +machine +machine gun +magazine +magic +magician +magnet +magnifying glass +magnolia +magpie +mahjong +mahout +maid +chain mail +mail slot +make +makeover +makeup artist +makeup tool +mallard +mallard duck +mallet +mammal +mammoth +man +management +manager +manatee +mandala +mandarin orange +mandarine +mane +manga +manger +mango +mangosteen +mangrove +manhattan +manhole +manhole cover +manicure +mannequin +manor house +mansion +mantid +mantle +manufactured home +manufacturing +manuscript +map +maple +maple leaf +maple syrup +maraca +marathon +marble +march +marching band +mare +marigold +marine +marine invertebrate +marine mammal +puppet +mark +market +market square +market stall +marriage +martial +martial artist +martial arts gym +martini +martini glass +mascara +mascot +mashed potato +masher +mask +massage +mast +mat +matador +match +matchbox +material +mattress +mausoleum +maxi dress +meal +measuring cup +measuring tape +meat +meatball +mechanic +mechanical fan +medal +media +medical equipment +medical image +medical staff +medicine cabinet +medieval +medina +meditation +meerkat +meet +melon +monument +menu +mermaid +net +mess +messenger bag +metal +metal artist +metal detector +meter +mezzanine +microphone +microscope +microwave +midnight +milestone +military uniform +milk +milk can +milk tea +milkshake +mill +mine +miner +mineral +mineral water +miniskirt +miniature +minibus +minister +minivan +mint +mint candy +mirror +miss +missile +mission +mistletoe +mix +mixer +mixing bowl +mixture +moat +mobility scooter +model +model car +modern +modern tower +moisture +mold +molding +mole +monarch +money +monitor +monk +monkey +monkey wrench +monochrome +monocycle +monster truck +moon +moon cake +moonlight +moor +moose +swab +moped +morning +morning fog +morning light +morning sun +mortar +mosaic +mosque +mosquito +moss +motel +moth +mother +motherboard +motif +sport +motor +motorbike +motorcycle +motorcycle helmet +motorcycle racer +motorcyclist +motorsport +mound +mountain +mountain bike +mountain biker +mountain biking +mountain gorilla +mountain lake +mountain landscape +mountain pass +mountain path +mountain range +mountain river +mountain snowy +mountain stream +mountain view +mountain village +mountaineer +mountaineering bag +mouse +mousepad +mousetrap +mouth +mouthwash +move +movie poster +movie ticket +mower +mp3 player +mr +mud +muffin +mug +mulberry +mulch +mule +municipality +mural +muscle +muscle car +museum +mushroom +music +music festival +music stool +music studio +music video performer +musical keyboard +musician +mussel +mustard +mythology +nacho +nail polish +nailfile +nanny +napkin +narrow +national flag +nativity scene +natural history museum +nature +nature reserve +navigation +navratri +navy +nebula +neck +neckband +necklace +neckline +nectar +nectarine +needle +neighbor +neighbourhood +neon +neon light +nerve +nest +new year +newborn +newfoundland +newlywed +news +news conference +newsstand +night +night market +night sky +night view +nightclub +nightstand +noodle +nose +noseband +note +notebook +notepad +notepaper +notice +number icon +nun +nurse +nursery +nursing home +nut +nutcracker +oak +oak tree +oar +oasis +oast house +oatmeal +oats +obelisk +observation tower +observatory +obstacle course +sea +octopus +offer +office +office building +office chair +office cubicle +office desk +office supply +office window +officer +official +oil +oil lamp +oil painting +oilrig +okra +old photo +olive +olive oil +olive tree +omelet +onion +onion ring +opal +open +opening +opening ceremony +opera +opera house +operate +operating room +operation +optical shop +orangutan +orange +orange juice +orange tree +orangery +orbit +orchard +orchestra pit +orchid +order +organization +origami +ornament +osprey +ostrich +otter +out +outcrop +outdoor +outhouse +electric outlet +outline +oval +oven +overall +overcoat +overpass +owl +oyster +teething ring +pack +package +paddock +police van +padlock +paella +pagoda +pain +paint brush +painter +paisley bandanna +palace +palette +paling +pall +palm tree +pan +pancake +panda +panel +panorama +pansy +pant +pantry +pants +pantyhose +papaya +paper +paper bag +paper cutter +paper lantern +paper plate +paper towel +paperback book +paperweight +parachute +parade +paradise +parrot +paramedic +paraquet +parasail +paratrooper +parchment +parish +park +park bench +parking +parking garage +parking meter +parking sign +parliament +parsley +participant +partner +partridge +party +party hat +pass +passage +passbook +passenger +passenger ship +passenger train +passion fruit +passport +pasta +paste +pastry +pasture +patch +patient +pattern +pavement +pavilion +paw +pay +payphone +pea +peace +peach +peacock +peak +peanut +peanut butter +pear +pearl +pebble +pecan +pedestrian +pedestrian bridge +pedestrian street +peel +peeler +pegboard +pegleg +pelican +pen +penalty kick +pencil +pencil case +pencil sharpener +pencil skirt +pendant +pendulum +penguin +peninsula +pennant +penny +piggy bank +peony +pepper +pepper grinder +peppercorn +pepperoni +perch +perform +performance +performance arena +perfume +pergola +persian cat +persimmon +personal care +personal flotation device +pest +pet +pet shop +pet store +petal +petunia +church bench +pheasant +phenomenon +philosopher +phone +phonebook +record player +photo +photo booth +photo frame +photography +physicist +physics laboratory +pianist +piano +plectrum +pick up +pickle +picnic +picnic area +picnic basket +picnic table +picture +picture frame +pie +pigeon +pilgrim +tablet +pillow +pilot +pilot boat +pin +pine +pine cone +pine forest +pine nut +pineapple +table tennis table +table tennis +pink +pint +pipa +pipe +pipe bowl +pirate +pirate flag +pirate ship +pistachio +ski slope +pocket bread +pitaya +pitbull +pitch +pitcher +pitcher plant +pitchfork +pizza +pizza cutter +pizza pan +pizzeria +placard +place +place mat +plaid +plain +plan +planet +planet earth +plank +plant +plantation +planting +plaque +plaster +plastic +plasticine +plateau +platform +platinum +platter +play +play badminton +play baseball +play basketball +play billiard +play football +play pong +play tennis +play volleyball +player +playground +playhouse +playing card +playing chess +playing golf +playing mahjong +playingfield +playpen +playroom +plaza +plier +plot +plow +plug +plug hat +plum +plumber +plumbing fixture +plume +plywood +pocket +pocket watch +pocketknife +pod +podium +poetry +poinsettia +point +pointer +poker card +poker chip +poker table +pole +polecat +police +police car +police dog +police station +politician +polka dot +pollen +pollution +polo +polo neck +polo shirt +pomegranate +pomeranian +poncho +pond +ponytail +poodle +pool +pop +pop artist +popcorn +pope +poppy +porcelain +porch +pork +porridge +portable battery +portal +portfolio +porthole +portrait +portrait session +pose +possum +post +post office +stamp +postcard +poster +poster page +pot +potato +potato chip +potato salad +potholder +potty +pouch +poultry +pound +pour +powder +power line +power plugs and sockets +power see +power station +practice +Prague Castle +prayer +preacher +premiere +prescription +show +presentation +president +press room +pressure cooker +pretzel +prince +princess +print +printed page +printer +printing +prison +produce +product +profession +professional +professor +project picture +projection screen +projector +prom +promenade +propeller +prophet +proposal +protective suit +protest +protester +publication +publicity portrait +ice hockey +pudding +puddle +puff +puffin +pug +pull +pulpit +pulse +pump +pumpkin +pumpkin pie +pumpkin seed +punch bag +punch +student +purple +push +putt +puzzle +tower +pyramid +python +qr code +quail +quarry +quarter +quartz +queen +quesadilla +queue +quiche +quilt +quilting +quote +rabbit +raccoon +race +race track +raceway +race car +racket +radar +radiator +radio +raft +rag doll +rail +railcar +railroad +railroad bridge +railway line +railway station +rain +rain boot +rainbow +rainbow trout +raincoat +rainforest +rainy +raisin +rake +ram +ramp +rapeseed +rapid +rapper +raspberry +rat +ratchet +raven +ravine +ray +razor +razor blade +read +reading +reamer +rear +rear light +rear view +rearview mirror +receipt +receive +reception +recipe +record +record producer +recorder +recording studio +recreation room +recreational vehicle +rectangle +recycling +recycling bin +red +red carpet +red flag +red panda +red wine +redwood +reed +reef +reel +referee +reflect +reflection +reflector +register +rein +reindeer +relax +release +relief +religion +religious +relish +remain +remodel +remote +remove +repair +repair shop +reptile +rescue +rescuer +research +researcher +reservoir +residence +residential neighborhood +resin +resort +resort town +restaurant kitchen +restaurant patio +restroom +retail +retriever +retro +reveal +rhinoceros +rhododendron +rib +ribbon +rice +rice cooker +rice field +ride +ridge +riding +rifle +rim +ring +riot +ripple +rise +rise building +river +river bank +river boat +river valley +riverbed +road +road sign +road trip +roadside +roast chicken +robe +robin +robot +stone +rock arch +rock artist +rock band +rock climber +rock climbing +rock concert +rock face +rock formation +rocker +rocket +rocking chair +rocky +rodent +rodeo +rodeo arena +roe +roe deer +roller +coaster +roller skate +roller skates +rolling pin +romance +romantic +roof +roof garden +room +room divider +root +root beer +rope bridge +rosary +rose +rosemary +rosy cloud +rottweiler +round table +router +row +rowan +royal +rubber stamp +rubble +rubik's cube +ruby +ruffle +rugby +rugby ball +rugby player +ruins +ruler +rum +run +runner +running shoe +rural +rust +rustic +rye +sack +saddle +saddlebag +safari +safe +safety vest +sage +sail +sailboat +sailing +sailor +squirrel monkey +sake +salad +salad bowl +salamander +salami +sale +salmon +salon +salsa +salt +salt and pepper shakers +salt lake +salt marsh +salt shaker +salute +samoyed +samurai +sand +sand bar +sand box +sand castle +sand sculpture +sandal +sandwich +sanitary napkin +santa claus +sapphire +sardine +sari +sashimi +satay +satchel +satellite +satin +sauce +saucer +sauna +sausage +savanna +saw +sawbuck +sax +saxophonist +scaffold +scale +scale model +scallop +scar +strawman +scarf +scene +scenery +schnauzer +school +school bus +school uniform +schoolhouse +schooner +science +science fiction film +science museum +scientist +scissors +wall lamp +scone +scoop +scooter +score +scoreboard +scorpion +scout +scrambled egg +scrap +scraper +scratch +screen +screen door +screenshot +screw +screwdriver +scroll +scrub +scrubbing brush +sculptor +sculpture +sea cave +sea ice +sea lion +sea turtle +sea urchin +seabass +seabed +seabird +seafood +seahorse +seal +sea view +seashell +seaside resort +season +seat +seat belt +seaweed +secretary +security +sedan +see +seed +seesaw +segway +selfie +sell +seminar +sense +sensor +server +server room +service +set +sewing machine +shadow +shake +shaker +shampoo +shape +share +shark +sharpener +sharpie +shaver +shaving cream +shawl +shear +shears +sheep +sheet +sheet music +shelf +shell +shellfish +shelter +shelve +shepherd +sherbert +shiba inu +shine +shipping +shipping container +shipwreck +shipyard +shirt +shirtless +shoal +shoe +shoe box +shoe shop +shoe tree +shoot +shooting basketball guard +shop window +shopfront +shopper +shopping +shopping bag +shopping basket +shopping cart +mall +shopping street +shore +shoreline +short +short hair +shorts +shot glass +shotgun +shoulder +shoulder bag +shovel +showcase +shower +shower cap +shower curtain +shower door +shower head +shredder +shrew +shrimp +shrine +shrub +shutter +siamese +siberia +sibling +side +side cabinet +side dish +sidecar +sideline +siding +sign +signage +signal +signature +silk +silk stocking +silo +silver +silver medal +silverware +sing +singe +singer +sink +sip +sit +sitting +skate park +skateboard +skateboarder +skater +skating rink +skeleton +sketch +skewer +ski +ski boot +ski equipment +ski jacket +ski lift +ski pole +ski resort +snowboard +skier +skiing shoes +skin +skull +skullcap +sky +sky tower +skylight +skyline +skyscraper +slalom +slate +sleigh +sleep +sleeping bag +sleepwear +sleeve +slice +slide +slider +sling +slope +slot +slot machine +sloth +slow cooker +slug +slum +smell +smile +smoke +snack +snail +snake +snapper +snapshot +snorkel +snout +snow +snow leopard +snow mountain +snowball +snowboarder +snowfield +snowflake +snowman +snowmobile +snowplow +snowshoe +snowy +soap +soap bubble +soap dispenser +soccer goalkeeper +socialite +sock +socket +soda +softball +software +solar battery +soldier +solo +solution +sombrero +song +sound +soup +soup bowl +soupspoon +sour cream +souvenir +soybean milk +spa +space +space shuttle +space station +spacecraft +spaghetti +span +wrench +spark +sparkle +sparkler +sparkling wine +sparrow +spatula +speaker +spectator +speech bubble +speed limit +speed limit sign +speedboat +speedometer +sphere +spice +spice rack +spider +spider web +spike +spin +spinach +spire +splash +sponge +spoon +sport association +sport equipment +sport team +sports ball +sports equipment +sports meet +sportswear +dot +spray +spread +spring +spring roll +sprinkle +sprinkler +sprout +spruce +spruce forest +squad +square +squash +squat +squeeze +squid +squirrel +water gun +stab +stable +stack +stadium +staff +stage +stage light +stagecoach +stain +stainless steel +stair +stairs +stairwell +stall +stallion +stand +standing +staple +stapler +star +stare +starfish +starfruit +starling +state park +state school +station +stationary bicycle +stationery +statue +steak +steak knife +steam +steam engine +steam locomotive +steam train +steamed bread +steel +steering wheel +stem +stencil +step stool +stereo +stethoscope +stew +stick +stick insect +sticker +still life +stilt +stingray +stir +stirrer +stirrup +sew +stock +stocking +stomach +stone building +stone carving +stone house +stone mill +stool +stop +stop at +stop light +stop sign +stop watch +traffic light +storage box +storage room +tank +store +storefront +stork +storm +storm cloud +stormy +stove +poker +straddle +strainer +strait +strap +straw +straw hat +strawberry +stream +street art +street artist +street corner +street dog +street food +street light +street market +street photography +street scene +street sign +street vendor +stretch +stretcher +strike +striker +string +string cheese +strip +stripe +stroll +structure +studio +studio shot +stuff +stuffed animal +stuffed toy +stuffing +stump +stunning +stunt +stupa +style +stylus +submarine +submarine sandwich +submarine water +suburb +subway +subway station +subwoofer +succulent +suede +sugar +sugar bowl +sugar cane +sugar cube +suit +suite +summer +summer evening +summit +sun +sun hat +sunbathe +sunday +sundial +sunflower +sunflower field +sunflower seed +sunglasses +sunny +sunrise +sunset +sunshade +sunshine +super bowl +sports car +superhero +supermarket +supermarket shelf +supermodel +supporter +surf +surface +surfboard +surfer +surgeon +surgery +surround +sushi +sushi bar +suspenders +suspension +suspension bridge +suv +swallow +swallowtail butterfly +swamp +swan +swan boat +sweat pant +sweatband +sweater +sweatshirt +sweet +sweet potato +swim +swim cap +swimmer +swimming hole +swimming pool +swing +swing bridge +swinge +swirl +switch +swivel chair +sword +swordfish +symbol +symmetry +synagogue +syringe +syrup +system +t shirt +t-shirt +tabasco sauce +tabby +table tennis racket +table top +tablecloth +tablet computer +tableware +tachometer +tackle +taco +tae kwon do +tai chi +tail +tailor +take +takeoff +talk +tambourine +tan +tangerine +tape +tapestry +tarmac +taro +tarp +tart +tassel +taste +tatami +tattoo +tattoo artist +tavern +tea +tea bag +tea party +tea plantation +tea pot +tea set +teach +teacher +teacup +teal +team photo +team presentation +tear +technician +technology +teddy +tee +teenager +telegraph pole +zoom lens +telescope +television +television camera +television room +television studio +temperature +temple +tempura +tennis +tennis court +tennis match +tennis net +tennis player +tennis racket +tent +tequila +terminal +terrace +terrain +terrarium +territory +test +test match +test tube +text +text message +textile +texture +thanksgiving +thanksgiving dinner +theater +theatre actor +therapy +thermometer +thermos +thermos bottle +thermostat +thicket +thimble +thing +thinking +thistle +throne +throne room +throw +throw pillow +thunder +thunderstorm +thyme +tiara +tick +ticket +ticket booth +tide pool +tie +tiger +tight +tile +tile flooring +tile roof +tile wall +tin +tinfoil +tinsel +tiramisu +tire +tissue +toast +toaster +tobacco +tobacco pipe +toddler +toe +tofu +toilet bowl +toilet seat +toiletry +tokyo tower +tomato +tomato sauce +tomato soup +tomb +tong +tongs +tool +toolbox +toothbrush +toothpaste +toothpick +topiary garden +topping +torch +tornado +tortilla +tortoise +tote bag +totem pole +totoro +toucan +touch +touchdown +tour +tour bus +tour guide +tourist +tourist attraction +tournament +tow truck +towel +towel bar +tower block +tower bridge +town +town square +toy +toy car +toy gun +toyshop +track +tractor +trade +tradition +traditional +traffic +traffic cone +traffic congestion +traffic jam +traffic sign +trail +trailer +trailer truck +train +train bridge +train car +train interior +train track +train window +trainer +training +training bench +training ground +trolley +trampoline +transformer +transparency +travel +tray +treadmill +treat +tree +tree branch +tree farm +tree frog +tree house +tree root +tree trunk +trial +triangle +triathlon +tribe +tributary +trick +tricycle +trim +trio +tripod +trombone +troop +trophy +trophy cup +tropic +trout +truck +truck driver +tub +tube +tugboat +tulip +tuna +tundra +tunnel +turbine +turkey +turn +turnip +turquoise +turret +turtle +tusk +tv actor +tv cabinet +tv drama +tv genre +tv personality +tv show +tv sitcom +tv tower +twig +twilight +twin +twine +twist +type +type on +typewriter +ukulele +ultraman +umbrella +underclothes +underwater +unicorn +uniform +universe +university +up +urban +urinal +urn +use +utensil +utility room +vacuum +valley +valve +vampire +van +vanilla +vanity +variety +vase +vault +vector cartoon illustration +vector icon +vegetable +vegetable garden +vegetable market +vegetation +vehicle +veil +vein +velvet +vending machine +vendor +vent +vespa +vessel +vest +vet +veteran +veterinarians office +viaduct +video +video camera +video game +videotape +view mirror +vigil +villa +village +vine +vinegar +vineyard +violence +violet +violin +violinist +violist +vision +visor +vodka +volcano +volleyball +volleyball court +volleyball player +volunteer +voyage +vulture +waffle +waffle iron +wagon +wagon wheel +waist +waiter +waiting hall +waiting room +walk +walking +walking cane +wall clock +wallpaper +walnut +walrus +war +warehouse +warm +warning sign +warrior +warship +warthog +wash +washer +washing +washing machine +wasp +waste +waste container +watch +water +water bird +water buffalo +water cooler +water drop +water feature +water heater +water level +water lily +water park +water pipe +water purifier +water ski +water sport +water surface +water tower +watercolor +watercolor illustration +watercolor painting +waterfall +watering can +watermark overlay stamp +watermelon +waterproof jacket +waterway +wave +wax +weapon +wear +weather +vane +web +webcam +wedding +wedding ring +wedding bouquet +wedding cake +wedding couple +wedding invitation +wedding party +wedding photo +wedding photographer +wedding photography +wedding reception +wedge +weed +weight +weight scale +welder +well +western food +western restaurant +wet +wet bar +wet suit +wetland +wetsuit +whale +whale shark +wheat +wheat field +wheel +wheelchair +wheelie +whipped cream +whisk +whisker +whiskey +whistle +white +white house +white wine +whiteboard +wicket +wide +wield +wig +Wii +Wii controller +wild +wildebeest +wildfire +wildflower +wildlife +willow +wind +wind chime +wind farm +wind turbine +windmill +window +window box +window display +window frame +window screen +window seat +window sill +wiper +windshield +windy +wine bottle +wine cooler +wine cabinet +wine cellar +wine glass +wine rack +wine tasting +winery +wing +winter +winter melon +winter morning +winter scene +winter sport +winter storm +wire +wisteria +witch +witch hat +wok +wolf +woman +wood +wood duck +wood floor +wood wall +wood-burning stove +wooden spoon +woodland +woodpecker +woodworking plane +wool +job +work card +workbench +worker +workplace +workshop +world +worm +worship +wound +wrap +wrap dress +wrapping paper +wrestle +wrestler +wrinkle +wristband +write +writer +writing +writing brush +writing desk +yacht +yak +yard +yellow +yoga +yoga mat +yoghurt +yoke +yolk +youth +youth hostel +yurt +zebra +zebra crossing +zen garden +zip +zipper +zombie +zongzi +zoo \ No newline at end of file diff --git a/ram/data/ram_tag_list_chinese.txt b/ram/data/ram_tag_list_chinese.txt new file mode 100644 index 0000000000000000000000000000000000000000..3f61dc0b84ed58e019d7e331555ef438f2ded2de --- /dev/null +++ b/ram/data/ram_tag_list_chinese.txt @@ -0,0 +1,4585 @@ +三维CG渲染 +3d眼镜 +算盘 +鲍鱼 +修道院 +肚子 +学院 +附件 +事故 +手风琴 +橡子 +丙烯颜料 +表演 +行动 +动作电影 +活动 +演员 +改编本 +添加 +胶带 +调整 +成人 +冒险 +广告 +天线 +有氧运动 +喷雾罐 +爆炸头 +农业 +帮助 +空调 +空调系统 +风向标 +飞机客舱 +飞机模型 +机场 +航线 +客机 +飞行员 +飞机 +飞机窗口 +机场 +机场跑道 +航站楼 +飞艇 +航展 +过道 +警报 +闹钟 +信天翁 +唱片 +唱片封面 +酒精 +壁龛 +水藻 +胡同/球道 +杏仁 +芦荟 +高山 +羊驼 +字母表 +德国牧羊犬 +圣坛 +琥珀 +救护车 +秃鹰 +美国短毛猫 +紫水晶 +圆形剧场 +扩音器 +游乐园 +游乐设施 +锚 +古老的 +海葵 +天使 +角 +动物 +动物雕塑 +动物收容所 +动画片 +动画电影 +动画师 +动漫 +脚踝 +短袜 +周年庆 +风衣 +蚂蚁 +羚羊 +古董 +鹿角 +铁砧 +公寓 +猿 +应用程序 +应用图标 +出现 +外观 +开胃菜 +掌声 +苹果 +苹果汁 +苹果派 +苹果树 +苹果酱 +设备 +约定 +通道 +杏子 +围裙 +浅绿色 +水族馆 +观赏鱼 +渡槽 +游乐中心 +商场游戏机 +拱门 +拱桥 +考古现场 +射箭 +群岛 +建筑师 +建筑设计 +档案 +拱门 +地区 +竞技场 +争论 +手臂 +穿山甲 +臂章 +扶手椅 +衣柜 +盔甲 +军队 +军事基地 +坦克 +阵列 +逮捕 +箭头 +艺术 +艺术展 +美术馆 +艺术印刷品 +艺术学校 +艺术工作室 +艺术矢量插图 +洋蓟 +文章 +手工艺品 +艺术家 +艺术阁楼 +灰 +烟灰缸 +亚洲寺庙 +芦笋 +沥青道路 +组装 +集会 +生产流水线 +协会 +宇航员 +天文学家 +运动员 +运动 +地图集 +自助取款机 +大气层 +中庭 +连接 +战斗机 +参加 +吸引力 +全地形车 +茄子 +拍卖 +奥迪汽车 +音频 +礼堂 +极光 +作者 +汽车厂 +汽车修理工 +汽车零件 +车展 +汽车展厅 +汽车电池 +汽车制造 +汽车模型 +汽车 +秋天 +秋天的森林 +秋天的叶子 +秋天的公园 +秋天的树 +阿凡达 +林荫大道 +飞行员太阳镜 +牛油果 +奖品 +颁奖典礼 +获奖者 +棚 +斧头 +杜鹃花 +狒狒 +婴儿 +奶瓶 +婴儿车 +婴儿衣服 +小象 +婴儿食品 +婴儿座椅 +迎婴派对 +背后/后面 +背景 +背光 +背包 +后院 +培根 +徽章 +獾 +荒地 +羽毛球运动 +羽毛球拍 +袋子 +面包圈 +风笛 +法棍 +诱饵 +焙烤食品 +面包师 +面包店 +烘焙 +烤盘 +平衡 +平衡车 +阳台 +球 +球池 +芭蕾舞女演员 +芭蕾舞 +芭蕾舞演员 +芭蕾舞裙 +气球 +气球拱门 +棒球手 +舞厅 +竹子 +竹林 +香蕉 +香蕉面包 +香蕉叶子 +香蕉树 +乐队 +创可贴 +绷带 +头巾 +束发带 +刘海 +手镯 +栏杆 +五弦琴 +银行 +银行卡 +银行金库 +纸币 +横幅/旗帜 +宴会 +宴会厅 +榕树 +包子 +洗礼 +酒吧 +条形码 +高脚凳 +烧烤 +烧烤架 +杠铃 +理发师 +理发店 +芭比娃娃 +驳船 +咖啡师 +树皮 +大麦 +谷仓 +仓鸮 +挡光板 +桶 +路障 +屏障 +手推车 +酒保 +棒球 +棒球基地 +棒球棒 +棒球帽 +棒球场 +棒球比赛 +棒球手套 +棒球投手 +棒球队 +棒球制服 +地下室 +罗勒 +水盆 +篮子 +篮子 +篮球 +篮球篮板 +篮球教练 +篮球场 +篮球比赛 +篮球框 +篮球运动员 +篮球馆 +篮球队 +贝斯 +低音吉他 +低音喇叭 +贝斯手 +球棒/球拍 +浴室 +水浴加热器 +浴垫 +浴巾 +泳装 +浴袍 +浴室 +浴室配件 +浴室柜 +浴室门 +浴室镜子 +浴室水槽 +卫生纸 +浴室窗户 +蝙蝠侠 +棒子 +接连猛打/击球员 +电池 +战斗 +战绳 +战舰 +海湾 +海湾大桥 +凸窗 +杨梅 +集市 +海滩 +沙滩球 +沙滩椅 +海滨别墅 +海滩小屋 +沙滩毛巾 +沙滩排球 +灯塔 +珠子 +比格犬 +鸟嘴 +烧杯 +横梁 +豆子 +豆袋椅 +豆袋 +熊 +幼熊 +胡子 +野兽 +击打/击败 +美丽的 +美丽 +美容院 +海狸 +床 +床单 +床架 +卧室 +床上用品 +便盆 +卧室窗户 +床头灯 +蜜蜂 +山毛榉 +牛肉 +养蜂人 +蜂鸣器 +啤酒 +啤酒瓶 +啤酒罐 +啤酒花园 +啤酒杯 +啤酒馆 +甜菜 +甲虫 +米色 +时钟 +甜椒 +钟楼 +皮带 +皮带扣 +长凳 +弯曲 +孟加拉虎 +盒饭 +贝雷帽 +浆果 +停泊位 +饮料 +围嘴 +拌饭 +圣经 +比熊 +自行车 +自行车头盔 +自行车车轮 +自行车骑士 +坐浴盆 +大本钟 +自行车道 +自行车道 +自行车赛 +骑车 +比基尼 +比基尼上衣 +账单 +台球 +广告牌 +台球台 +垃圾箱 +活页夹 +双筒望远镜 +生物学实验室 +双翼飞机 +桦木 +桦树 +鸟 +鸟池 +喂鸟器 +鸟舍 +鸟巢 +鸟池 +鸟笼 +出生 +生日 +生日蛋糕 +生日蜡烛 +生日贺卡 +生日聚会 +饼干 +主教 +野牛 +钻头 +咬 +黑色 +黑山羊 +黑莓 +乌鸦 +黑板 +铁匠 +叶片/刀片 +毯子/覆盖层 +运动外套 +看台 +搅拌机 +祝福 +窗帘 +眼罩 +闪光 +暴风雪 +块 +博客 +血 +开花 +花 +女装衬衫 +吹 +吹风机 +河豚 +蓝色 +蓝色艺术家 +蓝松鸦 +蓝天 +蓝莓 +蓝知更鸟 +猪 +板子 +板擦 +棋盘游戏 +木板路 +船 +船甲板 +船屋 +桨 +乘船 +浮标 +山猫 +躯干 +身体冲浪板 +健美运动员 +水煮鸡蛋 +锅炉 +饰扣式领带 +门闩 +炸弹 +轰炸机 +披肩榛鸡 +骨骼 +篝火 +阀盖 +盆景 +书 +书籍封面 +书柜 +文件夹 +书签 +书架 +书店 +远程拾音器 +推动 +靴子 +边界 +边境牧羊犬 +植物园 +瓶 +瓶盖 +开瓶器 +螺旋开瓶器 +三角梅 +巨石 +花束 +时装店 +精品酒店 +鞠躬/蝴蝶结 +领结 +弓形窗 +碗 +保龄球运动 +保龄球馆 +保龄球 +保龄球设备 +盒子 +箱形梁桥 +箱龟 +拳击手 +内裤 +拳击 +拳击手套 +拳击台 +男孩 +支撑物 +支架 +辫子 +大脑 +刹车 +刹车灯 +树枝 +商标 +白兰地 +黄铜 +黄铜牌匾 +面包 +面包箱 +休息 +早餐 +防浪堤 +胸部 +啤酒厂 +砖块 +砖建筑物 +墙 +砖块 +婚纱 +新娘 +新郎 +伴娘 +桥 +缰绳 +公文包 +明亮的 +边沿 +钻头 +广播 +西兰花 +青铜 +铜牌 +青铜雕塑 +青铜雕像 +胸针 +小溪 +扫帚 +肉汤 +棕色 +棕熊 +巧克力蛋糕 +早午餐 +浅黑肤色的女人 +刷子 +郊狼 +包菜 +气泡 +泡泡糖 +珍珠奶茶 +斗柜 +盾牌 +芽 +佛 +水牛 +自助餐 +昆虫 +建造 +建造者 +建筑 +积木 +建筑立面 +建筑材料 +灯 +牛 +斗牛犬 +子弹 +动车 +公告栏 +防弹背心 +斗牛 +扩音器 +斗牛场 +大黄蜂 +保险杠 +卷/地形起伏 +捆 +蹦极 +双层床 +地堡/击球 +兔子 +浮标 +书桌 +墓室 +燃烧 +玉米煎饼 +公交车 +公交车司机 +公交车内部 +公交车站 +公交车站 +公交车窗户 +灌木 +商业 +名片 +业务主管 +商务西装 +业务团队 +女商人 +商人 +半身像 +屠夫 +肉铺 +孤峰 +黄油 +奶油 +蝴蝶 +蝴蝶馆 +按钮 +梧桐树 +购买 +出租车 +小屋 +卷心菜 +小屋/机舱 +守车 +储藏柜 +橱柜 +电缆 +缆车 +仙人掌 +咖啡馆 +食堂 +笼子 +蛋糕 +蛋糕台 +计算器 +大锅 +日历 +小腿 +通话 +电话亭 +书法 +平静的 +摄像机 +骆驼 +相机 +相机镜头 +迷彩 +露营 +露营者 +篝火 +露营 +营地 +校园 +罐 +开罐器 +运河 +金丝雀 +癌症 +蜡烛 +烛台 +糖果 +块状糖 +柺杖糖 +糖果店 +拐杖 +罐子 +大炮 +树冠/顶棚 +四柱床 +香瓜 +悬臂桥 +帆布 +峡谷 +帽子 +斗篷 +科德角 +卡布奇诺 +胶囊 +队长 +捕获 +车 +汽车经销商 +车门 +汽车内饰 +车标 +后视镜 +停车场 +汽车座椅 +车展 +洗车 +车窗 +焦糖 +卡片 +纸牌游戏 +纸板 +纸板盒 +羊毛衫 +红衣凤头鸟 +货物 +货运飞机 +货船 +加勒比 +康乃馨 +狂欢节 +食肉动物 +旋转木马 +鲤鱼 +木匠 +地毯 +拖鞋 +红雀 +长途客车 +斑点狗 +航空母舰 +胡萝卜 +胡萝卜蛋糕 +携带 +手推车 +纸箱/纸盒 +卡通 +卡通人物 +卡通插图 +卡通风格 +雕刻 +容器 +现金 +腰果 +赌场 +砂锅 +磁带 +盒式录音机 +石膏绷带 +铸造 +城堡 +猫 +猫窝 +猫粮 +猫器具 +猫架 +地下墓穴 +双体船 +美洲狮 +握着/抓着 +捕手 +毛毛虫 +鲶鱼 +教堂 +牛 +猫步 +走秀 +菜花 +洞穴 +鱼子酱 +光盘 +CD播放器 +雪松 +天花板 +吊扇 +庆祝 +庆典 +名人 +芹菜 +大提琴 +手机 +水泥 +墓地 +中心装饰品 +蜈蚣 +陶瓷 +瓷砖 +麦片 +仪式 +证书 +链条 +链锯 +椅子 +升降椅 +躺椅 +木屋 +圣杯 +粉笔 +房间 +变色龙 +香槟酒 +香槟杯 +冠军 +锦标赛 +吊灯 +婴儿换尿布台 +通道 +皴裂处 +小教堂 +人物雕塑 +木炭 +充电 +充电器 +战车 +慈善机构 +慈善活动 +魅力 +图表 +追逐 +底盘 +检查/支票 +支票簿 +棋盘 +检查表 +欢呼声 +鼓励/啦啦队 +奶酪 +奶酪汉堡 +奶酪蛋糕 +猎豹 +厨师 +化合物 +化学家 +化学 +化学实验室 +旗袍 +樱桃 +樱花 +樱桃番茄 +樱桃树 +国际象棋 +栗子 +鸡 +鸡胸肉 +鸡笼 +鸡肉沙拉 +鸡翅 +鹰嘴豆 +小衣橱 +吉娃娃 +孩子 +童星 +孩子的房间 +红番椒 +辣热狗 +烟囱 +黑猩猩 +瓷器 +白菜 +中国园林 +中国结 +月季 +中国塔 +炸薯条/炸薯条 +花栗鼠 +凿子 +巧克力 +巧克力棒 +巧克力蛋糕 +巧克力碎片 +巧克力饼干 +巧克力牛奶 +巧克力慕斯 +松露 +唱诗班 +厨房刀 +砧板 +筷子 +圣诞节 +圣诞球 +圣诞贺卡 +圣诞装饰 +圣诞晚宴 +平安夜 +圣诞帽 +圣诞灯 +圣诞市场 +圣诞装饰 +圣诞树 +菊花 +教堂 +教堂塔 +苹果酒 +雪茄 +雪茄盒 +香烟 +烟盒 +腰带 +电影院 +摄影师 +肉桂 +圆 +电路 +电路板 +马戏团 +水箱 +柑橘类水果 +城市 +城市公交 +市政厅 +城市夜景 +城市公园 +城市天际线 +城市广场 +城市街道 +城墙 +城市景观 +蛤蜊 +单簧管 +扣子 +班级 +经典 +教室 +锁骨 +爪子 +黏土 +陶器 +清洁 +洁净室 +清洁工人 +清洁用品 +清晰的 +栓 +克莱门氏小柑橘 +客户端 +悬崖 +爬 +爬山 +登山者 +诊所 +夹子 +剪贴画 +剪贴板 +快速帆船 +君子兰 +斗篷 +木底鞋 +特写 +壁橱 +布 +穿衣 +衣服 +晒衣夹 +晒衣绳 +服装店 +云 +云雾森林 +多云 +三叶草 +小丑 +小丑鱼 +俱乐部 +离合器 +手拿包 +煤炭 +海岸 +外套 +衣帽架 +玉米 +公鸡 +凤头鹦鹉 +可卡犬 +驾驶 +蟑螂 +鸡尾酒 +小礼服 +鸡尾酒调制器 +鸡尾酒桌 +可可 +椰子 +椰子树 +咖啡 +咖啡豆 +咖啡杯 +咖啡机 +咖啡店 +咖啡壶 +棺材 +法国白兰地 +螺旋 +硬币 +可口可乐 +滤器 +冷的 +卷心菜沙拉 +合作 +拼贴画 +收藏品 +大学生 +牧羊犬 +碰撞 +颜色 +涂色书 +染色材料 +矮种马 +柱子 +梳子 +密码锁 +喜剧演员 +喜剧 +喜剧电影 +彗星 +舒服 +安慰食物 +漫画书 +漫画人物 +连环画 +指挥官 +评论员 +社区 +通勤 +公司 +指南针 +比赛 +比赛 +竞争者 +作曲家 +作文 +堆肥 +电脑 +电脑机箱 +电脑椅 +电脑桌 +键盘 +计算机显示器 +计算机房 +电脑屏幕 +机箱 +概念车 +音乐会 +音乐厅 +贝壳 +混凝土 +调味品 +避孕套 +独立产权的公寓 +指挥 +锥形物 +会议 +会议中心 +会议厅 +会议室 +五彩纸屑 +冲突 +合流 +连接 +连接器 +温室 +星座 +建筑工地 +建筑工人 +包含 +容器 +集装箱船 +大陆 +轮廓 +合同 +控制 +控制塔 +便利店 +集会 +交谈 +转换器 +可转换的 +输送机 +厨师/烹饪 +烹饪 +烹饪喷雾剂 +炊具 +凉的 +冷却器 +铜 +一本/一册 +珊瑚 +珊瑚礁 +粗绳 +有线电话 +酒 +威尔士矮脚狗 +瓶塞 +软木板 +鸬鹚 +玉米 +玉米田 +玉米面包 +角落 +小号 +飞檐 +燕麦片 +围栏 +走廊 +紧身衣 +化妆品 +化妆刷 +化妆镜 +角色扮演 +服装 +服装电影设计师 +婴儿床 +小屋 +棉花 +棉花糖 +沙发 +倒计时 +柜台 +台面 +最佳乡村歌手 +乡村别墅 +乡村公路 +乡村流行歌手 +农村 +双门小轿车 +夫妇/两人/几个 +情侣写真 +小胡瓜 +课程 +球场 +法院 +院子 +堂兄弟 +工作服 +奶牛 +母牛的颈铃 +牛仔 +牛仔靴 +牛仔帽 +螃蟹 +蟹肉 +裂纹 +摇篮 +工艺 +工匠 +蔓越莓 +起重机 +黑纱 +厕所 +板条箱 +火山口湖 +龙虾 +蜡笔 +奶油乳酪 +奶油罐 +创建 +生物 +信用卡 +新月形 +新月形面包 +山顶 +全体船员 +蟋蟀 +板球用球 +板球队 +板球队员 +钩边 +克罗克电锅 +鳄鱼 +庄稼 +露脐上衣 +交叉 +横木 +十字路口 +相声 +人行横道 +油煎面包块 +乌鸦 +撬棍 +人群 +拥挤的 +皇冠 +阴极射线管屏幕 +耶稣受难像 +巡游 +游轮 +巡洋艇 +面包屑 +压坏 +拐杖 +水晶 +幼兽 +立方体 +黄瓜 +球杆 +袖口 +袖扣 +烹饪 +农田 +杯子 +纸杯蛋糕 +丘比特 +马路牙子 +旋度 +卷发器 +无籽葡萄干 +货币 +咖喱 +窗帘 +曲线 +软垫 +顾客 +切 +餐具 +自行车 +骑自行车 +龙卷风 +汽缸 +铙钹 +柏树 +柏树 +达克斯猎狗 +水仙花 +匕首 +大丽花 +萝卜 +乳制品 +雏菊 +大坝 +损害 +潮湿的 +跳舞 +舞池 +舞蹈室 +舞者 +蒲公英 +黑暗 +黑暗 +飞镖 +圆靶 +指示板 +日期 +女儿 +黎明 +天床上 +日光 +门栓 +死亡 +辩论 +碎片 +玻璃水瓶 +甲板 +双层巴士 +装饰 +装修/装饰 +装饰画 +鹿 +后卫 +神 +熟食 +投递 +拆迁 +怪兽 +演示 +兽窝/休闲室 +牛仔夹克 +牙医 +百货商店 +抑郁症 +德比 +皮肤病 +沙漠 +沙漠公路 +设计 +设计师 +桌子/表格 +台灯 +桌面 +台式电脑 +甜点 +破坏 +侦探 +洗涤剂 +露水 +仪表盘 +钻石 +尿布 +尿布包 +杂志 +死 +饮食 +挖掘机 +数字 +数字时钟 +莳萝 +晚餐 +小船 +餐厅 +晚宴 +餐桌 +恐龙 +浸 +文凭 +指引 +导演 +尘埃 +越野摩托车 +泥土地 +泥土路 +泥路/土路 +灾难 +信徒 +迪斯科舞厅 +迪斯科灯秋 +迪斯科舞厅 +疾病 +盘子 +碟形天线 +洗碗机 +抹布 +菜肴 +洗碗液 +迪斯尼乐园 +自动售货机 +展示 +陈列窗 +壕沟 +潜水 +潜水员 +跳水板 +纸杯 +流行音乐播音员 +杜宾犬 +码头 +医生 +文件 +纪录片 +狗 +狗窝 +犬种 +狗项圈 +狗粮 +狗窝 +洋娃娃 +美元 +玩偶之家 +洋娃娃 +海豚 +穹顶 +住宅 +多米诺骨牌 +驴 +甜甜圈 +涂鸦 +门 +门把手 +受气包 +门牌 +门口 +宿舍 +面团 +市中心 +推土机 +拖 +龙 +蜻蜓 +排水沟 +剧本 +戏剧电影 +画 +抽屉里 +图画/画画 +图钉 +辫子 +连衣裙/特定场合的服装 +礼帽 +正装衬衫 +皮鞋 +大礼服 +梳妆台 +更衣室 +运球 +漂移 +浮木 +钻 +饮品/喝 +饮用水 +开车 +司机 +车道 +无人机 +水滴/下降 +吊灯 +滴管 +干旱 +药物 +药店 +鼓 +鼓手 +鸡腿 +干的 +公爵夫人 +鸭子 +鸭嘴兽 +小鸭子 +布基胶带 +伙计 +二重唱 +粗呢 +独木舟 +哑铃 +饺子 +沙丘 +扣篮 +榴莲 +黄昏 +灰尘 +垃圾车 +簸箕 +羽绒被 +DVD +染料 +鹰 +耳朵 +御寒耳罩 +耳机 +耳塞 +耳环 +地震 +画架 +复活节 +复活节兔子 +复活节彩蛋 +吃 +餐厅 +泡芙 +日食 +生态系统 +编辑 +教育 +教育家 +鳗鱼 +蛋 +蛋卷 +蛋挞 +打蛋器 +白鹭 +埃菲尔铁塔 +橡皮筋 +上级 +电椅 +电钻 +电工 +电 +电子 +电子器件 +大象 +高度图 +电梯 +电梯轿厢 +电梯门 +电梯大堂 +电梯井 +路堤 +大使馆 +装饰 +灰烬 +会徽 +刺绣 +翡翠 +紧急 +紧急服务 +紧急车辆 +情感 +帝国大厦 +搪瓷 +外壳/围墙 +茶几 +能源 +订婚 +订婚戒指 +引擎 +机舱 +工程师 +工程 +英国短毛猫 +乐团 +回车键 +演艺人员 +娱乐 +娱乐中心 +入口 +入口大厅 +信封 +马术 +设备 +橡皮擦 +二胡 +侵蚀 +自动扶梯 +食用蜗牛 +浓缩咖啡 +房地产 +河口 +桉树 +晚上 +晚礼服 +夜光 +傍晚天空 +晚上的太阳 +事件 +常绿的 +母羊 +挖掘 +运动 +排气罩 +展览 +出口 +探险者 +爆炸 +延长线 +灭火器 +排气扇 +挤压 +眼睛 +眼影 +眉 +眼线笔 +布料 +纺织品商店 +外观 +脸 +脸部特写 +蜜粉 +毛巾 +面巾纸架 +设施 +工厂 +工厂车间 +集市 +露天市场 +仙女 +猎鹰 +秋天 +家庭 +家庭轿车 +全家福 +家庭房 +风扇/扇子 +尖牙 +农场 +农民 +农民市场 +农舍 +时尚 +时尚配饰 +时装设计师 +时尚的女孩 +时装插图 +时装大片 +时装模特 +时装表演 +快餐 +西式快餐 +父亲 +水龙头 +故障 +动物 +小鹿 +传真 +宴会 +羽毛 +软呢帽 +饲料 +一餐 +饲养 +喂养的椅子 +猫科 +美洲狮 +栅栏 +芬达 +蕨类植物 +雪貂 +摩天轮 +渡船 +肥料 +节日 +纤维 +小说 +小说书 +田野/场地/野外 +田间道路 +无花果 +打架 +花样滑冰运动员 +小雕像 +文件 +档案照片 +文件柜 +填满 +胶片相机 +电影导演 +电影格式 +电影首映礼 +电影制片人 +拍摄 +过滤器 +鳍 +手 +终点线 +冷杉 +冷杉树 +火 +火灾报警 +消防部门 +消防车 +消防通道 +消防水带 +火坑 +消防站 +爆竹 +消防队员 +壁炉 +烟花 +烟花表演 +急救箱 +鱼 +鱼船 +海鲜市场 +鱼塘 +鱼缸 +渔夫 +钓鱼 +渔船 +渔网 +钓鱼 +渔村 +健身 +健身课程 +五个 +固定装置 +峡湾 +国旗 +旗杆 +小薄片 +火焰 +火烈鸟 +法兰绒 +拍打 +耀斑 +闪光 +烧瓶 +平 +比目鱼 +风味 +跳蚤 +跳蚤市场 +舰队 +飞行 +空中乘务员 +翻转 +触发器 +翻转图 +浮动 +群 +洪水 +地板/地面 +落地扇 +脚垫 +楼层平面图 +落地窗 +插花艺术 +花店 +牙线 +面粉 +流动 +花 +花篮 +花坛 +花箱 +花田 +花童 +花卉市场 +流体 +冲洗 +长笛 +飞 +飞行钓鱼 +传单 +马 +泡沫 +雾 +多雾的 +鹅肝酱 +箔纸 +折椅 +树叶 +民间艺术家 +民间舞蹈 +民间摇滚艺术家 +方旦糖 +火锅 +圣洗池 +食物 +食用色素 +美食广场 +食品加工机 +小吃摊 +快餐车 +桌上足球 +脚 +人行桥 +足球 +足球教练 +大学橄榄球赛 +足球比赛 +足球场 +足球比赛 +橄榄球头盔 +足球运动员 +足球场 +足球队 +小路 +脚印 +脚踏板 +台座 +鞋子 +故宫 +浅滩 +额头 +森林 +森林大火 +森林地面 +森林小路 +森林公路 +锻造 +餐叉 +叉车 +表格 +园林 +队列/形成物 +F1方程式赛车 +堡垒 +碉堡 +追逐 +化石 +粉底 +喷泉 +钢笔 +狐狸 +框架 +雀斑 +高速公路 +卡车 +法国 +法国斗牛犬 +薯条 +法式吐司 +化妆水 +冰箱 +炸鸡 +煎蛋 +炒饭 +友谊 +飞盘 +青蛙 +霜 +结霜 +严寒 +结冰 +水果 +水果蛋糕 +水果盘 +水果市场 +水果沙拉 +水果摊 +果树 +水果商店 +油炸食品 +煎锅 +软糖 +燃料 +吸烟罩 +有趣的 +葬礼 +真菌 +漏斗 +毛皮衣服 +毛皮大衣 +家具 +蒲团 +小工具 +枪口 +星云/星系 +美术馆 +游戏 +游戏棋盘 +游戏手柄 +火腿 +团伙 +车库 +车库门 +手工模型 +垃圾 +花园 +花园芦笋 +橡胶软管 +花园蜘蛛 +园丁 +园艺 +加菲猫 +滴水嘴 +花环 +大蒜 +衣服 +气体 +加油站 +煤气炉 +防毒面具 +收集 +聚集 +测量仪器 +露台 +齿轮 +壁虎 +艺妓 +凝胶 +百货商店 +发电机 +天竺葵 +幽灵 +礼物 +礼品袋 +礼品篮 +礼物盒 +礼品卡 +礼品商店 +礼物包装 +演唱会 +杜松子酒 +姜 +姜饼 +姜饼屋 +银杏树 +长颈鹿 +女孩 +给 +冰川 +角斗士 +玻璃珠 +玻璃瓶 +玻璃碗 +玻璃箱 +玻璃建筑 +玻璃门 +玻璃地板 +玻璃屋 +玻璃罐 +玻璃板 +玻璃桌子 +玻璃花瓶 +玻璃墙 +玻璃窗 +眼镜 +光滑面 +滑翔机 +地球 +手套 +发光 +汤圆 +去 +袭击 +球门 +守门员 +山羊 +羊奶酪 +戈壁 +护目镜/墨镜 +黄金 +金牌 +金门大桥 +金毛猎犬 +金鱼 +高尔夫运动 +高尔夫球帽 +高尔夫球车 +高尔夫球杆 +高尔夫球场 +高尔夫球手 +鹅 +大猩猩 +哥特式 +葫芦 +政府 +政府机构 +礼服 +毕业生 +毕业典礼 +谷物 +逆戟鲸 +大奖赛 +祖父 +祖母 +祖父母 +花岗岩 +格兰诺拉麦片 +葡萄 +西柚 +葡萄酒 +草 +蚱蜢 +草原 +长满草的 +擦菜器 +坟墓 +碎石 +墓碑 +肉汁 +调味汁瓶 +灰色 +吃草 +放牧 +绿色 +绿色植物 +欢迎 +问候 +贺卡 +灰狗 +网格 +筛子 +烧烤架 +格栅 +烤鳗鱼 +磨 +研磨机 +粗燕麦粉 +杂货袋 +洞穴 +地松鼠 +群体 +合影 +小树林 +生长 +牛油果酱 +警卫 +看门狗 +宾馆 +客房 +指南 +豚鼠 +吉他 +吉他手 +海湾 +海鸥 +枪 +高达 +谒师所 +古筝 +健身房 +体操运动员 +栖息地 +黑客 +冰雹 +头发 +头发颜色 +发胶 +毛刷 +发型 +发夹 +发网 +发夹 +发型 +一半 +礼堂 +万圣节 +万圣节服装 +万圣节南瓜 +露背装 +汉堡 +汉堡包 +哈密瓜 +锤子 +吊床 +阻碍 +仓鼠 +烘手机 +放大镜 +擦手巾 +手提包 +手球 +手铐 +手枪 +手帕 +把手 +手锯 +握手 +倒立 +手写 +汉服 +悬挂 +飞机库 +衣架 +幸福 +海港 +斑海豹 +硬摇滚艺术家 +精装书 +建筑工人 +硬件 +五金店 +硬木 +硬木地板 +口琴 +管风琴 +羽管键琴 +收获 +收割机 +坐垫/搁脚凳/草丛 +帽子 +帽盒 +双簧管 +山楂 +干草 +干草地 +榛子 +头 +主教练 +大灯 +床头板 +头饰 +海岬 +总部 +听力 +心脏 +心形 +热能 +加热器 +帚石楠 +树篱 +刺猬 +脚后跟 +直升机 +直升机机场 +头盔 +帮助 +母鸡 +指甲花 +药草 +兽群 +寄居蟹 +英雄 +苍鹭 +芙蓉花 +芙蓉花 +隐藏/隐蔽处 +高杠 +高跟鞋 +高地 +突出 +徒步旅行 +徒步旅行者 +徒步靴 +登山设备 +山丘 +丘陵地 +别墅 +山坡 +印度教寺庙 +铰链 +臀部 +嘻哈艺人 +河马 +历史学家 +历史遗迹 +历史 +曲棍球 +冰球馆 +曲棍球比赛 +曲棍球运动员 +曲棍球棒 +锄头 +洞 +假日 +冬青树 +海参 +家/住宅 +家用电器 +基地 +家居装饰 +室内设计 +内政部 +家庭影院 +家庭作业 +鹰嘴豆泥 +蜂蜜 +蜂窝 +蜜月 +风帽 +连帽衫 +挂钩/勾住 +跳 +地平线 +犀鸟 +长角牛 +大黄蜂 +震惊 +恐怖电影 +马鞍褥 +马车 +马场 +骑马 +马背 +马蹄铁 +软管 +医院 +医院病床 +病房 +主持人 +小旅馆 +热 +热气球 +热狗 +辣椒酱 +温泉 +旅馆 +酒店大堂 +酒店房间 +电炉 +沙漏 +房子 +房子外部 +室内植物 +悬滑板 +吼 +蜷缩 +拥抱 +呼啦圈 +人 +增湿器 +蜂鸟 +座头鲸 +打猎 +狩猎小屋 +障碍 +飓风 +哈士奇 +小屋 +鬣狗 +混合物 +绣球花 +消火栓 +水上飞机 +冰 +冰袋 +北极熊 +冰洞 +冰淇淋 +冰淇淋蛋卷 +冰淇淋商店 +冰块 +浮冰 +冰球运动员 +冰球队 +棒棒糖 +制冰机 +溜冰场 +冰雕 +冰架 +溜冰鞋 +滑冰 +冰山 +冰柱 +糖衣/酥皮 +图标 +身份证照片 +身份证 +冰屋 +光/灯光/光线 +鬣蜥蜴 +照亮 +插图 +形象 +黑斑羚 +熏香 +独立日 +个人 +室内 +划船器 +电磁炉 +工业区 +工业 +步兵 +充气艇 +服务台 +基础设施 +成分 +吸入器 +注射 +受伤 +墨水 +印泥 +小湖湾 +题词 +昆虫 +安装 +乐器/器械 +绝缘杯 +互动 +室内设计 +网站 +十字路口 +面试 +无脊椎动物 +邀请 +平板电脑 +苹果手机 +苹果音乐播放器 +虹膜 +铁 +熨衣板 +灌溉系统 +岛 +小岛 +等足类动物 +象牙 +常青藤 +居酒屋 +千斤顶 +帝王蟹/蟹 +夹克衫 +按摩浴缸 +玉 +美洲虎 +监狱牢房 +果酱 +日式花园 +茉莉花 +下巴 +松鸦 +爵士乐 +爵士乐艺术家 +爵士融合艺术家 +牛仔裤 +吉普车 +果冻 +果冻豆 +水母 +喷气式飞机 +摩托艇 +珠宝 +珠宝 +珠宝店 +拼图游戏 +人力车 +赛马骑师 +赛马帽 +慢跑 +联合的 +记者 +操纵杆 +法官 +水壶 +玩杂耍 +果汁 +榨汁器 +枣子 +跳绳 +连身裤 +丛林 +废品堆放场 +羽衣甘蓝 +万花筒 +袋鼠 +卡拉ok +空手道 +卡丁车运动 +旧城区 +皮船 +烤肉串 +按键/钥匙 +门卡 +卡其色 +踢 +苏格兰裙 +和服 +幼儿园教室 +幼儿园 +国王 +帝王蟹 +亲吻 +工具包 +厨房 +厨房橱柜 +厨房台面 +厨房地板 +厨房抽油烟机 +厨房岛 +厨房水槽 +厨房桌子 +厨房用具 +厨房窗户 +厨房用具 +风筝 +猕猴桃 +护膝 +跪下 +餐刀 +骑手 +编织 +编织针 +球形把手 +门环 +结 +考拉 +锦鲤 +ktv +实验室 +实验室外套 +标签 +拉布拉多 +迷宫 +网眼织物 +蕾丝连衣裙 +梯子 +长柄杓 +瓢虫 +环礁湖 +湖泊 +湖区 +湖边小屋 +湖岸 +羊肉 +羊排 +灯柱 +灯罩 +矛 +土地 +陆地车辆 +废物填埋 +着陆 +降落甲板 +地标 +风景 +山崩 +挂带 +灯笼 +腿/大腿 +笔记本电脑 +笔记本键盘 +幼体 +烤宽面条 +激光 +睫毛 +套索 +门闩 +乳胶 +拿铁咖啡 +笑 +发射 +发布会 +举办会议 +自助洗衣店 +洗衣房 +洗衣篮 +洗衣房 +熔岩 +薰衣草 +草坪 +草坪婚礼 +律师 +躺 +引领 +主唱 +通向 +领袖 +泄漏 +倾斜/倚靠 +学习 +皮带 +皮革 +皮夹克 +皮鞋 +演讲 +演讲厅 +教学室 +窗台 +剩饭 +腿 +传说 +紧身裤/秋裤 +立法院 +乐高 +豆类 +柠檬 +柠檬汁 +柠檬水 +狐猴 +镜头 +眩光 +扁豆 +豹 +紧身连衣裤 +紧身裤袜 +小妖精 +课程 +信函 +信箱 +信的标志 +刻字 +生菜 +水平 +图书馆 +许可证 +车牌 +地衣 +舔 +盖子 +躺着 +安全带 +救生衣 +救生艇 +救生员 +提起 +灯具 +灯光秀 +电灯开关 +照明/照明设备 +闪电 +避雷针 +淡紫色 +百合 +肢体 +石灰 +石灰石 +豪华轿车 +线条 +艺术线条 +排队 +亚麻 +邮轮 +狮子 +润唇膏 +口红 +液体 +酒类商店 +列表 +荔枝 +生活 +家畜 +客厅 +生活空间 +蜥蜴 +负载 +装卸码头 +游手好闲的人 +走廊 +定位 +锁 +闸室 +储物柜 +阁楼 +原木 +小木屋 +标志 +洛基 +长头发 +冲浪板 +隐约显现/织布机 +环状 +遗失 +彩票 +莲花 +爱 +双人沙发 +行李 +木材 +伐木工人 +午餐 +午餐盒 +郁郁葱葱的 +奢侈品 +豪华游艇 +雨衣 +澳洲胡桃 +短尾猿 +通心粉 +金刚鹦鹉 +弯刀 +机器 +机枪 +杂志 +魔法 +魔术师 +磁铁 +放大镜 +木兰花 +喜鹊 +麻将 +象夫 +女仆 +邮件 +邮件槽 +制作 +改造 +化妆师 +化妆工具 +野鸭 +野鸭 +槌棒 +哺乳动物 +猛犸象 +男人 +管理 +经理 +海牛 +曼荼罗 +橘子 +普通话 +鬃毛 +漫画 +食槽 +芒果 +山竹果 +红树林 +曼哈顿 +检修孔 +井盖 +修指甲 +人体模型 +庄园主宅 +大厦 +螳螂 +地幔 +活动房层 +制造业 +手稿 +地图 +枫木 +枫叶 +枫糖浆 +沙球 +马拉松 +大理石 +行进 +行进乐队 +母马 +金盏花 +水兵 +海洋无脊椎动物 +海洋哺乳动物 +木偶 +标志 +集市 +市场广场 +市场摊位 +结婚 +武术 +武术家 +武术馆 +马提尼 +马丁尼酒杯 +睫毛膏 +吉祥物 +土豆泥 +搅碎机 +面具/口罩 +按摩 +桅杆 +地垫 +斗牛士 +比赛 +火柴盒 +衣料 +床垫 +陵墓 +长裙 +一餐 +量杯 +卷尺 +肉类 +肉丸 +机械师 +机械风扇 +奖牌 +媒体 +医疗设备 +医学图像 +医务人员 +医药箱 +中世纪的 +麦地那市 +冥想 +猫鼬 +赛事 +香瓜 +纪念碑 +菜单 +美人鱼 +网 +肮脏 +信使袋 +金属 +金属艺术家 +金属探测器 +计量器 +中层楼 +麦克风 +显微镜 +微波炉 +午夜 +里程碑 +军装 +牛奶 +牛奶罐 +奶茶 +奶昔 +磨坊 +矿井 +矿工 +矿物质 +矿泉水 +迷你 +微缩模型 +面包车 +部长 +小型货车 +薄荷 +薄荷糖 +镜子 +小姐 +投掷物 +任务 +槲寄生 +混合 +搅拌机 +搅拌碗 +混合物 +护城河 +电动踏板车 +模型/模特 +汽车模型 +现代 +现代大厦 +潮湿 +模具 +模具 +鼹鼠 +君主 +钱 +监控器 +和尚 +猴子 +活动扳手 +黑白照片 +独轮脚踏车 +怪物卡车 +月亮 +月饼 +月光 +沼泽 +驼鹿 +拖把 +助力车 +早晨 +晨雾 +晨光 +朝阳 +砂浆 +马赛克 +清真寺 +蚊子 +藓类植物 +汽车旅馆 +蛾 +母亲 +主板 +主题 +动作 +电动机 +摩托车 +摩托车 +摩托车头盔 +摩托车赛车手 +骑摩托车的人 +赛车运动 +土堆 +山 +山地自行车 +山地自行车员 +山地自行车运动 +山地大猩猩 +山湖 +山景观 +山口 +山路 +山脉 +山区河流 +山雪 +山间溪流 +山景城 +山村 +登山者 +登山包 +鼠标/鼠 +鼠标垫 +捕鼠器 +嘴 +漱口水 +移动 +电影海报 +电影票 +割草机 +mp3播放器 +先生 +泥 +松饼 +马克杯 +桑树 +覆盖物 +骡子 +直辖市 +壁画 +肌肉 +肌肉车 +博物馆 +蘑菇 +音乐 +音乐节 +音乐凳子 +音乐工作室 +音乐录影带表演者 +音乐键盘 +音乐家 +贻贝 +芥末 +神话 +烤干酪辣味玉米片 +指甲油 +指甲锉 +保姆 +餐巾 +狭窄的 +国旗 +基督诞生的场景 +自然历史博物馆 +自然 +自然保护区 +导航 +九夜节 +海军 +星云 +脖子 +围颈带/领口 +项链 +领口 +花蜜 +油桃 +针状物 +邻居 +与某处邻近的地区 +霓虹灯 +霓虹灯 +神经 +巢 +新年 +新生的 +纽芬兰 +新婚 +新闻 +记者招待会 +报摊 +晚上 +夜市 +夜空 +夜景 +夜总会 +床头柜 +面条 +鼻子 +鼻羁 +注解 +笔记本 +记事本 +信纸 +公告 +数字图标 +修女 +护士 +托儿所 +养老院 +螺母 +胡桃夹子 +橡木 +橡树 +桨 +绿洲 +烘干室 +燕麦片 +燕麦 +方尖塔 +观察塔 +天文台 +超越障碍训练场 +海洋 +章鱼 +提供 +办公室 +办公大楼 +办公椅 +办公室隔间 +办公桌 +办公用品 +办公室的窗户 +军官 +行政官员 +石油 +油灯 +油画 +石油钻台 +秋葵 +老照片 +橄榄 +橄榄油 +橄榄树 +煎蛋卷 +洋葱 +洋葱圈 +蛋白石 +开阔的/张开 +开始 +开幕式 +歌剧 +歌剧院 +操作 +手术室 +操作 +眼镜店 +猩猩 +橙子/橙色 +橙汁 +橙树 +橘园 +轨道 +果园 +乐池 +兰花 +订单 +组织 +折纸 +点缀 +鱼鹰 +鸵鸟 +水獭 +外面的 +露头 +户外 +厕所 +电源插头 +大纲 +椭圆形 +烤箱 +整体 +大衣 +天桥 +猫头鹰 +牡蛎 +橡皮环 +包裹 +包/包装/包裹 +围场 +警车 +挂锁 +肉菜饭 +宝塔 +疼痛 +油漆刷 +画家 +佩斯利印花大手帕 +宫殿 +调色板 +栅栏 +棺罩 +棕榈树 +平底锅 +煎饼 +熊猫 +面板 +全景 +三色堇 +喘息 +储藏室 +裤子 +连裤袜 +木瓜 +纸 +纸袋 +切纸机 +纸灯笼 +纸盘子 +纸巾 +平装书 +压纸器 +降落伞 +游行 +天堂 +鹦鹉 +护理人员 +长尾小鹦鹉 +滑翔伞 +伞兵 +羊皮纸 +教区 +公园 +公园长椅 +停车 +停车场 +停车费 +停车标志 +议会 +欧芹/香菜 +参与者 +合作伙伴 +帕特里奇 +聚会 +派对帽 +通过 +通道 +存折 +乘客 +客船 +旅客列车 +百香果 +护照 +面食 +粘贴 +糕点 +牧场 +补丁 +病人 +图案/款式 +人行道/硬路面 +大帐篷 +爪子 +支付 +付费电话 +豌豆 +和平 +桃子 +孔雀 +山峰/尖顶 +花生 +花生酱 +梨 +珍珠 +卵石 +山核桃 +行人 +人行天桥 +步行街 +果皮 +削皮器 +小钉板 +木质腿 +鹈鹕 +笔/围栏 +点球 +铅笔 +铅笔盒 +卷笔刀 +铅笔裙 +吊坠 +钟摆 +企鹅 +半岛 +锦标旗 +便士 +储蓄罐 +牡丹 +胡椒/辣椒 +胡椒研磨机 +胡椒子 +意大利辣香肠 +栖息/鲈鱼 +表演 +表演 +表演舞台 +香水 +绿廊 +波斯猫 +柿子 +个人护理 +个人漂浮装置 +害虫 +宠物 +宠物店 +宠物店 +花瓣 +佩妮 +教堂的长椅 +野鸡 +现象 +哲学家 +电话 +电话簿 +留声机 +照片 +照相亭 +相框 +摄影 +物理学家 +物理实验室 +钢琴家 +钢琴 +选择 +捡起 +泡菜 +野餐 +野餐区 +野餐篮 +野餐桌 +图片 +相框 +馅饼 +鸽子 +朝圣者 +药片 +枕头 +飞行员 +领航艇 +别针 +松树 +松果 +松林 +松子 +菠萝 +乒乓球桌 +乒乓球 +粉色 +一品脱的量 +琵琶 +管子 +管碗 +海盗 +海盗旗 +海盗船 +阿月浑子 +滑雪场 +口袋里的面包 +火龙果 +斗牛犬 +球场 +大水罐 +猪笼草 +干草叉 +披萨 +披萨刀 +比萨锅 +披萨店 +招牌 +地方 +餐具垫 +格子 +平原 +示意图 +行星 +行星地球 +厚木板 +植物 +种植园 +种植 +匾额 +石膏 +塑料 +橡皮泥 +高原 +平台 +白金 +大浅盘 +玩/演奏/运动 +打羽毛球 +打棒球 +打篮球 +玩台球 +踢足球 +玩乒乓球 +打网球 +打排球 +选手/运动员 +操场 +剧场 +扑克牌 +下棋 +打高尔夫球 +打麻将 +运动场 +护栏 +游戏室 +广场 +钳子 +故事情节 +犁 +插头 +插头帽 +李子 +水管工 +卫生洁具 +羽毛 +夹板 +口袋 +怀表 +随身小折刀 +圆荚体 +乐队指挥台 +诗歌 +一品红 +指/朝向 +指针 +扑克卡 +筹码 +扑克表 +杆/柱 +臭猫 +警察 +警车 +警犬 +警察局 +政治家 +圆点 +花粉 +污染 +马球 +马球领 +马球衬衫 +石榴 +波美拉尼亚的 +雨披 +池塘 +马尾辫 +贵宾犬 +池 +流行 +流行艺术家 +爆米花 +教皇 +罂粟 +瓷 +玄关 +猪肉 +粥 +便携式电池 +门户网站 +投资组合 +汽门 +肖像 +肖像会话 +摆姿势拍照 +负鼠 +帖子 +邮局 +邮票 +明信片 +海报 +海报页 +锅/罐/陶盆 +土豆 +土豆片 +土豆沙拉 +布垫子 +便壶 +袋 +家禽 +英镑 +倾泻 +粉末 +电源线 +电源插头及插座 +权力看 +电站 +练习 +布拉格城堡 +祈祷 +牧师 +首映 +处方 +显示 +演讲 +总统 +新闻发布室 +高压锅 +椒盐卷饼 +王子 +公主 +打印 +打印页面 +打印机 +印刷 +监狱 +农产品/生产 +产品 +职业 +专业的 +教授 +项目图片 +投影屏幕 +投影仪 +毕业舞会 +散步 +螺旋桨 +先知 +建议 +防护服 +抗议 +抗议者 +出版 +宣传画像 +冰上曲棍球 +布丁 +水坑 +泡芙 +角嘴海雀 +哈巴狗 +拉 +讲坛 +脉冲 +泵 +南瓜 +南瓜饼 +南瓜种子 +拳击吊袋 +拳头猛击/穿孔 +学生 +紫色 +推 +轻轻一击 +谜题 +塔 +金字塔 +大蟒 +二维码 +鹌鹑 +采石场 +季度 +石英 +女王 +油炸玉米粉饼 +队列 +乳蛋饼 +被子 +绗缝 +引用 +兔子 +浣熊 +比赛 +赛道 +水沟/跑道 +赛车 +球拍 +雷达 +散热器 +广播 +木筏/橡皮艇 +布娃娃 +栏杆/铁轨 +轨道车 +铁道 +铁路桥梁 +轨道线 +火车站 +雨 +雨靴 +彩虹 +虹鳟鱼 +雨衣 +热带雨林 +多雨的 +葡萄干 +耙子 +公羊 +斜坡 +油菜籽 +快速 +说唱歌手 +树莓 +老鼠 +棘轮 +乌鸦 +峡谷 +雷 +剃须刀 +锋利的 +阅读 +阅读材料 +钻孔器 +后面 +尾灯 +后视图 +后视镜 +收据 +收到 +接待 +配方 +记录 +唱片制作人 +记录器/竖笛 +录音室 +娱乐室 +休闲车 +矩形 +回收 +回收站 +红色 +红地毯 +红旗 +红熊猫 +红酒 +红木 +芦苇 +礁石 +卷轴 +裁判 +倒影 +倒影 +反射器 +注册 +控制 +驯鹿 +放松 +释放 +救援 +宗教 +宗教的 +享受 +保持 +改造 +遥控器 +移除 +修复 +维修店 +爬行动物 +救援 +救助者 +研究 +研究员 +储层 +住宅 +居民区 +树脂 +度假胜地 +度假小镇 +餐厅的厨房 +餐厅的露台 +厕所 +零售 +寻回犬 +制动火箭 +揭示 +犀牛 +杜鹃 +肋骨 +丝带 +大米 +电饭煲 +稻田 +骑/搭乘 +脊 +骑马 +步枪 +边缘 +环/戒指 +暴乱 +涟漪 +上升 +高层建筑 +河 +河岸 +河船 +河谷 +河床 +路 +路标 +公路旅行 +路边 +烤鸡 +长袍 +罗宾 +机器人 +石头 +岩石拱 +摇滚艺术家 +摇滚乐队 +攀岩者 +攀岩 +摇滚音乐会 +岩石表面 +岩层 +摇滚歌手 +火箭 +摇椅 +岩石 +啮齿动物 +牛仔竞技表演 +竞技舞台 +罗伊 +狍子 +辊 +过山车 +轮式溜冰鞋 +溜冰鞋 +擀面杖 +浪漫 +浪漫的 +屋顶 +屋顶花园 +房间 +房间分频器 +根 +根啤酒 +绳索桥 +念珠 +玫瑰 +迷迭香 +玫瑰色的云 +罗特韦尔犬 +圆桌 +路由器 +行 +罗文 +皇家 +橡皮图章 +废墟 +魔方 +红宝石 +莱夫 +橄榄球 +橄榄球 +橄榄球运动员 +毁坏 +尺 +朗姆酒 +跑 +跑步者 +跑步鞋 +农村的 +锈 +乡村的 +黑麦 +袋 +鞍 +鞍囊 +旅行 +安全 +安全背心 +圣人 +帆 +帆船 +航行 +水手 +松鼠猴 +缘故 +沙拉 +沙拉碗 +火蜥蜴 +意大利蒜味腊肠 +出售 +三文鱼 +沙龙 +萨尔萨舞 +盐 +盐和胡椒瓶 +盐湖 +盐沼 +盐瓶 +敬礼 +萨莫耶德人 +武士 +沙子 +沙洲 +砂箱 +沙堡 +沙雕 +凉鞋 +三明治 +卫生巾 +圣诞老人 +蓝宝石 +沙丁鱼 +莎丽 +生鱼片 +沙爹 +书包 +卫星 +缎 +酱汁 +碟子 +桑拿 +香肠 +稀树大草原 +锯 +锯木架 +萨克斯管 +萨克斯手 +脚手架 +秤/标尺 +比例模型 +扇贝 +疤痕 +稻草人 +围巾 +场景 +风景 +雪纳瑞犬 +学校 +校车 +校服 +校舍 +纵帆船 +科学 +科幻电影 +科学博物馆 +科学家 +剪刀 +壁灯 +司康饼 +勺子 +踏板车/摩托车 +分数 +记分板 +蝎子 +童子军 +炒蛋 +废弃 +刮板 +刮伤 +屏幕 +纱门 +截图 +螺杆 +螺丝刀 +长卷纸/卷轴 +擦洗 +硬毛刷 +雕塑家 +雕塑 +海洞穴 +海冰 +海狮 +海龟 +海胆 +尖吻鲈 +海底 +海鸟 +海鲜 +海马 +海豹 +海景 +海贝 +海滨度假胜地 +季节 +座位 +安全带 +海藻 +秘书 +安全 +小轿车 +看到 +种子 +跷跷板 +赛格威 +自拍 +出售 +研讨会 +感觉 +传感器 +服务器 +服务器机房 +服务 +集 +缝纫机 +影子 +摇 +瓶 +洗发水 +形状 +分享 +鲨鱼 +卷笔刀 +记号笔 +剃须刀 +剃须膏 +披肩/围巾 +剪切 +剪刀 +羊 +床单 +乐谱 +架子 +贝壳 +贝类 +避难所 +搁置 +牧羊人 +果子露 +柴犬 +发光 +航运 +集装箱 +海难 +船厂 +衬衫 +赤膊的 +浅滩 +鞋 +鞋盒 +鞋店 +鞋楦 +射击 +得分篮球后卫 +商店橱窗 +门面 +购物者 +购物 +购物袋 +购物篮 +购物车 +购物中心 +购物街 +海岸 +海岸线 +短的 +短发 +短裤 +小酒杯 +散弹枪 +肩膀 +单肩包 +铲 +陈列柜 +淋浴 +浴帽 +浴帘 +淋浴门 +淋浴头 +碎纸机 +泼妇 +虾 +神社 +灌木 +快门 +暹罗猫 +西伯利亚 +兄弟姐妹 +侧面 +边柜 +配菜 +边车 +边线 +壁板 +标志 +指示牌 +信号 +签名 +丝绸 +丝袜 +筒仓 +银 +银牌 +银器 +唱歌 +烧焦 +歌手 +水槽 +啜 +坐/放置/坐落 +坐着 +滑板公园 +滑板 +滑板者 +溜冰者 +溜冰场 +骨架 +草图 +串串 +滑雪 +滑雪靴 +滑雪设备 +滑雪服 +滑雪缆车 +滑雪杖 +滑雪胜地 +滑雪板 +滑雪 +滑雪鞋 +皮肤 +头骨 +无边便帽 +天空 +天空塔 +天窗 +天际线 +摩天大楼 +激流回旋 +石板 +雪橇 +睡眠 +睡袋 +睡衣 +袖子 +片 +滑动 +滑块 +吊索 +坡 +投币口 +老虎机 +树懒 +慢炖锅 +鼻涕虫 +贫民窟 +气味 +微笑 +烟雾/抽烟 +零食 +蜗牛 +蛇 +鲷鱼 +快照 +通气管 +鼻子 +雪 +雪豹 +雪山 +雪球 +单板滑雪者 +雪原 +雪花 +雪人 +雪地摩托 +雪犁 +雪鞋 +雪 +肥皂 +肥皂泡 +给皂器 +足球守门员 +社会名流 +短袜 +插座 +苏打水 +垒球 +软件 +太阳能电池阵列 +士兵 +独奏 +解决方案 +宽边帽 +歌曲 +声音 +汤 +汤碗 +汤匙 +酸奶油 +纪念品 +豆浆 +水疗中心 +空间 +航天飞机 +空间站 +宇宙飞船 +意大利面 +横跨 +扳手 +火花 +闪耀 +烟火 +起泡葡萄酒 +麻雀 +抹刀 +扬声器 +观众 +会话框 +速度限制 +限速标志 +快艇 +车速表 +球 +香料 +调料架 +蜘蛛 +蜘蛛网 +扣球 +旋转 +菠菜 +尖塔 +飞溅 +海绵 +勺子 +体育协会 +运动器材 +运动团队 +体育球 +体育器材 +运动会 +运动服装 +点 +喷雾 +伸展 +春天 +春卷 +撒 +洒水器 +发芽 +云杉 +云杉森林 +队 +广场 +南瓜 +蹲 +挤 +鱿鱼 +松鼠 +水枪 +刺 +稳定的 +(码放整齐的)一叠 +体育场 +工作人员 +舞台 +舞台灯 +驿马车 +弄脏 +不锈钢 +楼梯 +楼梯 +楼梯间 +摊位/小隔间 +种马 +站/矗立/摊位 +站 +主食 +订书机 +星星 +盯着 +海星 +杨桃 +燕八哥 +州立公园 +公立学校 +车站 +固定自行车 +文具 +雕像 +牛排 +牛排刀 +蒸汽 +蒸汽机 +蒸汽机车 +蒸汽火车 +馒头 +钢 +方向盘 +(花草的)茎 +模版 +梯凳 +立体声 +听诊器 +炖 +戳/条状物 +竹节虫 +贴纸 +静物画 +高跷 +黄貂鱼 +搅拌 +搅拌器 +镫 +缝 +股票 +长筒袜 +腹部 +石头建筑 +石雕 +石屋 +石磨 +凳子 +停止 +停在 +红灯 +停车标志 +秒表 +红绿灯 +存储箱 +储藏室 +罐/蓄水池 +商店 +店面 +鹳 +风暴 +暴风云 +狂风暴雨的 +炉子 +扑克 +跨骑 +过滤器 +海峡 +带 +稻草/吸管 +草帽 +草莓 +溪流 +街头艺术 +街头艺术家 +街角 +流浪狗 +街头食品 +路灯 +街市场 +街头摄影 +街景 +路标 +街头小贩 +拉伸 +担架 +罢工 +前锋 +细绳 +芝士条 +带子 +条纹 +漫步 +结构 +工作室 +影棚拍摄 +材料 +填充玩具动物 +毛绒玩具 +馅 +树桩 +惊人的 +特技 +佛塔 +风格 +手写笔 +潜艇 +潜艇形大三明治 +海底水 +郊区 +地铁 +地铁站 +低音炮 +多肉 +绒面革 +糖 +糖碗 +甘蔗 +方糖 +西装 +套房 +夏天 +夏天傍晚 +峰顶 +太阳 +太阳帽 +日光浴 +周日 +日晷 +向日葵 +向日葵田 +葵花籽 +太阳镜 +晴天 +日出 +日落 +遮阳伞 +阳光 +超级碗 +跑车 +超级英雄 +超市 +超市货架 +超模 +支持者 +冲浪 +表面 +冲浪板 +冲浪者 +外科医生 +外科手术 +环绕 +寿司 +寿司吧 +背带裤 +悬架 +吊桥 +越野车 +燕子 +燕尾蝶 +沼泽 +天鹅 +天鹅游艇 +运动裤 +防汗带 +毛衣 +运动衫 +甜的 +红薯 +游泳 +泳帽 +游泳者 +游泳洞 +游泳池 +摆动 +平转桥 +秋千 +漩涡 +开关 +转椅 +剑 +旗鱼 +象征 +对称 +犹太教堂 +注射器 +糖浆 +系统 +t恤 +t恤 +塔巴斯科辣椒酱 +虎斑 +乒乓球拍 +桌面 +桌布 +平板电脑 +餐具 +转速表 +拦截 +墨西哥煎玉米卷 +跆拳道 +太极 +尾巴 +裁缝 +拍/拿 +起飞 +说话/交谈/演讲 +手鼓 +棕褐色 +橘子 +胶带/磁带/终点线 +挂毯 +沥青碎石路面 +芋头 +篷布 +果馅饼 +流苏 +味道 +榻榻米 +纹身 +纹身艺术家 +酒馆 +茶 +茶包 +茶话会 +茶园 +茶壶 +茶具 +教 +老师 +茶杯 +水鸭 +团队合影 +团队介绍 +眼泪/撕裂/划破 +技术员 +技术 +泰迪熊 +T字形物 +青少年 +电线杆 +变焦镜头 +望远镜 +电视 +电视摄像机 +电视室 +电视演播室 +温度 +寺庙 +天妇罗 +网球 +网球场 +网球比赛 +网球网 +网球运动员 +网球拍 +帐篷 +龙舌兰酒 +终端/航站楼 +阳台 +地形 +玻璃容器 +领土 +测试 +测试赛 +试管 +文本 +短信 +纺织 +纹理 +感恩节 +感恩节晚餐 +剧院 +戏剧演员 +治疗 +温度计 +热水瓶 +暖瓶 +恒温器 +灌木丛 +顶针 +东西 +思考 +蓟 +宝座 +金銮殿 +扔 +抱枕 +雷 +雷雨 +百里香 +皇冠 +记号 +票 +售票亭 +潮池 +领带 +老虎 +紧 +瓦 +瓷砖地板 +瓦屋顶 +瓷砖墙 +锡 +锡纸 +箔 +提拉米苏 +轮胎 +纸巾 +烤面包 +烤面包机 +烟草 +烟斗 +学步的小孩 +脚趾 +豆腐 +马桶 +马桶座圈 +化妆包 +东京铁塔 +番茄 +番茄酱 +番茄汤 +墓 +钳子 +钳子 +工具 +工具箱 +牙刷 +牙膏 +牙签 +修剪成形的花园 +配料 +火炬/光源 +龙卷风 +玉米粉圆饼 +乌龟 +大手提袋 +图腾柱 +龙猫 +巨嘴鸟 +触摸 +触地 +旅行 +旅游巴士 +导游 +游客 +旅游景点 +锦标赛 +拖车 +毛巾 +毛巾杆 +大厦 +塔桥 +小镇 +城镇广场 +玩具 +玩具车 +玩具枪 +玩具店 +跑道 +拖拉机 +贸易 +传统 +传统的 +交通 +锥形交通路标 +交通拥堵 +交通堵塞 +交通标志 +小道 +预告片 +拖车 +火车 +火车桥 +火车车厢 +火车内部 +火车轨道 +火车窗口 +教练 +训练 +训练长椅 +训练场 +电车/手推车 +蹦床 +变形金刚 +透明度 +旅行 +托盘/碟子 +跑步机 +美食 +树 +树枝 +林场 +树蛙 +树屋 +树根 +树干 +试验 +三角形 +铁人三项 +部落 +支流 +戏法/特技 +三轮车 +修剪 +三人组 +三脚架 +长号 +部队 +奖杯 +奖杯 +热带 +鳟鱼 +卡车 +卡车司机 +浴缸 +管子 +拖船 +郁金香 +金枪鱼 +苔原 +隧道 +涡轮 +火鸡 +转动 +芜菁 +绿松石 +炮塔 +乌龟 +獠牙 +电视演员 +电视柜 +电视剧 +电视节目类型 +电视名人 +电视节目 +情景喜剧 +电视塔 +枝条 +黄昏 +双胞胎 +麻线 +扭 +类型 +键入 +打字机 +尤克里里 +奥特曼 +伞 +内衣 +水下 +独角兽 +制服 +宇宙 +大学 +向上 +城市 +尿壶 +瓮 +使用 +用具 +杂物间 +吸尘器/真空 +谷 +阀门 +吸血鬼 +货车 +香草 +虚荣 +种类 +花瓶/瓶 +金库 +矢量卡通插图 +矢量图标 +蔬菜 +菜园 +蔬菜市场 +植被 +车辆 +面纱 +静脉 +天鹅绒 +自动售货机 +小贩 +通风孔 +胡蜂属 +船 +背心 +兽医 +经验丰富的 +兽医办公室 +高架桥 +视频 +摄像机 +电子游戏 +录像带 +视镜 +守夜 +别墅 +村庄 +藤蔓 +醋 +葡萄园 +暴力 +紫罗兰色 +小提琴 +小提琴家 +中提琴演奏者 +愿景 +遮阳板 +伏特加 +火山 +排球 +排球场 +排球运动员 +志愿者 +航行 +秃鹰 +华夫饼干 +华夫饼机 +货车 +马车车轮 +腰 +服务员 +候机室 +等候室 +走 +步行 +手杖 +挂钟 +壁纸 +核桃 +海象 +战争 +仓库 +温暖的 +警告标志 +战士 +军舰 +疣猪 +洗 +洗衣机/垫圈 +洗 +洗衣机 +黄蜂 +浪费 +废物容器 +手表 +水 +水鸟 +水牛 +水冷却器 +水滴 +水景 +热水器 +水位 +荷花 +水上乐园 +水管 +净水器 +滑水板 +水上运动 +水面 +水塔 +水彩 +水彩插图 +水彩画 +瀑布 +喷壶 +水印叠加图章 +西瓜 +防水外套 +水路 +波浪 +蜡 +武器 +穿着 +天气 +叶片 +网 +摄像头 +婚礼 +结婚戒指 +婚礼花束 +结婚蛋糕 +新婚夫妇 +婚礼请柬 +婚礼派对 +婚纱照 +婚礼摄影师 +婚纱摄影 +婚宴 +楔 +杂草 +重量 +体重秤 +焊接工 +井 +西餐 +西餐厅 +湿 +吧台 +潜水衣 +湿地 +潜水服 +鲸鱼 +鲸鲨 +小麦 +麦田 +车轮 +轮椅 +后轮支撑车技 +生奶油 +搅拌器 +胡须 +威士忌 +哨子 +白色 +白宫 +白葡萄酒 +白板 +便门 +宽的 +挥动 +假发 +Wii +Wii手柄 +荒野 +角马 +野火 +野花 +野生动物 +柳树 +风 +风铃 +风电场 +风力涡轮机 +风车 +窗户 +窗台花盆箱 +橱窗展示 +窗框 +纱窗 +靠窗的座位 +窗台 +雨刮器 +挡风玻璃 +有风的 +酒瓶 +冷酒器 +酒柜 +酒窖 +酒杯 +酒架 +品酒 +酒庄 +翅膀 +冬天 +冬瓜 +冬天的早晨 +冬季场景 +冬季运动 +冬季风暴 +电线 +紫藤 +巫婆 +女巫帽子 +炒锅 +狼 +女人 +木头 +林鸳鸯 +木地板 +木墙 +烧木炉 +木匙 +林地 +啄木鸟 +木工刨 +羊毛 +工作 +练习卡 +工作台 +工人 +工作场所 +车间 +世界 +蠕虫 +敬拜 +伤口 +包 +裹身裙 +包装纸 +搏斗 +摔跤手 +皱纹 +腕带 +写 +作家 +手写/字迹 +毛笔 +写字桌 +游艇 +牦牛 +院子 +黄色 +瑜伽 +瑜伽垫 +酸奶 +轭 +蛋黄 +青年 +青年旅馆 +蒙古包 +斑马 +斑马线 +禅意花园 +拉链 +拉链 +僵尸 +粽子 +动物园 diff --git a/ram/data/ram_tag_list_threshold.txt b/ram/data/ram_tag_list_threshold.txt new file mode 100644 index 0000000000000000000000000000000000000000..0472b23c25903900c0dde68fffc9a6a6755f5117 --- /dev/null +++ b/ram/data/ram_tag_list_threshold.txt @@ -0,0 +1,4585 @@ +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.71 +0.75 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.61 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.7 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.82 +0.8 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.85 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.77 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.89 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.78 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.9 +0.65 +0.83 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.79 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.79 +0.65 +0.63 +0.65 +0.87 +0.8 +0.46 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.8 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.84 +0.65 +0.65 +0.79 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.81 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.83 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.77 +0.87 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.85 +0.65 +0.68 +0.65 +0.8 +0.65 +0.65 +0.75 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.8 +0.8 +0.79 +0.65 +0.85 +0.65 +0.65 +0.65 +0.9 +0.65 +0.89 +0.8 +0.65 +0.65 +0.65 +0.76 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +1 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.89 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.71 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.8 +0.8 +0.9 +0.65 +0.85 +0.8 +0.8 +0.8 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.75 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.63 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.71 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.9 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.71 +0.65 +0.8 +0.76 +0.85 +0.8 +0.65 +0.65 +0.8 +0.65 +0.79 +0.65 +0.75 +0.65 +0.8 +0.65 +0.86 +0.65 +0.65 +0.9 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.73 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.8 +0.75 +0.65 +0.65 +0.65 +0.65 +0.8 +0.85 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.6 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.74 +0.65 +0.65 +0.67 +0.65 +0.65 +0.8 +0.65 +0.65 +0.85 +0.65 +0.8 +0.65 +0.65 +0.84 +0.8 +0.8 +0.8 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.89 +0.65 +0.65 +0.65 +0.83 +0.65 +0.65 +0.65 +0.65 +0.6 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.74 +0.65 +0.65 +0.66 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.84 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.88 +0.65 +0.65 +0.8 +0.65 +0.65 +0.7 +0.65 +0.65 +0.65 +0.9 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.82 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.75 +0.65 +0.7 +0.9 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +1 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.71 +0.65 +0.65 +0.65 +0.79 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.82 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.87 +0.65 +0.66 +0.65 +0.84 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.84 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.5 +0.65 +0.64 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.84 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.8 +0.65 +0.85 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.73 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.82 +0.8 +0.65 +0.65 +0.65 +0.84 +0.9 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.64 +0.65 +0.65 +0.65 +0.8 +0.8 +0.87 +0.65 +0.65 +0.78 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.74 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.83 +0.89 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.8 +0.84 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.81 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.7 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.82 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.87 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.73 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.89 +0.8 +0.65 +0.9 +0.65 +1 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.89 +0.89 +0.65 +0.65 +0.65 +0.8 +0.75 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.88 +0.65 +0.8 +0.65 +0.65 +0.8 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.9 +0.57 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.8 +0.79 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.89 +0.8 +0.65 +0.8 +0.65 +0.8 +0.65 +0.81 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.84 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.8 +0.83 +0.65 +0.65 +0.8 +0.65 +0.65 +0.72 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +1 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.69 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.71 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.85 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.9 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.85 +0.65 +0.65 +0.8 +0.65 +0.89 +0.65 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.86 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.75 +0.8 +0.65 +0.8 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.82 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.83 +0.65 +0.65 +0.92 +0.89 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.75 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.85 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.87 +0.65 +0.79 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.83 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.7 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.65 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.8 +0.82 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +1 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.64 +0.65 +0.65 +0.63 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.76 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.65 +0.75 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.87 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.82 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.89 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.65 +0.73 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.9 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.69 +0.65 +0.65 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.72 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.9 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.45 +0.8 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.51 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.66 +0.65 +0.8 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.75 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.66 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.65 +0.85 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.79 +0.75 +0.65 +0.65 +0.8 +0.65 +0.67 +0.8 +0.8 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.81 +0.8 +0.65 +0.65 +0.9 +0.65 +0.79 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.74 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.6 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.89 +0.8 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.9 +0.75 +0.65 +0.65 +0.65 +0.8 +0.6 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.84 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.85 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.63 +0.65 +0.65 +0.65 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.84 +0.65 +0.65 +0.8 +0.65 +0.81 +0.8 +0.8 +0.8 +0.82 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.65 +0.88 +0.65 +0.8 +0.65 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +1 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.74 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.9 +0.86 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.64 +0.65 +0.65 +0.8 +0.8 +0.65 +0.87 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.7 +0.65 +0.65 +0.8 +0.65 +0.65 +0.75 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.85 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.71 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.73 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.75 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.88 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.8 +0.65 +0.65 +0.9 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.9 +0.65 +0.65 +0.65 +0.65 +0.7 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.57 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.76 +1 +0.8 +0.65 +0.65 +0.58 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +1 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.8 +0.9 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.68 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.99 +0.8 +0.77 +0.65 +0.9 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.8 +0.8 +0.65 +0.7 +0.65 +0.65 +0.8 +0.9 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.77 +0.65 +0.65 +0.65 +0.65 +0.79 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.52 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.86 +0.65 +0.65 +0.8 +0.56 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.72 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.8 +0.6 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.89 +0.85 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.75 +0.65 +0.65 +0.65 +0.65 +0.54 +1 +0.65 +0.65 +0.75 +0.65 +0.75 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.9 +0.62 +0.65 +0.65 +0.65 +0.65 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.82 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.74 +0.8 +0.65 +0.8 +0.8 +0.7 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.8 +0.8 +0.8 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.8 +0.8 +0.84 +0.8 +0.65 +0.65 +0.8 +0.75 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.82 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.84 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.8 +0.65 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.74 +0.65 +0.8 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.65 +0.65 +0.85 +0.65 +0.9 +0.9 +0.65 +0.65 +0.65 +0.63 +0.82 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.65 +0.74 +0.9 +0.65 +0.8 +0.65 +0.65 +0.58 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.75 +0.65 +0.65 +0.8 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.87 +0.65 +0.65 +0.65 +0.8 +0.65 +0.64 +0.65 +0.65 +0.65 +0.8 +0.87 +0.65 +0.65 +0.8 +0.9 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.83 +0.65 +0.65 +0.8 +0.65 +0.9 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.78 +0.65 +0.8 +0.65 +0.9 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.9 +0.65 +0.88 +0.8 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.77 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.8 +0.65 +0.65 +0.65 +1 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.85 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.88 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.65 +0.65 +0.65 +0.65 +0.68 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.9 +0.65 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.81 +0.65 +0.65 +0.65 +0.8 +0.85 +0.65 +0.77 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.8 +0.8 +0.9 +0.65 +0.65 +0.89 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.8 +0.65 +0.65 +0.65 +0.88 +0.8 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.82 +0.65 +0.8 +0.74 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.85 +0.65 +0.65 +0.85 +0.65 +0.65 +0.65 +0.65 +0.7 +0.7 +0.8 +0.65 +0.65 +0.65 +0.65 +0.87 +0.8 +0.65 +0.65 +0.65 +0.89 +0.85 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.7 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.9 +0.8 +0.8 +0.65 +0.66 +0.57 +0.65 +0.65 +0.65 +0.49 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.65 +0.65 +0.65 +0.8 +0.65 +0.8 +0.8 +0.86 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.89 +0.65 +0.65 +0.65 +0.65 +0.65 +0.65 +0.76 diff --git a/ram/data/tag_list.txt b/ram/data/tag_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..11a61b68fb9a22eec9cc52a3a2d32474323aafdb --- /dev/null +++ b/ram/data/tag_list.txt @@ -0,0 +1,3429 @@ +tennis +bear cub +observatory +bicycle +hillside +judge +watercolor illustration +granite +lobster +livery +stone +ceramic +ranch +cloth +smile +building +tattoo +cricketer +cheek +pear +source +winter +surface +spray +ceremony +magic +curve +container +fair +medicine +baby +tennis racquet +ornament +bamboo +duckling +song +safari +team presentation +daffodil +cross +toothpaste +shield +fashion model +capsule +map +creek +glass house +glass plate +siding +corner +water buffalo +bison +figure skater +diploma +tire +race +cable car +brain +gas stove +soap bubble +palette +snowboard +school child +trench coat +monk +fiber +kitchen window +sunglass +coffee +security +strawberry +penguin +tree root +loaf +engagement ring +lamb +vector cartoon illustration +sandwich +mountain village +shape +charm +fiction +knot +greenhouse +sushi +text +disaster +trophy +gang +strap +soccer game +cardinal +tee +turtle +water surface +grassland +dolphin +store +dirt +iceberg +pergola +farmer market +publicity portrait +tote bag +teenage girl +view mirror +session +commuter +dressing room +tricycle +christmas ball +headlight +police +armchair +chart +yacht +saw +printer +rock band +gingerbread house +tag +table lamp +hockey game +slope +font +wicker basket +jewelry +quarter +software +weapon +pin +worship +painter +goal +morning light +bike +baseball bat +elevator +cuisine +sausage +stunt +wrestler +statue +landing +pillar +willow tree +sea wave +chicken +peanut +muscle +bob +tv genre +bathroom window +radish +textile +pelican +marketplace +crest +elevation map +gift +parish +traffic light +campfire +fog +award winner +beach ball +mat +white house +plaster +moped +football team +solution +bicyclist +bit +playground +darkness +cake +maple leave +mold +cracker +blueberry +rubble +container ship +pedestrian bridge +snail +parrot +form +circuit +highlight +pickup truck +koala +rain +system +weather +raincoat +soccer team +windshield +thunderstorm +mike +bird house +bridge +grandfather +restroom +animation +wilderness +clown +banana +brown +braid +dining room +kindergarten +launch event +purple +school +stairwell +brooch +movie poster image +mountain river +shelf +wicket +headboard +buddha +flower field +dugout +cd +bald eagle +lagoon +seaweed +agriculture +emergency service +maple tree +parachute +continent +amusement park +remote +bun +tackle +hospital +garage door +birthday party +friendship +go +mausoleum +jeep +raccoon +step +ice hockey team +cigarette +lace dress +forest floor +mall +captain +milk +golf course +meal +picnic table +sail +volleyball +canal +terrace +computer desk +caravan +hotel +cheerleader +nurse +museum +marsh +fox +plateau +night +twin +letter logo +autumn tree +powder +convention +creature +lighthouse +shop window +jacket +stork +taxi +trade +blackboard +olive +road sign +resort +snowflake +cemetery +travel +evening dress +picnic +drink +winter morning +football player +snack +boxing glove +dinner party +airline +swing +port +wheelbarrow +bathroom sink +sweater +ambulance +gear +oil +wii controller +array +home office +car show +mixture +profession +tree frog +square +facility +coral reef +sea wall +pizza +exhibit +demolition +trout +ring +coffee shop +bracelet +bean +lip +fencing +landscape +sitting +package +metal +bust +king +hair +window seat +wildlife +trunk +greenery +stencil +fire hydrant +bridesmaid +plaza +alps +tower bridge +crop top +crossing +cinema +pedestrian crossing +family +shopping cart +stomach +church building +screen door +skater +soccer field +kettle +mussel +raindrop +candy cane +water lily +flower girl +desert +enclosure +christmas light +kitchen +caterpillar +plaid +bath +bush +mud +ballet +knee +adult +raft +sea view +cactus +office chair +overall +rim +scaffolding +pig +cover +poster page +sprinkle +chandelier +algae +traffic +surfboard +book +filming +flash +mansion +camouflage +trouser +ticket +weed +cab +trench +elephant +huddle +sphere +christmas decoration +city +launch +doll +christmas ornament +fabric +bikini +biplane +breakfast +neighbourhood +race track +foliage +avocado +school bus +footwear +highway +ocean view +art vector illustration +wall clock +curtain +teenager +kitchen area +robot +tusk +lounge chair +beam +paddle +camel +lid +world map +city view +newlywed +cargo ship +yellow +exhibition +bend +novel +wool +ontario +bread +campus +coastline +cutting board +booth +table top +carpet +beach chair +workout +street food +fun +costumer film designer +gadget +artist +fishing village +builder +violinist +iphone +spider web +traffic sign +ruin +rescue +clipboard +seal +film director +paw +nursery +intersection +tomato sauce +taste +paddy field +christmas tree +wave +stool +watering can +rug +daytime +subway station +craft +pine forest +black +planet +motif +christmas market +glass window +college +wheat +damage +rectangle +picture frame +chess +guest room +street corner +religion +seed +puzzle +freeway +beauty +ocean +watch +mother +garage +quote +dj +supporter +hip hop artist +muffin +eiffel tower +cash +firefighter +cauliflower +bunker +sled +manicure +shark +stall +jungle +family home +tour bus +chimney +touchdown +roundabout +coyote +street scene +tank +wedding dress +mantle +bedroom window +coconut +chapel +goat +living space +rock wall +polka dot +railway +mandala +mango +lesson +mountain landscape +team photo +bookshelf +meter +bulldog +evening sun +stick +card +pink +fish pond +paint +pill +cart +pea +van +album +football college game +mountain pass +doughnut +ski slope +match +official +shadow +organ +celebration +coin +log cabin +firework display +present +twig +chef +confetti +footpath +tour +ponytail +artwork +race car +club +season +hose +pencil +aircraft +rock formation +wardrobe +participant +politician +engineer +peace +filter +sailing boat +water bottle +service dog +poodle +loki +statesman +sleeping bag +outskirt +clock +factory +oak tree +physician +color +room +stairway +company +lady +graph +faucet +tablecloth +subway train +chocolate chip cookie +headquarters +screw +goggle +halloween +city street +swirl +cord +forward +bone +bedding +archway +wig +lobby +mask +attic +kitchen table +skylight +fire +exit +oil painting +passenger +meditation +salmon +fedora +rubber stamp +orange juice +arch +scientist +stroll +manhattan +float +baseball uniform +circle +church +decker bus +competitor +zoo +basketball team +tourist +daughter +silverware +ceiling fan +birth +vase +jack +mushroom +spiral +cage +limb +salad +ad +control +earth +party +bolt +tractor +barley +wedding photo +hawk +warehouse +vegetable garden +chocolate cake +cabbage +floor window +baby shower +magnifying glass +table +stethoscope +reading +mission +croissant +gift box +rocket +forest road +cooking +suite +hill country +motorcycle +baseball player +angle +drug +sport association +championship +family portrait +florist +softball +egret +office +plywood +jockey +mosque +brunch +beanie +office building +pattern +calendar +indoor +pepper +ledge +trail +fuel +laptop computer +tennis shoe +deck chair +guitarist +barn +surgery +cartoon illustration +nebula +railroad +mountain goat +goose +car door +cheer +liquid +hardwood floor +pathway +acorn +gull +airliner +couch +lake house +spaghetti +promenade +collection +garden +bank +robin +tennis ball +peony +gymnast +lavender +deck +test +riverside +rapper +domino +bride +mouse +basil +wedding couple +ocean wave +arm +kitchen floor +grove +family member +backyard +raspberry +forest fire +officer +hibiscus +canyon +composer +signature +olive oil +hibiscus flower +rose +vector icon +sunrise +horseback +motor scooter +office worker +tradition +ingredient +washing machine +lighting +bagel +sailboat +policeman +mare +graphic +halloween pumpkin +stock +pilot +education +team +body +horse +kimono +bazaar +bag +recording studio +parsley +entrance +denim +vet +horse farm +charcoal +architecture +glass vase +puppy +estuary +television show host +city bus +shoulder +beast +balance +golfer +roadside +denim jacket +stone wall +counter top +app icon +toast +head coach +ham +warrior +gem +refrigerator +snowman +construction worker +coal +website +morning fog +mustard +human +owl +puppy dog +piggy bank +vegetation +pirate +action film +marshmallow +thanksgiving +business +disease +signage +greeting +skate park +tile +mouth +spinach +vacation +leader +shrine +walker +science fiction film +bill +rabbit +motor boat +bar +radio +barge +tail +chainsaw +gallery +rainbow +pasta +padlock +web +pastry +ink +reef +school uniform +shawl +treasure +peach +dinner table +injury +harbor +witch +car dealership +litter +gesture +documentary +marriage +sea shell +priest +dome +kit +icon +seaside +bucket +entertainment +stable +hat +puddle +sock +shopper +technology +harbour +orbit +antler +tube +flag waving +cook +tight +commander +farmland +switch +hiker +wedding ceremony +award ceremony +champion +chopstick +farmhouse +performer +spike +accident +cruise ship +passenger train +attraction +entertainer +rear view +sidewalk +parade +racing +plane +ritual +peacock +pocket +plum +drop +carrot +floor +sunset +troop +architect +coffee table +dust +outline +leather +charity event +heat +whale +laundry +coconut tree +crosswalk +pony +ant +pipe +string +coat +angel +beef +church tower +dish +pitch +cupboard +thermometer +dirt field +fireworks +minute +cane +pajama +flower garden +autumn +trash can +dachshund +banana tree +tray +moose +roadway +carnival +antenna +pole +castle wall +ram +cattle +hay +cookie +swimmer +baseball team +strait +hedge +jet +fire pit +octopus +calf +cube +opera +cardboard box +tiara +kitchen sink +prairie +bowl +galaxy +straw hat +linen +ski resort +stitch +street lamp +motorist +icicle +stain +flora +drain +kitchen cabinet +decor +bouquet +pound +interior design +nail polish +figurine +tomb +disc +twist +blouse +ribbon +figure +burger +cork +soccer goalkeeper +train bridge +drinking water +dew +baker +storm cloud +tarmac +tv drama +sponge +magnet +sailor +entry +swan +exercise +sloth +jewel +scuba diver +bite +cat tree +tent +can +tennis match +ecosystem +picket fence +palm +train car +frying pan +rally +tablet pc +reindeer +image +wolf +chin +conservatory +flood water +cityscape +beach sand +car park +pavement +farm field +swimming +winter storm +stem +pillow +inning +gorilla +desk +avenue +fern +money +pearl +train station +skillet +nap +barber +library +freezer +label +rainforest +parking sign +mirror +wing +noodle +press room +sculpture +tablet +viewer +prayer +mini +mechanic +laugh +rice field +hand +mustache +mountain road +catwalk +conference +cape +installation +musician +stream +machine +speech +crocodile +soccer match +town square +passport +post box +point +stone building +motorway +mix +dentist +businessperson +happiness +boat +vineyard +treadmill +glass wall +water droplet +coffee mug +graduate +sunflower +parliament +shepherd +movie +wine +orchard +tulip +motherboard +cup +broom +spot +drawing +polo shirt +graduation +film producer +moonlight +glow +film format +t shirt +rock face +sword +clinic +festival day +meadow +staple +pupil +training ground +rider +flower +foal +wharf +foot bridge +shooting +top +mast +police car +robe +wedding bouquet +stop sign +birthday cake +glitter +butter +scooter +tundra +superhero +pocket watch +inscription +youngster +fruit tree +movie poster +engine +foundation +motorcyclist +take +woman +antelope +country artist +road trip +typewriter +tuxedo +brand +pine +bathroom +paradise +texture +balloon +dining table +home +computer screen +actor +clip +tv tower +panorama +summit +cat +plot +eagle +dancer +pup +studio shot +tear +bird bath +classroom +bookstore +city wall +tv programme +blade +easel +buttercream +sweet +designer +diamond +handshake +herb +corn field +seafront +concrete +street artist +gas +stamp +window display +paper +note +pint +quarry +research +fixture +manager +soil +leopard +board game +ladder +stop light +island +ramp +football match +icing +drill +currency +summer evening +topping +pyramid +pomegranate +cell +ivy +squad +scenery +computer +locomotive +surf +mascot +dune +path +duck +twilight +wire +bow tie +strike +cormorant +car wash +crane +market +philosopher +alarm clock +camera +birch +greeting card +plain +clay +donut +lock +moth +laboratory +fan +violin +jazz fusion artist +mountain biker +terrain +magazine +pickup +comedy film +smartphone +film +bed +microwave oven +tournament +lawn +car window +alligator +screen +jetty +shopping bag +landscape view +cabinetry +friendly match +thing +petal +shopping center +transport +ballet dancer +shoreline +princess +car seat +parking meter +green +vodka +band +rock +costume +warning sign +strip +plaque +wheelchair +headband +ginger +dice +media +hairdresser +press +living room +stove +player +cherry +workshop +carving +embroidery +doodle +adventure +rugby player +monument +brush +marker +loft +postcard +collage +ball +professor +dresser +gig +festival +blackbird +makeup artist +video camera +sticker +peak +wildflower +santa hat +rodeo +wedding photographer +guy +staff +waterfall +operation +defender +falcon +haze +individual +gentleman +greyhound +rocking chair +rice +garbage +platter +chocolate +splash +business suit +cheetah +valley +maze +trampoline +garland +slalom +unicorn +tree stump +painting +romance +fight +alcohol +ghost +fondant +spa +shutter +death +demonstration +cotton +pier +flea market +history +savannah +fist +aisle +crew +jug +pose +anchor +teapot +boat house +business team +tripod +bee +pebble +mattress +canvas +hallway +campaign +pod +lake district +article +white +sofa +honey +marathon +pancake +tourist attraction +wedding gown +battle +shelving +sea +sheet music +pie +yarn +construction site +flyer +tie +star +lettuce +martial artist +dart +straw +reflection +conference room +temperature +rugby +mosquito +physicist +rock climber +crash +backdrop +toilet seat +sand castle +water park +toy car +waste +luxury +hangar +rv +tree trunk +board +gold +project picture +cap +cottage +relief +attire +microscope +battery +roll +line +parking garage +crystal +broadcasting +brick wall +lab +flooring +meeting +3d cg rendering +desktop computer +cowboy +sailing ship +junction +hairstyle +homework +profile +model +flower pot +street light +salt lake +maple +space +blizzard +throw +zebras +brochure +constellation +beak +kilt +pond +blue sky +sneaker +sand dune +morning sun +almond +grill +curl +basketball girl game +chameleon +toilet bowl +prince +keyboard +queen +computer monitor +writing +crown +basilica +kiss +house +parking +football competition +shell +sport equipment +comedy +baboon +vendor +rise building +wrap +food truck +cat bed +rickshaw +flare +teal +nectar +eclipse +vehicle +steam locomotive +gorge +cow +christmas card +demonstrator +memorial +towel +jewellery +train +frisbee +baseball game +fur +afternoon sun +community +sparkler +bandage +firework +dollar +pasture +video +bus +tree house +seashore +field +hamburger +souvenir +hedgehog +worm +pine cone +osprey +dinosaur +vegetable +junk +poster +army +winger +bundle +stage +growth +wedding party +service +blanket +ruler +eye +credit card +castle +diner +hut +elk +hard rock artist +nun +dog breed +nest +drama film +number icon +water tank +giraffe +altar +pavilion +tv personality +suv +street vendor +street sign +ditch +debris +foam +takeoff +spice +mountain lake +tea +orchestra +spacecraft +counter +abbey +mountain +hydrangea +racer +orange tree +tide +cowboy hat +rapid +town +wild +herd +vein +driveway +jar +bark +illustration +horror film +corn +stroller +industry +mountain stream +gym +neckline +pan +client +spectator +eggplant +camper +fawn +hoodie +meat +lemonade +food market +slum +comic book character +flower market +love +palace +gun +heel +shopping street +shooting basketball guard +family photo +rooftop +laundry basket +airport runway +horn +face mask +flight +appetizer +violet +country lane +cement +instrument +tv actor +spark +celebrity +award +country house +standing +auction +date +engagement +puck +advertisement +chair +zebra +driftwood +bumblebee +maple leaf +bonnet +orange +water tower +door +singer +floor plan +discussion +theatre +pilgrim +mug +branch +window sill +baseball pitcher +bakery +lollipop +basketball player +toilet paper +chalkboard +cabin +sign +night sky +cannon +fishing net +submarine +suit +fur coat +wine bottle +folder +street art +suspension bridge +evening sky +billboard +postage stamp +newspaper +transportation +surgeon +light +park +horizon +road +sand bar +trumpet +lounge +cloud forest +birthday celebration +balcony +anime +beehive +umbrella +goldfish +baseball cap +waterhole +ceiling +carousel +backpack +plant pot +atmosphere +sunflower field +spire +vision +woodpecker +chip +pool table +lotus flower +cone +humpback whale +reservoir +hunt +piano +plate +dining area +luggage +skier +dance floor +crow +stair +overpass +opera house +bear +jazz artist +water +vessel +cast +yard +cathedral +basketball hoop +graveyard +sound +berry +onlooker +fauna +birch tree +retail +hill +skeleton +journalist +frost +basket +nail +dusk +trash +dawn +clover +hen +volcano +basketball coach +home decor +charge +haircut +sense +university +lizard +daisy +tablet computer +grass field +prison +metal artist +bathroom mirror +window frame +chest +flavor +pop country artist +market square +monkey +blog +deer +speech bubble +dog +independence day +girl +boy +tartan +furniture +appliance +office window +fish boat +sand box +tv sitcom +drama +sleigh +depression +paper towel +baseball +protestor +grape +wedding cake +invitation +accessory +pick +grandparent +racket +tea plantation +outdoors +egg +glass bowl +sun +organization +lion +panel +station +wallpaper +helicopter +salt +vanity +patio +lunch +street performer +mountain range +soup +bacon +power station +cantilever bridge +hummingbird +shirt +rope +hip +chalk +pendant +choir +tv +lichen +railway bridge +art gallery +bartender +wagon +baby elephant +accordion +horseshoe +building site +clutch +harvest +savanna +geranium +business woman +paddock +patch +beech tree +war +suburbs +hospital bed +motorcycle racer +moss +gravel +government agency +dollar bill +father +fjord +concert +nut +wedding photography +finish line +home plate +food +nose +thumb +village +dining room table +bumper +monster +blackberry +lime +conflict +gala +wallet +wrist +hug +mermaid +lava +lawyer +folk rock artist +arena +onion +toothbrush +fashion +perfume +flip +triangle +woodland +mail +grasshopper +studio +wood floor +den +racquet +cello +lemur +astronaut +glass table +blood +dvd +planter +silver +leash +master bedroom +forest +batter +shoe +engraving +opening +product +toe +cocktail +mallard duck +bike ride +oasis +wedding ring +cinematographer +holly +autograph +fence +ice cube +cove +pineapple +aurora +glass bead +produce +apartment building +cob +miniature +cockpit +flashlight +frog +sheep +groom +steel +watermelon +clip art +paper plate +ostrich +contour +mural +cub +paisley bandanna +winery +turn +handle +satellite +post +pork +child +asphalt +grocery store +vulture +trolley +nightclub +brick +trailer +compass +cereal +cafe +cartoon character +sugar +fiction book +glass floor +umpire +guitar +hamster +protester +airplane +garment +blazer +railway line +wedding +shoe box +parking lot +construction +graduation ceremony +tram +telescope +copper +pain +autumn forest +guest house +partner +crayon +dip +boot +corridor +computer keyboard +hockey player +chicken coop +bus station +gathering +ankle +bunk bed +wood table +football coach +monarch +pharmacy +legging +mannequin +female +train track +stack +canopy +design element +grandmother +symbol +beach hut +zucchini +bomb +businessman +skyscraper +tongue +case +sparkle +highland +ballroom +prom +estate +customer +archipelago +cheese +debate +carriage +bulldozer +pumpkin +sitting room +gas station +wedding reception +camp +dog bed +tower +property +river bed +pop latin artist +fridge +wine glass +coast +beer +tow truck +fire truck +mountain bike +thigh +heron +boat ride +gondola +turquoise +lake +llama +kitty +tin +waiting room +coffee cup +socialite +guard +tap +waterway +forehead +list +erosion +box +sea lion +pollen +dam +wasp +salon +tennis tournament +flower box +aquarium +rain cloud +clothing store +lead singer +cupcake +tortoise +lettering +sport facility +dance +dog house +nature +football +rooster +footballer +railway track +crowd +fishing rod +silhouette +wind turbine +sari +bus window +cloud +charity +medal +yoga +event +veil +fashion menswear milan week +news +knife +print +screen tv +walnut +fungus +ice cream +computer mouse +play +tribe +picture +video game +business card +music festival +rack +envelope +shower +dirt road +mine +oyster +monarch butterfly +dude +fruit salad +podium +fork +lace +test match +boulder +cricket player +staircase +peninsula +shopping +popcorn +oak +market stall +pine tree +mountaineer +student +closet +hood +handstand +centerpiece +insect +patient +makeover +tennis player +sheet +park bench +apple +organism +hook +turkey +tangerine +sibling +shopping mall +bird +scarf +smoothie +net +grass +napkin +ray +eyebrow +laptop keyboard +motorbike +woman hand +oven +book cover +easter egg +microwave +sand +snapshot +soccer ball +makeup +knight +bowling ball +shower curtain +flame +lightning +running +power plant +crib +cartoon +moat +fashion girl +wedding invitation +bottle +cliff +monastery +file photo +apartment +casino +cream +sweatshirt +storm +cruise +teddy bear +shovel +wind farm +writer +dock +professional +hotel room +job +monitor +donkey +pass +interview +duchess +mark +plank +beard +zombie +trio +channel +cricket team +windmill +vest +diagram +cable +winter scene +golden gate bridge +buffalo +studio portrait +pagoda +whiskey +freight train +kite +future +steam train +phone box +headset +wood +snowboarder +paper bag +slide +grapefruit +seating +morning +bronze sculpture +theatre actor +stump +jean +landmark +jam +waist +watercolor +hammock +light fixture +ice +basin +beverage +shelter +premiere +mound +ear +bronze +sunlight +street +energy +barn door +hike +fleet +claw +beach +pepperoni +bin +trainer +buffet +archive +toddler +referee +bay window +dove +production company +evening light +gate +farm +reed +fruit stand +explorer +snow storm +throw pillow +button +display case +bookcase +lead +lipstick +basketball court +cargo +ensemble +pope +clock tower +teen +speaker +rat +laptop +ski +mess +stadium +ferry boat +bunny +waterfront +downtown +sink +press conference +dinner +condiment +thread +audience +grid +car +plastic +people +barbecue +pigeon +urinal +seagull +volunteer +hockey +fir tree +pollution +trial +collar +area +meeting room +circus +yogurt +orangutan +viaduct +comedian +drone +scissor +pop rock artist +biscuit +panda +water feature +air balloon +remote control +watercolor painting +show +walk +post office +bike path +rap gangsta artist +microphone +crack +sunset sky +glass +tv show +cartoon style +stripe +foyer +signal +calligraphy +bulb +gardener +coffee bean +spider +tapestry +city skyline +necklace +kitten +traveler +veteran +frosting +fry +tennis court +tank top +butterfly house +mist +drummer +water level +scale +baseball glove +music video performer +champagne +camping +clothing +water drop +telephone box +pen +morning mist +fire engine +porch +opening ceremony +style +palm tree +fashion show +universe +scratch +axe +ottoman +explosion +rib +boutique +game +cucumber +fruit +stone bridge +nature reserve +track +train window +punch +telephone pole +velvet +sauce +moon +contrast +flamingo +bat +vending machine +ship +equestrian +shade +comforter +pallet +sparrow +wii +glaze +grocery +steeple +soccer player +contract +advertising +runner +chimpanzee +world +seat +project +chihuahua +bubble +willow +pedestal +soul hip hop artist +curb +drawer +leaf +banner +launch party +coach +government +snowball +toy +portrait +doctor +whiteboard +electronic +tiger +graffiti +column +nightstand +whistle +maxi dress +bench +wetsuit +bird feeder +football game +basketball +class +bathroom door +store window +text message +wreath +street view +binocular +pet +facade +drought +lemon +new year +night view +airplane window +specie +rule +jaw +wheat field +diet +pop artist +habitat +screenshot +scoreboard +shore +mane +quilt +ski lift +orchid +turban +christmas +airport +marina +glass door +glass bottle +restaurant +conductor +logo +sleep +tape +tomato +river bank +lilac +tooth +training +pottery +shop +steam engine +mason jar +base +procession +border +shoot +footprint +hotdog +bull +stocking +recreation +automobile model +design +country pop artist +river +retriever +department store +auditorium +sport car +supermarket +belt +cricket +window box +dress shirt +letter +residence +megaphone +pant +wildfire +bird nest +crab +swimsuit +candle +funeral +mill +national park +plant +cop +power line +perch +blue +finger +ferris wheel +globe +skateboard +helmet +movie theater +uniform +hammer +material +kid +well +butterfly +sideline +fashion fall show +planet earth +lift +male +sauna +gray +flour +sand sculpture +program +cabinet +infant +wheel +aircraft model +dough +garlic +skate +arrow +wrapping paper +ripple +lamp +iron +banknote +beaver +ferry +courtyard +bassist +countryside +steak +comfort +boxer +laundry room +campsite +brick building +golf +subway +headphone +fort +handbag +drum +flood +saddle +bass +labyrinth +needle +sun ray +app +menu +president +cardigan +dandelion +wetland +ice hockey player +number +city hall +fishing +portrait session +pug +key +art print +minister +hurdle +emergency +painting artist +flag pole +evening +purse +recipe +golf ball +coloring book +mountain peak +senior +holiday +bud +cousin +pantry +lap +skin +flag +tissue paper +ridge +wire fence +surfer +climber +photograph +sewing machine +cooler +actress +apple tree +cancer +starfish +automobile make +dumbbell +brace +tunnel +window +paint artist +composition +school student +condo +convertible +cushion +selfie +territory +guide +tree +court +shrimp +stone house +dress +eyelash +juice +broccoli +chain +tourism +mountain top +concept car +film premiere +light bulb +cafeteria +badge +flower bed +theater +root +racecar driver +basketball boy game +glove +skyline +wall +glacier +airport terminal +bug +trim +railway station +briefcase +flat +fountain +person +lane +asparagus +art +lantern +dishwasher +director +snake +lecture +game controller +tree branch +pub +bathing suit +queue +belly +poppy +bow +pitcher +ice cream cone +cave +candy +road bridge +host +traffic jam +earring +file +foot +watermark overlay stamp +mailbox +supercar +railing +bedroom +seafood +waffle +bronze statue +plan +flow +marble +basketball game +automobile +scene +cypress tree +soldier +skateboarder +glass building +cherry tree +pump +grain +wildebeest +loop +frame +bathtub +saxophone +diver +stalk +lily +bead +alley +flock +family room +manufacturing +pointer +worker +navy +potato +teacher +photography +dolly +boardwalk +water fountain +athlete +side dish +bay +ice hockey +phone +hero +face +gold medal +blind +swamp +researcher +swim +meatball +iguana +leather jacket +jellyfish +site +smoke +traffic signal +melon +beetle +calculator +skirt +plantation +sculptor +barrier +catcher +security guard +sketch +awning +steering wheel +mountain view +bus stop +pool +leg +spotlight +apron +mineral +inlet +sleeve +torch +emotion +march +police officer +performance +lamp post +fishing boat +summer +presentation +saucer +suitcase +supermodel +goalkeeper +shrub +rock artist +document +beach house +man +blue artist +cigar +railroad track +gown +mosaic +bungalow +alphabet +baseball field +shed +pedestrian +rail +soap +kitchen counter +dessert +dunk +blossom +conversation +fruit market +glass jar +military +beer bottle +photographer +tennis racket +competition +escalator +bell tower +stilt +ballerina +television +feather +fence post +rear +dahlia +red carpet +tub +hole +fortress +pack +telephone +cardboard +city park +platform +college student +arch bridge +wind +blender +bloom +ice rink +birthday +raven +fairy +embankment +hall +flower shop +suburb +barrel +biker +steam +dragonfly +formation +electricity +business people +symmetry +walkway +fisherman +gas mask +loch +youth +hanger +dot +fish +street market +animation film +crime fiction film +boar +emblem +halloween costume +kangaroo +couple +spoon +squirrel +neon sign +sky +office desk +beauty salon +breakwater +fashion look +toaster +author +news conference +outdoor +canoe +dragon +tool +shopping centre +ladybug +swimming pool +landscaping +ski pole +red +truck +fly +temple +level +sunday +railroad bridge +car mirror +lawn mower +flute +aircraft carrier +fashion menswear london week +sunshine +tile floor +skull +fossil +flower arrangement +diaper +sea turtle +cherry blossom +fireman +shack +lens +waiter +animal +basement +snow +autumn park +glass box +kick +head +anniversary +vine +back +paper lantern +fish tank +cellphone +silk +coral +notebook +photo +gazebo +ketchup +driver +farmer +bonfire +chestnut +photoshoot +football field +olive tree +pheasant +sandal +toilet +fireplace +music +deity +fish market +fig +bell +neck +grave +villa +cyclist +crate +grey +asphalt road +soccer +hostel +municipality +courthouse +roof +end table +pot +sedan +structure +folk artist +sport +sport team +protest +syringe +fashion designer +jersey +heart shape +kayak +stare +sit with +direct +read +photograph +spin +teach +laugh +carve +grow on +warm +watch +stretch +smell +decorate +shine +light +dance +send +park +chase +collect +lead +kiss +lead to +lick +smile +cheer +sit +point +block +rock +drop +cut +ski +wrap +lose +serve +provide +sleep +dress +embrace +burn +pack +stir +create +touch +wash +stick +reveal +shop +train +paint +groom +hunt +bloom +play +pay +brush +shoot +hold +picture +carry +sip +contain +turn +pour +pitch +give +add +blow +look in +show +walk +illuminate +kneel +cover +drag +post +present +fit +operate +fish +race +write +deliver +peel +push +run +sit around +buy +jump +walk on +attend +clean +sell +ride on +mount +host +dry +plant +sing +row +shake +perch +ride +fight +skateboard +live +call +surround +practice +play on +work on +step +relax +hit +fall in +flow +greet +launch +wear +hang on +drive +sit in +break +learn +fly +connect +display +locate +compete +go for +sail +lift +toast +help +run on +reflect +pose +scratch +frame +dribble +herd +enter +exit +place +inspect +build +pick +fill +grind +skate +offer +float +sit by +stand +release +rest +singe +climb +tie +mark +lay +stand around +capture +set +land +swinge +run in +kick +lean +head +sign +approach +swim +close +crash +control +fall +remove +repair +open +appear +travel +load +miss +check +surf +moor +smoke +drink +board +seat +feed +rise +sit on +swing +grow +strike +date +slide +share +graze +jump in +lie +extrude +roll +move +gather +eat +pull +run through +squeeze +lay on +draw +play with +wave +assemble +perform +march +score +attach +adjust +hang +hug +sleep on +throw +live in +talk +pet +work +run with +see +flip +catch +cook +receive +celebrate +look +classic +bridal +indoor +industrial +teenage +mini +grassy +aged +long +warm +light +handsome +happy +three +pregnant +circular +urban +silver +ceramic +3d +green +blonde +golden +dark +tropical +ripe +deep +fat +musical +giant +medical +medieval +bare +stunning +bold +geographical +huge +plastic +foggy +stormy +gothic +biological +empty +clear +antique +pink +steep +brown +striped +aerial +rainy +cool +flying +commercial +purple +trendy +blank +haired +dead +wooden +flat +high +beige +panoramic +angry +dozen +rural +solar +big +small +stained +thick +many +fresh +clean +strong +abstract +crowded +retro +dry +gorgeous +martial +modern +blue +cloudy +low +four +outdoor +single +much +beautiful +snowy +pretty +new +short +sunny +closed +rocky +red +two +double +male +gray +five +colorful +automotive +various +one +old +rusty +tall +wild +narrow +natural +several +frozen +textured +lush +young +hot +mixed +white +float +quiet +round +bright +religious +female +historical +shiny +traditional +tourist +yellow +bald +coastal +lovely +little +broken +romantic +wide +royal +rich +open +cute +ancient +cold +political +elderly +gold +full +rustic +metallic +floral +sad +wet +fancy +senior +tiny +stylish +large +frosty +orange +transparent +electronic +shallow +scared +armed +dirty +historic +black +few +windy +some +square +ornamental +sandy +thin \ No newline at end of file diff --git a/ram/inference.py b/ram/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..182efc55098ce201cdc776236aa8c9468845cb41 --- /dev/null +++ b/ram/inference.py @@ -0,0 +1,46 @@ +''' + * The Inference of RAM and Tag2Text Models + * Written by Xinyu Huang +''' +import torch + + +def inference_tag2text(image, model, input_tag="None"): + + with torch.no_grad(): + caption, tag_predict = model.generate(image, + tag_input=None, + max_length=50, + return_tag_predict=True) + + if input_tag == '' or input_tag == 'none' or input_tag == 'None': + return tag_predict[0], None, caption[0] + + # If user input specified tags: + else: + input_tag_list = [] + input_tag_list.append(input_tag.replace(',', ' | ')) + + with torch.no_grad(): + caption, input_tag = model.generate(image, + tag_input=input_tag_list, + max_length=50, + return_tag_predict=True) + + return tag_predict[0], input_tag[0], caption[0] + + +def inference_ram(image, model): + + with torch.no_grad(): + tags, tags_chinese = model.generate_tag(image) + + return tags[0],tags_chinese[0] + + +def inference_ram_openset(image, model): + + with torch.no_grad(): + tags = model.generate_tag_openset(image) + + return tags[0] diff --git a/ram/models/__init__.py b/ram/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69bdb22f2dba166bac07ab9d63fe8d0562dc88a6 --- /dev/null +++ b/ram/models/__init__.py @@ -0,0 +1,2 @@ +from .ram import ram +from .tag2text import tag2text_caption diff --git a/ram/models/bert.py b/ram/models/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..cb90b794284d2262d171aa0f93fdf20854a9059b --- /dev/null +++ b/ram/models/bert.py @@ -0,0 +1,1035 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +''' + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + + +class BertEmbeddings_nopos(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + # self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + # self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + # self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # if position_ids is None: + # position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + # if self.position_embedding_type == "absolute": + # position_embeddings = self.position_embeddings(position_ids) + # # print('add position_embeddings!!!!') + # embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + # print('add position_embeddings!!!!') + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + # print(self.key.weight.shape) + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # compatible with higher versions of transformers + if key_layer.shape[0] > query_layer.shape[0]: + key_layer = key_layer[:query_layer.shape[0], :, :, :] + attention_mask = attention_mask[:query_layer.shape[0], :, :] + value_layer = value_layer[:query_layer.shape[0], :, :, :] + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + + if mode == 'tagging': + + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + present_key_value = cross_attention_outputs[-1] + + else: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode=='multimodal': + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, + device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + # sequence_output.shape torch.Size([85, 30, 768]) + # prediction_scores.shape torch.Size([85, 30, 30524]) + # labels.shape torch.Size([85, 30]) + + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + if reduction=='none': + lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + diff --git a/ram/models/ram.py b/ram/models/ram.py new file mode 100644 index 0000000000000000000000000000000000000000..615c2a474e60550de2de31edca27160eefb597cb --- /dev/null +++ b/ram/models/ram.py @@ -0,0 +1,273 @@ +''' + * The Recognize Anything Model (RAM) + * Written by Xinyu Huang +''' +import json +import warnings + +import numpy as np +import torch +from torch import nn + +from .bert import BertConfig, BertLMHeadModel, BertModel +from .swin_transformer import SwinTransformer +from .utils import * + +warnings.filterwarnings("ignore") + + + +class RAM(nn.Module): + def __init__(self, + med_config=f'{CONFIG_PATH}/configs/med_config.json', + image_size=384, + vit='base', + vit_grad_ckpt=False, + vit_ckpt_layer=0, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt', + tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt'): + r""" The Recognize Anything Model (RAM) inference module. + RAM is a strong image tagging model, which can recognize any common category with high accuracy. + Described in the paper " Recognize Anything: A Strong Image Tagging Model" https://recognize-anything.github.io/ + + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + threshold (int): tagging threshold + delete_tag_index (list): delete some tags that may disturb captioning + """ + super().__init__() + + # create image encoder + if vit == 'swin_b': + if image_size == 224: + vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' + elif image_size == 384: + vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' + vision_config = read_json(vision_config_path) + assert image_size == vision_config['image_res'] + # assert config['patch_size'] == 32 + vision_width = vision_config['vision_width'] + + self.visual_encoder = SwinTransformer( + img_size=vision_config['image_res'], + patch_size=4, + in_chans=3, + embed_dim=vision_config['embed_dim'], + depths=vision_config['depths'], + num_heads=vision_config['num_heads'], + window_size=vision_config['window_size'], + mlp_ratio=4., + qkv_bias=True, + drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + use_checkpoint=False) + + elif vit == 'swin_l': + if image_size == 224: + vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' + elif image_size == 384: + vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' + vision_config = read_json(vision_config_path) + assert image_size == vision_config['image_res'] + # assert config['patch_size'] == 32 + vision_width = vision_config['vision_width'] + + self.visual_encoder = SwinTransformer( + img_size=vision_config['image_res'], + patch_size=4, + in_chans=3, + embed_dim=vision_config['embed_dim'], + depths=vision_config['depths'], + num_heads=vision_config['num_heads'], + window_size=vision_config['window_size'], + mlp_ratio=4., + qkv_bias=True, + drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + use_checkpoint=False) + + else: + self.visual_encoder, vision_width = create_vit( + vit, image_size, vit_grad_ckpt, vit_ckpt_layer) + + # create tokenzier + self.tokenizer = init_tokenizer() + + # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder + # create image-tag interaction encoder + encoder_config = BertConfig.from_json_file(med_config) + encoder_config.encoder_width = 512 + self.tag_encoder = BertModel(config=encoder_config, + add_pooling_layer=False) + + # create image-tag-text decoder + decoder_config = BertConfig.from_json_file(med_config) + self.text_decoder = BertLMHeadModel(config=decoder_config) + + self.delete_tag_index = delete_tag_index + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + + # load tag list + self.tag_list = self.load_tag_list(tag_list) + self.tag_list_chinese = self.load_tag_list(tag_list_chinese) + + # create image-tag recognition decoder + self.threshold = threshold + self.num_class = len(self.tag_list) + q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') + q2l_config.encoder_width = 512 + self.tagging_head = BertModel(config=q2l_config, + add_pooling_layer=False) + self.tagging_head.resize_token_embeddings(len(self.tokenizer)) + # self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size) + self.label_embed = nn.Parameter(torch.zeros(self.num_class, q2l_config.encoder_width)) + + if q2l_config.hidden_size != 512: + self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size) + else: + self.wordvec_proj = nn.Identity() + + self.fc = nn.Linear(q2l_config.hidden_size, 1) + + self.del_selfattention() + + # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder" + tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', + ' ') + self.image_proj = nn.Linear(vision_width, 512) + # self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/textual_label_embedding.pth',map_location='cpu').float()) + + # adjust thresholds for some tags + self.class_threshold = torch.ones(self.num_class) * self.threshold + ram_class_threshold_path = f'{CONFIG_PATH}/data/ram_tag_list_threshold.txt' + with open(ram_class_threshold_path, 'r', encoding='utf-8') as f: + ram_class_threshold = [float(s.strip()) for s in f] + for key,value in enumerate(ram_class_threshold): + self.class_threshold[key] = value + + def load_tag_list(self, tag_list_file): + with open(tag_list_file, 'r', encoding="utf-8") as f: + tag_list = f.read().splitlines() + tag_list = np.array(tag_list) + return tag_list + + # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label + def del_selfattention(self): + del self.tagging_head.embeddings + for layer in self.tagging_head.encoder.layer: + del layer.attention + + def generate_tag(self, + image, + threshold=0.68, + tag_input=None, + ): + + label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) + + image_embeds = self.image_proj(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], + dtype=torch.long).to(image.device) + + # recognized image tags using image-tag recogntiion decoder + image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + bs = image_spatial_embeds.shape[0] + label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', + ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(image.device), + torch.tensor(1.0).to(image.device), + torch.zeros(self.num_class).to(image.device)) + + tag = targets.cpu().numpy() + tag[:,self.delete_tag_index] = 0 + tag_output = [] + tag_output_chinese = [] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + tag_output.append(' | '.join(token)) + token_chinese = self.tag_list_chinese[index].squeeze(axis=1) + tag_output_chinese.append(' | '.join(token_chinese)) + + + return tag_output, tag_output_chinese + + def generate_tag_openset(self, + image, + threshold=0.68, + tag_input=None, + ): + + label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) + + image_embeds = self.image_proj(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], + dtype=torch.long).to(image.device) + + # recognized image tags using image-tag recogntiion decoder + image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + bs = image_spatial_embeds.shape[0] + label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', + ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(image.device), + torch.tensor(1.0).to(image.device), + torch.zeros(self.num_class).to(image.device)) + + tag = targets.cpu().numpy() + tag[:,self.delete_tag_index] = 0 + tag_output = [] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + tag_output.append(' | '.join(token)) + + return tag_output + + +# load RAM pretrained model parameters +def ram(pretrained='', **kwargs): + model = RAM(**kwargs) + if pretrained: + if kwargs['vit'] == 'swin_b': + model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) + elif kwargs['vit'] == 'swin_l': + model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs) + else: + model, msg = load_checkpoint(model, pretrained) + print('vit:', kwargs['vit']) +# print('msg', msg) + return model diff --git a/ram/models/swin_transformer.py b/ram/models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c1affc9a8695474e831ad060343c1988d750dc5f --- /dev/null +++ b/ram/models/swin_transformer.py @@ -0,0 +1,654 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- + +import numpy as np +from scipy import interpolate + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward(self, x, idx_to_group_img=None, image_atts=None, **kwargs): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + + x_cls = self.avgpool(x.transpose(1, 2)) # B C 1 + + if idx_to_group_img is None: + return torch.cat([x_cls.transpose(1, 2), x], dim=1) + else: + x_bs = torch.gather(x, dim=0, index=idx_to_group_img.view(-1, 1, 1).expand(-1, x.shape[1], x.shape[2])) + weights = image_atts[:, 1:].unsqueeze(2) # B L 1 + x_bs_cls = torch.sum((weights * x_bs).transpose(1, 2), dim=-1, keepdim=True) # B C 1 + x_bs_cls = x_bs_cls / torch.sum(weights.transpose(1, 2), dim=-1, keepdim=True) # avgpool + + return torch.cat([x_bs_cls.transpose(1, 2), x_bs], dim=1), \ + torch.cat([x_cls.transpose(1, 2), x], dim=1) + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name=''): + # from: https://github.com/microsoft/unilm/blob/8a0a1c1f4e7326938ea7580a00d56d7f17d65612/beit/run_class_finetuning.py#L348 + + # rel_pos_bias: relative_position_bias_table + src_num_pos, num_attn_heads = rel_pos_bias.size() + + num_extra_tokens = 0 + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) + if src_size != dst_size: + print("Position interpolate %s from %dx%d to %dx%d" % (param_name, src_size, src_size, dst_size, dst_size)) + + # extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + # rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + return a * (1.0 - r ** n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + # print("Original positions = %s" % str(x)) + # print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() + f = interpolate.interp2d(x, y, z, kind='cubic') + all_rel_pos_bias.append( + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + return rel_pos_bias \ No newline at end of file diff --git a/ram/models/tag2text.py b/ram/models/tag2text.py new file mode 100644 index 0000000000000000000000000000000000000000..39ab066890119bdbc4b034bfd8b09b4d6eb1b91b --- /dev/null +++ b/ram/models/tag2text.py @@ -0,0 +1,277 @@ +''' + * The Tag2Text Model + * Written by Xinyu Huang +''' +import numpy as np +import json +import torch +import warnings + +from torch import nn +from .bert import BertConfig, BertModel, BertLMHeadModel +from .swin_transformer import SwinTransformer + +from .utils import * + +warnings.filterwarnings("ignore") + + +class Tag2Text_Caption(nn.Module): + + def __init__(self, + med_config=f'{CONFIG_PATH}/configs/med_config.json', + image_size=384, + vit='base', + vit_grad_ckpt=False, + vit_ckpt_layer=0, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[127,2961, 3351, 3265, 3338, 3355, 3359], + tag_list=f'{CONFIG_PATH}/data/tag_list.txt'): + r""" Tag2Text inference module, both captioning and tagging are included. + Tag2Text is an efficient and controllable vision-language pre-training framework. + Described in the paper "Tag2Text: Guiding Vision-Language Model via Image Tagging" https://arxiv.org/abs/2303.05657 + + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + threshold (int): tagging threshold + delete_tag_index (list): delete some tags that may disturb captioning + """ + super().__init__() + + # create image encoder + if vit == 'swin_b': + if image_size == 224: + vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' + elif image_size == 384: + vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' + vision_config = read_json(vision_config_path) + assert image_size == vision_config['image_res'] + # assert config['patch_size'] == 32 + vision_width = vision_config['vision_width'] + + self.visual_encoder = SwinTransformer( + img_size=vision_config['image_res'], + patch_size=4, + in_chans=3, + embed_dim=vision_config['embed_dim'], + depths=vision_config['depths'], + num_heads=vision_config['num_heads'], + window_size=vision_config['window_size'], + mlp_ratio=4., + qkv_bias=True, + drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + use_checkpoint=False) + + else: + self.visual_encoder, vision_width = create_vit( + vit, image_size, vit_grad_ckpt, vit_ckpt_layer) + + # create tokenzier + self.tokenizer = init_tokenizer() + + # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder + # create image-tag interaction encoder + encoder_config = BertConfig.from_json_file(med_config) + encoder_config.encoder_width = vision_width + self.tag_encoder = BertModel(config=encoder_config, + add_pooling_layer=False) + + # create image-tag-text decoder + decoder_config = BertConfig.from_json_file(med_config) + self.text_decoder = BertLMHeadModel(config=decoder_config) + + # delete some tags that may disturb captioning + # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one" + self.delete_tag_index = delete_tag_index + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + + # load tag list + self.tag_list = self.load_tag_list(tag_list) + + # create image-tag recognition decoder + self.threshold = threshold + self.num_class = len(self.tag_list) + q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') + q2l_config.encoder_width = vision_width + self.tagging_head = BertModel(config=q2l_config, + add_pooling_layer=False) + self.tagging_head.resize_token_embeddings(len(self.tokenizer)) + self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size) + self.fc = GroupWiseLinear(self.num_class, + q2l_config.hidden_size, + bias=True) + self.del_selfattention() + + # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder" + tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', + ' ') + + # adjust thresholds for some tags + # default threshold: 0.68 + # 2701: "person"; 2828: "man"; 1167: "woman"; + tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7} + self.class_threshold = torch.ones(self.num_class) * self.threshold + for key,value in tag_thrshold.items(): + self.class_threshold[key] = value + + def load_tag_list(self, tag_list_file): + with open(tag_list_file, 'r') as f: + tag_list = f.read().splitlines() + tag_list = np.array(tag_list) + return tag_list + + # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label + def del_selfattention(self): + del self.tagging_head.embeddings + for layer in self.tagging_head.encoder.layer: + del layer.attention + + def generate(self, + image, + sample=False, + num_beams=3, + max_length=30, + min_length=10, + top_p=0.9, + repetition_penalty=1.0, + tag_input=None, + return_tag_predict=False): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1], + dtype=torch.long).to(image.device) + + # if not user specified tags, recognized image tags using image-tag recogntiion decoder + if tag_input == None: + image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + bs = image_spatial_embeds.shape[0] + label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', + ) + + logits = self.fc(tagging_embed[0]) + + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(image.device), + torch.tensor(1.0).to(image.device), + torch.zeros(self.num_class).to(image.device)) + + tag = targets.cpu().numpy() + + # delete some tags that may disturb captioning + tag[:, self.delete_tag_index] = 0 + + tag_input = [] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + tag_input.append(' | '.join(token)) + + tag_output = tag_input + + # beam search for text generation(default) + if not sample: + image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) + tag_input_temp = [] + for tag in tag_input: + for i in range(num_beams): + tag_input_temp.append(tag) + tag_input = tag_input_temp + + image_atts = torch.ones(image_embeds.size()[:-1], + dtype=torch.long).to(image.device) + + # tokenizer input tags + tag_input_tokenzier = self.tokenizer(tag_input, + padding='max_length', + truncation=True, + max_length=40, + return_tensors="pt").to( + image.device) + encoder_input_ids = tag_input_tokenzier.input_ids + encoder_input_ids[:, 0] = self.tokenizer.enc_token_id + + # put input tag into image-tag interaction encoder to interact with image embeddings + output_tagembedding = self.tag_encoder( + encoder_input_ids, + attention_mask=tag_input_tokenzier.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # prompt trick for better captioning, followed BLIP + prompt = [self.prompt] * image.size(0) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( + image.device) + input_ids[:, 0] = self.tokenizer.bos_token_id + input_ids = input_ids[:, :-1] + + if sample: + # nucleus sampling + model_kwargs = { + "encoder_hidden_states": output_tagembedding.last_hidden_state, + "encoder_attention_mask": None + } + outputs = self.text_decoder.generate( + input_ids=input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + # beam search (default) + model_kwargs = { + "encoder_hidden_states": output_tagembedding.last_hidden_state, + "encoder_attention_mask": None + } + outputs = self.text_decoder.generate( + input_ids=input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + captions = [] + for output in outputs: + caption = self.tokenizer.decode(output, skip_special_tokens=True) + captions.append(caption[len(self.prompt):]) + if return_tag_predict == True: + return captions, tag_output + return captions + + +# load Tag2Text pretrained model parameters +def tag2text_caption(pretrained='', **kwargs): + model = Tag2Text_Caption(**kwargs) + if pretrained: + if kwargs['vit'] == 'swin_b': + model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) + else: + model, msg = load_checkpoint(model, pretrained) + print('vit:', kwargs['vit']) +# print('msg', msg) + return model + diff --git a/ram/models/utils.py b/ram/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a405577f16dfa4881f90ea7ebe185f3b00f4637c --- /dev/null +++ b/ram/models/utils.py @@ -0,0 +1,278 @@ +import os +import json +import torch +import math + +from torch import nn +from typing import List +from transformers import BertTokenizer +from urllib.parse import urlparse +from timm.models.hub import download_cached_file +from .vit import interpolate_pos_embed +from .swin_transformer import interpolate_relative_pos_embed +from pathlib import Path +CONFIG_PATH=(Path(__file__).resolve().parents[1]) + +def read_json(rpath): + with open(rpath, 'r') as f: + return json.load(f) + + +def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, + base_model_prefix: str, skip_key: str): + uninitialized_encoder_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + logger.info( + f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." + ) + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + uninitialized_encoder_weights: List[str], + skip_key: str, + depth=0, + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" + if hasattr(decoder_pointer, "weight") and skip_key not in module_name: + assert hasattr(encoder_pointer, "weight") + encoder_pointer.weight = decoder_pointer.weight + if hasattr(decoder_pointer, "bias"): + assert hasattr(encoder_pointer, "bias") + encoder_pointer.bias = decoder_pointer.bias + print(module_name + ' is tied') + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert ( + len(encoder_modules) > 0 + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" + + all_encoder_weights = set([ + module_name + "/" + sub_name + for sub_name in encoder_modules.keys() + ]) + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance( + decoder_modules[decoder_name], + type(encoder_modules[encoder_name])) and len( + encoder_modules) != len(decoder_modules): + # this can happen if the name corresponds to the position in a list module list of layers + # in this case the decoder has added a cross-attention that the encoder does not have + # thus skip this step and subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." + ) + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + "/" + name, + uninitialized_encoder_weights, + skip_key, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + "/" + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, + uninitialized_encoder_weights, skip_key) + + +class GroupWiseLinear(nn.Module): + # could be changed to: + # output = torch.einsum('ijk,zjk->ij', x, self.W) + # or output = torch.einsum('ijk,jk->ij', x, self.W[0]) + def __init__(self, num_class, hidden_dim, bias=True): + super().__init__() + self.num_class = num_class + self.hidden_dim = hidden_dim + self.bias = bias + + self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim)) + if bias: + self.b = nn.Parameter(torch.Tensor(1, num_class)) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.W.size(2)) + for i in range(self.num_class): + self.W[0][i].data.uniform_(-stdv, stdv) + if self.bias: + for i in range(self.num_class): + self.b[0][i].data.uniform_(-stdv, stdv) + + def forward(self, x): + # x: B,K,d + x = (self.W * x).sum(-1) + if self.bias: + x = x + self.b + return x + + +def init_tokenizer(): + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']}) + tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] + return tokenizer + + +def create_vit(vit, + image_size, + use_grad_checkpointing=False, + ckpt_layer=0, + drop_path_rate=0): + + assert vit in ['base', 'large'], "vit parameter must be base or large" + if vit == 'base': + vision_width = 768 + visual_encoder = VisionTransformer( + img_size=image_size, + patch_size=16, + embed_dim=vision_width, + depth=12, + num_heads=12, + use_grad_checkpointing=use_grad_checkpointing, + ckpt_layer=ckpt_layer, + drop_path_rate=0 or drop_path_rate) + elif vit == 'large': + vision_width = 1024 + visual_encoder = VisionTransformer( + img_size=image_size, + patch_size=16, + embed_dim=vision_width, + depth=24, + num_heads=16, + use_grad_checkpointing=use_grad_checkpointing, + ckpt_layer=ckpt_layer, + drop_path_rate=0.1 or drop_path_rate) + return visual_encoder, vision_width + + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + + +def load_checkpoint(model, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file(url_or_filename, + check_hash=False, + progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + + state_dict = checkpoint['model'] + + state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed( + state_dict['visual_encoder.pos_embed'], model.visual_encoder) + if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): + state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed( + state_dict['visual_encoder_m.pos_embed'], model.visual_encoder_m) + for key in model.state_dict().keys(): + if key in state_dict.keys(): + if state_dict[key].shape != model.state_dict()[key].shape: + del state_dict[key] + + msg = model.load_state_dict(state_dict, strict=False) + print('load checkpoint from %s' % url_or_filename) + return model, msg + + +def load_checkpoint_swinbase(model, url_or_filename, kwargs): + if kwargs['image_size'] == 224: + vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' + elif kwargs['image_size'] == 384: + vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' + window_size = read_json(vision_config_path)['window_size'] + print('--------------') + print(url_or_filename) + print('--------------') + if is_url(url_or_filename): + cached_file = download_cached_file(url_or_filename, + check_hash=False, + progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + + state_dict = checkpoint['model'] + + for k in list(state_dict.keys()): + if 'relative_position_bias_table' in k: + dst_num_pos = (2 * window_size - 1)**2 + state_dict[k] = interpolate_relative_pos_embed(state_dict[k], + dst_num_pos, + param_name=k) + elif ('relative_position_index' in k) or ('attn_mask' in k): + del state_dict[k] + elif "vision_multi" in k: + state_dict[k.replace("vision_multi", + "tagging_head")] = state_dict.pop(k) + + msg = model.load_state_dict(state_dict, strict=False) + print('load checkpoint from %s' % url_or_filename) + return model, msg + + +def load_checkpoint_swinlarge(model, url_or_filename, kwargs): + if kwargs['image_size'] == 224: + vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' + elif kwargs['image_size'] == 384: + vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' + window_size = read_json(vision_config_path)['window_size'] + print('--------------') + print(url_or_filename) + print('--------------') + if is_url(url_or_filename): + cached_file = download_cached_file(url_or_filename, + check_hash=False, + progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + + state_dict = checkpoint['model'] + + for k in list(state_dict.keys()): + if 'relative_position_bias_table' in k: + dst_num_pos = (2 * window_size - 1)**2 + state_dict[k] = interpolate_relative_pos_embed(state_dict[k], + dst_num_pos, + param_name=k) + elif ('relative_position_index' in k) or ('attn_mask' in k): + del state_dict[k] + elif "vision_multi" in k: + state_dict[k.replace("vision_multi", + "tagging_head")] = state_dict.pop(k) + + msg = model.load_state_dict(state_dict, strict=False) + print('load checkpoint from %s' % url_or_filename) + return model, msg + + diff --git a/ram/models/vit.py b/ram/models/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..cec3d8e08ed4451d65392feb2e9f4848d1ef3899 --- /dev/null +++ b/ram/models/vit.py @@ -0,0 +1,305 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on timm code base + * https://github.com/rwightman/pytorch-image-models/tree/master/timm +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.vision_transformer import _cfg, PatchEmbed +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_, DropPath +from timm.models.helpers import named_apply, adapt_input_conv + +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_gradients = None + self.attention_map = None + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def forward(self, x, register_hook=False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def forward(self, x, register_hook=False): + x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, + use_grad_checkpointing=False, ckpt_layer=0): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) + ) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x, register_blk=-1): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:,:x.size(1),:] + x = self.pos_drop(x) + + for i,blk in enumerate(self.blocks): + x = blk(x, register_blk==i) + x = self.norm(x) + + return x + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) +# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: +# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) +# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) +# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: +# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) +# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): + # interpolate position embedding + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = visual_encoder.patch_embed.num_patches + num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + + if orig_size!=new_size: + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) + + return new_pos_embed + else: + return pos_embed_checkpoint \ No newline at end of file diff --git a/ram/utils/__init__.py b/ram/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f58b7cb0ab2b96e9ae1f8eaca696e3c7f8281b89 --- /dev/null +++ b/ram/utils/__init__.py @@ -0,0 +1 @@ +from .openset_utils import build_openset_label_embedding \ No newline at end of file diff --git a/ram/utils/openset_utils.py b/ram/utils/openset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c39e84fe87b200402e17d4dedb06681dff7d7a54 --- /dev/null +++ b/ram/utils/openset_utils.py @@ -0,0 +1,331 @@ + + + +import torch +import torch.nn as nn +from clip import clip + + +def article(name): + return "an" if name[0] in "aeiou" else "a" + + +def processed_name(name, rm_dot=False): + # _ for lvis + # / for obj365 + res = name.replace("_", " ").replace("/", " or ").lower() + if rm_dot: + res = res.rstrip(".") + return res + + +single_template = ["a photo of a {}."] + +multiple_templates = [ + "There is {article} {} in the scene.", + "There is the {} in the scene.", + "a photo of {article} {} in the scene.", + "a photo of the {} in the scene.", + "a photo of one {} in the scene.", + "itap of {article} {}.", + "itap of my {}.", # itap: I took a picture of + "itap of the {}.", + "a photo of {article} {}.", + "a photo of my {}.", + "a photo of the {}.", + "a photo of one {}.", + "a photo of many {}.", + "a good photo of {article} {}.", + "a good photo of the {}.", + "a bad photo of {article} {}.", + "a bad photo of the {}.", + "a photo of a nice {}.", + "a photo of the nice {}.", + "a photo of a cool {}.", + "a photo of the cool {}.", + "a photo of a weird {}.", + "a photo of the weird {}.", + "a photo of a small {}.", + "a photo of the small {}.", + "a photo of a large {}.", + "a photo of the large {}.", + "a photo of a clean {}.", + "a photo of the clean {}.", + "a photo of a dirty {}.", + "a photo of the dirty {}.", + "a bright photo of {article} {}.", + "a bright photo of the {}.", + "a dark photo of {article} {}.", + "a dark photo of the {}.", + "a photo of a hard to see {}.", + "a photo of the hard to see {}.", + "a low resolution photo of {article} {}.", + "a low resolution photo of the {}.", + "a cropped photo of {article} {}.", + "a cropped photo of the {}.", + "a close-up photo of {article} {}.", + "a close-up photo of the {}.", + "a jpeg corrupted photo of {article} {}.", + "a jpeg corrupted photo of the {}.", + "a blurry photo of {article} {}.", + "a blurry photo of the {}.", + "a pixelated photo of {article} {}.", + "a pixelated photo of the {}.", + "a black and white photo of the {}.", + "a black and white photo of {article} {}.", + "a plastic {}.", + "the plastic {}.", + "a toy {}.", + "the toy {}.", + "a plushie {}.", + "the plushie {}.", + "a cartoon {}.", + "the cartoon {}.", + "an embroidered {}.", + "the embroidered {}.", + "a painting of the {}.", + "a painting of a {}.", +] + + +openimages_rare_unseen = ['Aerial photography', +'Aircraft engine', +'Ale', +'Aloe', +'Amphibian', +'Angling', +'Anole', +'Antique car', +'Arcade game', +'Arthropod', +'Assault rifle', +'Athletic shoe', +'Auto racing', +'Backlighting', +'Bagpipes', +'Ball game', +'Barbecue chicken', +'Barechested', +'Barquentine', +'Beef tenderloin', +'Billiard room', +'Billiards', +'Bird of prey', +'Black swan', +'Black-and-white', +'Blond', +'Boating', +'Bonbon', +'Bottled water', +'Bouldering', +'Bovine', +'Bratwurst', +'Breadboard', +'Briefs', +'Brisket', +'Brochette', +'Calabaza', +'Camera operator', +'Canola', +'Childbirth', +'Chordophone', +'Church bell', +'Classical sculpture', +'Close-up', +'Cobblestone', +'Coca-cola', +'Combat sport', +'Comics', +'Compact car', +'Computer speaker', +'Cookies and crackers', +'Coral reef fish', +'Corn on the cob', +'Cosmetics', +'Crocodilia', +'Digital camera', +'Dishware', +'Divemaster', +'Dobermann', +'Dog walking', +'Domestic rabbit', +'Domestic short-haired cat', +'Double-decker bus', +'Drums', +'Electric guitar', +'Electric piano', +'Electronic instrument', +'Equestrianism', +'Equitation', +'Erinaceidae', +'Extreme sport', +'Falafel', +'Figure skating', +'Filling station', +'Fire apparatus', +'Firearm', +'Flatbread', +'Floristry', +'Forklift truck', +'Freight transport', +'Fried food', +'Fried noodles', +'Frigate', +'Frozen yogurt', +'Frying', +'Full moon', +'Galleon', +'Glacial landform', +'Gliding', +'Go-kart', +'Goats', +'Grappling', +'Great white shark', +'Gumbo', +'Gun turret', +'Hair coloring', +'Halter', +'Headphones', +'Heavy cruiser', +'Herding', +'High-speed rail', +'Holding hands', +'Horse and buggy', +'Horse racing', +'Hound', +'Hunting knife', +'Hurdling', +'Inflatable', +'Jackfruit', +'Jeans', +'Jiaozi', +'Junk food', +'Khinkali', +'Kitesurfing', +'Lawn game', +'Leaf vegetable', +'Lechon', +'Lifebuoy', +'Locust', +'Lumpia', +'Luxury vehicle', +'Machine tool', +'Medical imaging', +'Melee weapon', +'Microcontroller', +'Middle ages', +'Military person', +'Military vehicle', +'Milky way', +'Miniature Poodle', +'Modern dance', +'Molluscs', +'Monoplane', +'Motorcycling', +'Musical theatre', +'Narcissus', +'Nest box', +'Newsagent\'s shop', +'Nile crocodile', +'Nordic skiing', +'Nuclear power plant', +'Orator', +'Outdoor shoe', +'Parachuting', +'Pasta salad', +'Peafowl', +'Pelmeni', +'Perching bird', +'Performance car', +'Personal water craft', +'Pit bull', +'Plant stem', +'Pork chop', +'Portrait photography', +'Primate', +'Procyonidae', +'Prosciutto', +'Public speaking', +'Racewalking', +'Ramen', +'Rear-view mirror', +'Residential area', +'Ribs', +'Rice ball', +'Road cycling', +'Roller skating', +'Roman temple', +'Rowing', +'Rural area', +'Sailboat racing', +'Scaled reptile', +'Scuba diving', +'Senior citizen', +'Shallot', +'Shinto shrine', +'Shooting range', +'Siberian husky', +'Sledding', +'Soba', +'Solar energy', +'Sport climbing', +'Sport utility vehicle', +'Steamed rice', +'Stemware', +'Sumo', +'Surfing Equipment', +'Team sport', +'Touring car', +'Toy block', +'Trampolining', +'Underwater diving', +'Vegetarian food', +'Wallaby', +'Water polo', +'Watercolor paint', +'Whiskers', +'Wind wave', +'Woodwind instrument', +'Yakitori', +'Zeppelin'] + + +def build_openset_label_embedding(): + categories = openimages_rare_unseen + model, _ = clip.load("ViT-B/16") + templates = multiple_templates + + run_on_gpu = torch.cuda.is_available() + + with torch.no_grad(): + openset_label_embedding = [] + for category in categories: + texts = [ + template.format( + processed_name(category, rm_dot=True), article=article(category) + ) + for template in templates + ] + texts = [ + "This is " + text if text.startswith("a") or text.startswith("the") else text + for text in texts + ] + texts = clip.tokenize(texts) # tokenize + if run_on_gpu: + texts = texts.cuda() + model = model.cuda() + text_embeddings = model.encode_text(texts) + text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) + text_embedding = text_embeddings.mean(dim=0) + text_embedding /= text_embedding.norm() + openset_label_embedding.append(text_embedding) + openset_label_embedding = torch.stack(openset_label_embedding, dim=1) + if run_on_gpu: + openset_label_embedding = openset_label_embedding.cuda() + + openset_label_embedding = openset_label_embedding.t() + return openset_label_embedding, categories + + + + diff --git a/tagging_model.py b/tagging_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a3aa568dab62e2972c6dbc51fe1899c345344e98 --- /dev/null +++ b/tagging_model.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from torchvision.transforms import transforms + +from ram.models import ram + + +class TaggingModule(nn.Module): + def __init__(self, device='cpu'): + super().__init__() + self.device = device + image_size = 384 + self.transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # load RAM Model + self.ram = ram( + pretrained='checkpoints/ram_swin_large_14m.pth', + image_size=image_size, + vit='swin_l' + ).eval().to(device) + print('==> Tagging Module Loaded.') + + @torch.no_grad() + def forward(self, original_image): + print('==> Tagging...') + img = self.transform(original_image).unsqueeze(0).to(self.device) + tags, tags_chinese = self.ram.generate_tag(img) + print('==> Tagging results: {}'.format(tags[0])) + return [tag for tag in tags[0].split(' | ')]