diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..f45524f85ad4938600da4d5ee5c720b15923a793
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+*.pth
+*.bin
\ No newline at end of file
diff --git a/README.md b/README.md
index 3525344f2d93b9d5e94b509c915278b141cb767e..ab42a4fafcc332714ac78ffde942de03e091f4e8 100644
--- a/README.md
+++ b/README.md
@@ -9,4 +9,33 @@ app_file: app.py
pinned: false
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
+
+
+## Acknowledgement 💌
+
+- [Osprey](https://github.com/CircleRadon/Osprey) and [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA): We build this repostory based on them.
+- [RAISE](http://loki.disi.unitn.it/RAISE/): The Dist. images in SEAGULL-100w are constructed based on this dataset.
+- [SAM](https://segment-anything.com/) and [SEEM](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once): The mask-based ROIs are generated using these two awesome works. And SAM are used to get the segmentation result in the demo.
+- [TOPIQ](https://github.com/chaofengc/IQA-PyTorch): The quality scores and importance scores for ROIs are generated using this great FR-IQA.
+
+
+## Citation 🖊️
+If our work is useful to your research, we will be grateful for you to cite our paper:
+```
+@misc{chen2024seagull,
+ title={SEAGULL: No-reference Image Quality Assessment for Regions of Interest via Vision-Language Instruction Tuning},
+ author={Zewen Chen and Juan Wang and Wen Wang and Sunhan Xu and Hang Xiong and Yun Zeng and Jian Guo and Shuxun Wang and Chunfeng Yuan and Bing Li and Weiming Hu},
+ year={2024},
+ eprint={2411.10161},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+ url={https://arxiv.org/abs/2411.10161},
+}
+```
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f96affb9084634b9ab9e87bc5aa11511cfd7e0d
--- /dev/null
+++ b/app.py
@@ -0,0 +1,62 @@
+import argparse
+from demo.UI import Main_ui
+
+if __name__ == '__main__':
+ import subprocess
+ import sys
+ def run_command(command):
+ subprocess.check_call([sys.executable, '-m'] + command.split(), shell=False)
+
+ # Install the package in editable mode
+ run_command("pip install -e .")
+
+ # Install NVM (Node Version Manager)
+ run_command("curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.3/install.sh | bash")
+
+ # Source the appropriate shell configuration file
+ run_command("source ~/.bashrc") # You can change to ~/.zshrc based on your shell
+
+ # Install Node.js version 18.16.0
+ run_command("nvm install v18.16.0")
+
+ # Install pnpm (package manager)
+ run_command("curl -fsSL https://get.pnpm.io/install.sh | sh -")
+
+ # Source the shell configuration file again (for pnpm)
+ run_command("source ~/.bashrc") # You can change to ~/.zshrc based on your shell
+
+ # Verify if pnpm was installed correctly
+ run_command("pnpm --version")
+
+ # Clone the Gradio BBox repository
+ run_command("git clone https://github.com/chencn2020/gradio-bbox.git")
+
+ # Change into the cloned repository directory
+ run_command("cd gradio-bbox")
+
+ # Build frontend
+ run_command("bash scripts/build_frontend.sh")
+
+
+
+ # Change back to the previous directory
+ run_command("cd ..")
+
+ # Install the package again in editable mode
+ run_command("pip install -e .")
+
+ # Install Segment Anything repository from GitHub
+ run_command("pip install git+https://github.com/facebookresearch/segment-anything.git")
+
+ # Download the model checkpoint
+ run_command("curl -o ./checkpoints/sam_vit_b_01ec64.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth")
+
+
+
+ parser = argparse.ArgumentParser(description='SEAGULL', formatter_class=argparse.RawTextHelpFormatter)
+ parser.add_argument('--model', help='path to seagull model', default='Zevin2023/SEAGULL-7B')
+ parser.add_argument('--example_path', help='path to examples', default='./imgs/Examples')
+ args = parser.parse_args()
+
+ demo = Main_ui(args).load_demo()
+ demo.launch(server_port=7530)
\ No newline at end of file
diff --git a/demo/UI.py b/demo/UI.py
new file mode 100644
index 0000000000000000000000000000000000000000..2826c2a5f7295460ec07916de977cd3cd6545c59
--- /dev/null
+++ b/demo/UI.py
@@ -0,0 +1,143 @@
+import os
+import gradio as gr
+from demo.sam_inference import SAM_Inference
+from demo.seagull_inference import Seagull
+from demo.mask_utils import ImageSketcher
+
+class Main_ui():
+ def __init__(self, args) -> None:
+ self.args = args
+ self.seagull = Seagull(model_path=args.model)
+
+ self.example_list = self.load_example()
+ self.sam = SAM_Inference()
+ # self.sam_predictor = get_sam_predictor()
+ # self.mask_generator = get_mask_generator()
+
+ def load_example(self):
+ examples = []
+ for file in sorted(os.listdir(self.args.example_path)):
+ examples.append([os.path.join(self.args.example_path, file)])
+ return examples
+
+ def load_demo(self):
+ with gr.Blocks() as demo:
+ preprocessed_img = gr.State(value=None)
+ binary_mask = gr.State(value=None)
+
+ with gr.Row():
+ gr.Markdown("""
+
+
+ ## 🔔 Usage
+
+ Firstly, you need to upload an image and choose the analyse types **(quality score, importance score and distortion analysis)**.
+
+ Then you can click **(points)** or pull a frame **(bbox)** on the image to indicate the region of interest (ROIs).
+
+ After that, this demo process the following steps:
+
+ > 1. SAM extracts the mask-based ROIs based on your clicked points or frame.
+
+ > 2. Based on the uploaded image and mask-based ROIs, SEAGULL analyses the quality of the ROIs.
+
+ """)
+
+ with gr.TabItem("Mask-based ROIs (Points)"):
+ with gr.Row():
+ input_image_ponit = gr.Image(type="numpy", label='Input image', height=512) # input image
+ output_mask_ponit = gr.Image(label='Mask-based ROI', height=512) # output binary mask
+
+ with gr.Row():
+ output_mask_point_on_img = gr.Image(label='Mask on image', height=512) # mask on image for better view
+
+ with gr.Column():
+ radio_point = gr.Radio(label='Analysis type', choices=['Quality Score', 'Importance Score', 'Distortion Analysis'], value='Quality Score')
+ output_text_point = gr.Textbox(label='Analysis Results')
+ point_seg_button = gr.Button('Analysis')
+
+ point_example = gr.Dataset(label='Examples', components=[input_image_ponit], samples=self.example_list)
+
+ with gr.TabItem("Mask-based ROIs (BBox)"):
+ with gr.Row():
+ input_image_BBOX = ImageSketcher(type="numpy", label='Input image', height=512)
+ output_mask_BBOX = gr.Image(label='Mask-based ROI', height=512)
+
+ with gr.Row():
+ output_BBOX_mask_on_img = gr.Image(label='Mask on image', height=512)
+
+ with gr.Column():
+ radio_BBOX = gr.Radio(label='Analysis type', choices=['Quality Score', 'Importance Score', 'Distortion Analysis'], value='Quality Score')
+ output_text_BBOX = gr.Textbox(label='ROI Quality Analysis')
+ box_seg_button = gr.Button('Generate mask and analysis')
+ box_analyse_button = gr.Button('Analysis')
+
+ BBOX_example = gr.Dataset(label='Examples', components=[input_image_BBOX], samples=self.example_list)
+
+ # click point
+ input_image_ponit.upload(
+ self.seagull.init_image,
+ [input_image_ponit],
+ [preprocessed_img, input_image_ponit, input_image_BBOX]
+ )
+
+ point_example.click(
+ self.seagull.init_image,
+ [point_example],
+ [preprocessed_img, input_image_ponit, input_image_BBOX]
+ )
+
+ # after clicking on the image
+ input_image_ponit.select(
+ self.sam.img_select_point,
+ [preprocessed_img],
+ [input_image_ponit, output_mask_ponit, output_mask_point_on_img, binary_mask]
+ ).then(
+ self.seagull.seagull_predict,
+ [preprocessed_img, binary_mask, radio_point],
+ [output_text_point]
+ )
+
+ point_seg_button.click(
+ self.seagull.seagull_predict,
+ [preprocessed_img, binary_mask, radio_point],
+ [output_text_point]
+ )
+
+ # draw frame
+ input_image_BBOX.upload(
+ self.seagull.init_image,
+ [input_image_BBOX],
+ [preprocessed_img, input_image_ponit, input_image_BBOX]
+ )
+
+ BBOX_example.click(
+ self.seagull.init_image,
+ [BBOX_example],
+ [preprocessed_img, input_image_ponit, input_image_BBOX]
+ )
+
+ # after drawing a frame on the image
+ input_image_BBOX.select(
+ self.sam.gen_box_seg,
+ [input_image_BBOX],
+ [output_mask_BBOX, output_BBOX_mask_on_img, binary_mask]
+ )
+
+ box_seg_button.click(
+ self.sam.gen_box_seg,
+ [input_image_BBOX],
+ [output_mask_BBOX, output_BBOX_mask_on_img, binary_mask]
+ ).then(
+ self.seagull.seagull_predict,
+ [preprocessed_img, binary_mask, radio_BBOX],
+ [output_text_BBOX]
+ )
+
+ box_analyse_button.click(
+ self.seagull.seagull_predict,
+ [preprocessed_img, binary_mask, radio_BBOX],
+ [output_text_BBOX]
+ )
+
+ return demo
\ No newline at end of file
diff --git a/demo/__pycache__/UI.cpython-310.pyc b/demo/__pycache__/UI.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..264862da1e1392304018677e3c0d7f47f8a1fedb
Binary files /dev/null and b/demo/__pycache__/UI.cpython-310.pyc differ
diff --git a/demo/__pycache__/mask_utils.cpython-310.pyc b/demo/__pycache__/mask_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f65b53f287fb29600b64d1292916ec71e821d67
Binary files /dev/null and b/demo/__pycache__/mask_utils.cpython-310.pyc differ
diff --git a/demo/__pycache__/sam_inference.cpython-310.pyc b/demo/__pycache__/sam_inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ad4434b1d126e3acb959a0838aed252ed3cdb04
Binary files /dev/null and b/demo/__pycache__/sam_inference.cpython-310.pyc differ
diff --git a/demo/__pycache__/seagull_inference.cpython-310.pyc b/demo/__pycache__/seagull_inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ef54f8821aec386cdae7cb04dd92fd2011cc1d3
Binary files /dev/null and b/demo/__pycache__/seagull_inference.cpython-310.pyc differ
diff --git a/demo/mask_utils.py b/demo/mask_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..54e0744f85efdc17f288e22c1e689ab089f16c2d
--- /dev/null
+++ b/demo/mask_utils.py
@@ -0,0 +1,144 @@
+import cv2
+from PIL import Image
+import numpy as np
+import torch
+import gradio as gr
+
+class ImageSketcher(gr.Image):
+ """
+ Code is from https://github.com/jshilong/GPT4RoI/blob/7c157b5f33914f21cfbc804fb301d3ce06324193/gpt4roi/app.py#L365
+
+ Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
+ """
+
+ is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
+
+ def __init__(self, **kwargs):
+ super().__init__(tool='boxes', **kwargs)
+
+ def preprocess(self, x):
+ if x is None:
+ return x
+ if self.tool == 'boxes' and self.source in ['upload', 'webcam']:
+ if isinstance(x, str):
+ x = {'image': x, 'boxes': []}
+ else:
+ assert isinstance(x, dict)
+ assert isinstance(x['image'], str)
+ assert isinstance(x['boxes'], list)
+ x = super().preprocess(x)
+ return x
+
+def process_mask_to_show(mask):
+ '''
+ Process the mask to show on the gradio.Image
+ '''
+ mask = np.array(mask > 0.1, dtype=np.uint8) * 255
+ mask_stacked = np.stack([mask] * 3, axis=-1)
+
+ return mask_stacked
+
+def img_add_masks(img_, colored_mask, mask, linewidth=2):
+ if type(img_) is np.ndarray:
+ img = Image.fromarray(img_, mode='RGB').convert('RGBA')
+ else:
+ img = img_.copy()
+ h, w = img.height, img.width
+ # contour
+ temp = np.zeros((h, w, 1))
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+ cv2.drawContours(temp, contours, -1, (255, 255, 255), linewidth)
+ color = np.array([1, 1, 1, 1])
+ contour_mask = temp * color.reshape(1, 1, -1)
+
+ overlay_inner = Image.fromarray(colored_mask.astype(np.uint8), 'RGBA')
+ img.paste(overlay_inner, (0, 0), overlay_inner)
+
+ overlay_contour = Image.fromarray(contour_mask.astype(np.uint8), 'RGBA')
+ img.paste(overlay_contour, (0, 0), overlay_contour)
+ return img
+
+def gen_colored_masks(
+ annotation,
+ random_color=False,
+):
+ """
+ Code is largely based on https://github.com/CASIA-IVA-Lab/FastSAM/blob/4d153e909f0ad9c8ecd7632566e5a24e21cf0071/utils/tools_gradio.py#L130
+ """
+ device = annotation.device
+ mask_sum = annotation.shape[0]
+ height = annotation.shape[1]
+ weight = annotation.shape[2]
+ areas = torch.sum(annotation, dim=(1, 2))
+ sorted_indices = torch.argsort(areas, descending=False)
+ annotation = annotation[sorted_indices]
+
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
+ if random_color:
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
+ else:
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
+ [30 / 255, 144 / 255, 255 / 255]
+ ).to(device)
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
+ visual = torch.cat([color, transparency], dim=-1)
+ mask_image = torch.unsqueeze(annotation, -1) * visual
+
+ mask = torch.zeros((height, weight, 4)).to(device)
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
+
+ mask[h_indices, w_indices, :] = mask_image[indices]
+ mask_cpu = mask.cpu().numpy()
+
+ return mask_cpu, sorted_indices
+
+def mask_foreground(mask, trans=0.6, random_color=True):
+ if random_color:
+ color = np.concatenate([np.random.random(3) * 255, np.array([trans * 255])], axis=0)
+ else:
+ color = np.array([30, 144, 255, trans * 255])
+ h, w = mask.shape[-2:]
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
+
+ return mask_image
+
+
+def mask_background(mask, trans=0.5):
+ h, w = mask.shape[-2:]
+ mask_image = (1 - mask.reshape(h, w, 1)) * np.array([0, 0, 0, trans * 255])
+
+ return mask_image
+
+
+def mask_select_point(all_masks, output_mask_2_raw, mask_order, evt: gr.SelectData):
+ h, w = output_mask_2_raw.height, output_mask_2_raw.width
+ pointed_mask = None
+ for i in range(len(mask_order)):
+ idx = mask_order[i]
+ msk = all_masks[idx]
+ if msk[evt.index[1], evt.index[0]] == 1:
+ pointed_mask = msk.copy()
+ break
+
+ if pointed_mask is not None:
+ contours, hierarchy = cv2.findContours(pointed_mask.astype("uint8"), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+ ret = output_mask_2_raw.copy()
+
+ temp = np.zeros((h, w, 1))
+ contours, _ = cv2.findContours(msk.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+ cv2.drawContours(temp, contours, -1, (255, 255, 255), 3)
+ color = np.array([1, 1, 1, 1])
+ contour_mask = temp * color.reshape(1, 1, -1)
+
+ colored_mask = mask_background(pointed_mask)
+
+ overlay_inner = Image.fromarray(colored_mask.astype(np.uint8), 'RGBA')
+ ret.paste(overlay_inner, (0, 0), overlay_inner)
+
+ overlay_contour = Image.fromarray(contour_mask.astype(np.uint8), 'RGBA')
+ ret.paste(overlay_contour, (0, 0), overlay_contour)
+
+ return ret, pointed_mask
+ else:
+ return output_mask_2_raw, None
diff --git a/demo/sam_inference.py b/demo/sam_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ef7e08433105ed93745729f3a2529ad9854501e
--- /dev/null
+++ b/demo/sam_inference.py
@@ -0,0 +1,102 @@
+import gc
+
+import numpy as np
+import torch
+from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
+import gradio as gr
+import cv2
+from demo.mask_utils import *
+
+class SAM_Inference:
+ def __init__(self, model_type='vit_b', device='cuda') -> None:
+ models = {
+ 'vit_b': './checkpoints/sam_vit_b_01ec64.pth',
+ 'vit_l': './checkpoints/sam_vit_l_0b3195.pth',
+ 'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
+ }
+
+ sam = sam_model_registry[model_type](checkpoint=models[model_type])
+ sam = sam.to(device)
+
+ self.predictor = SamPredictor(sam)
+ self.mask_generator = SamAutomaticMaskGenerator(model=sam)
+
+ def img_select_point(self, original_img: np.ndarray, evt: gr.SelectData):
+ img = original_img.copy()
+ sel_pix = [(evt.index, 1)] # append the foreground_point
+
+ masks = self.run_inference(original_img, sel_pix)
+ for point, label in sel_pix:
+ cv2.circle(img, point, 5, (240, 240, 240), -1, 0)
+ cv2.circle(img, point, 5, (30, 144, 255), 2, 0)
+
+ mask = masks[0][0]
+ colored_mask = mask_foreground(mask)
+ res = img_add_masks(original_img, colored_mask, mask)
+ return img, process_mask_to_show(mask), res, mask
+
+ def gen_box_seg(self, inp):
+ if inp is None:
+ raise gr.Error("Please upload an image first!")
+ image = inp['image']
+ if len(inp['boxes']) == 0:
+ raise gr.Error("Please clear the raw boxes and draw a box first!")
+ boxes = inp['boxes'][-1]
+
+ input_box = np.array([boxes[0], boxes[1], boxes[2], boxes[3]]).astype(int)
+
+ masks = self.predict_box(image, input_box)
+
+ mask = masks[0][0]
+ colored_mask = mask_foreground(mask)
+ res = img_add_masks(image, colored_mask, mask)
+
+ return process_mask_to_show(mask), res, mask
+
+ def run_inference(self, input_x, selected_points):
+ if len(selected_points) == 0:
+ return []
+
+ self.predictor.set_image(input_x)
+
+ points = torch.Tensor(
+ [p for p, _ in selected_points]
+ ).to(self.predictor.device).unsqueeze(0)
+
+ labels = torch.Tensor(
+ [int(l) for _, l in selected_points]
+ ).to(self.predictor.device).unsqueeze(0)
+
+ transformed_points = self.predictor.transform.apply_coords_torch(
+ points, input_x.shape[:2])
+
+ # predict segmentation according to the boxes
+ masks, scores, logits = self.predictor.predict_torch(
+ point_coords=transformed_points,
+ point_labels=labels,
+ multimask_output=False,
+ )
+ masks = masks.cpu().detach().numpy()
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return masks
+
+ def predict_box(self, input_x, input_box):
+ self.predictor.set_image(input_x)
+
+ input_boxes = torch.tensor(input_box[None, :], device=self.predictor.device)
+ transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, input_x.shape[:2])
+
+ masks, _, _ = self.predictor.predict_torch(
+ point_coords=None,
+ point_labels=None,
+ boxes=transformed_boxes,
+ multimask_output=False
+ )
+ masks = masks.cpu().detach().numpy()
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ return masks
diff --git a/demo/seagull_inference.py b/demo/seagull_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7cac2376869a98f40a7f1857e91199a3151874d
--- /dev/null
+++ b/demo/seagull_inference.py
@@ -0,0 +1,163 @@
+import torch
+from seagull.utils import disable_torch_init
+from transformers import AutoTokenizer, CLIPImageProcessor
+from seagull.model.language_model.seagull_llama import SeagullLlamaForCausalLM
+from seagull.mm_utils import tokenizer_image_token
+from seagull.conversation import conv_templates, SeparatorStyle
+from seagull.constants import IMAGE_TOKEN_INDEX
+from seagull.train.train import DataArguments
+
+from functools import partial
+import os
+import numpy as np
+import cv2
+from typing import List
+from PIL import Image
+
+class Seagull():
+ def __init__(self, model_path, device='cuda'):
+ disable_torch_init()
+ model_path = os.path.expanduser(model_path)
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=2048, padding_side="right", use_fast=True)
+ self.model = SeagullLlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16,).to(device)
+ self.tokenizer.pad_token = self.tokenizer.unk_token
+
+ self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":512}, resample=3, do_center_crop=True, crop_size={"height": 512, "width": 512},
+ do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
+
+ spi_tokens = ['', '']
+ self.tokenizer.add_tokens(spi_tokens, special_tokens=True)
+
+ for m in self.model.modules():
+ m.tokenizer = self.tokenizer
+
+ vision_tower = self.model.get_vision_tower()
+ if not vision_tower.is_loaded:
+ vision_tower.load_model()
+ vision_tower.to(dtype=torch.float16, device=device)
+
+ begin_str = "\nThis provides an overview of the image.\n Please answer the following questions about the provided region. Note: Distortions include: blur, colorfulness, compression, contrast exposure and noise.\n Here is the region . "
+
+ instruction = {
+ 'distortion analysis': 'Provide the distortion type of this region.',
+ 'quality score': 'Analyze the quality of this region.',
+ 'importance score': 'Consider the impact of this region on the overall image quality. Analyze its importance to the overall image quality.'
+ }
+
+ self.ids_input = {}
+ for ins_type, ins in instruction.items():
+ conv = conv_templates['seagull_v1'].copy()
+ qs = begin_str + ins
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+ self.ids_input[ins_type] = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
+
+ self.stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
+
+ def init_image(self, img):
+ if isinstance(img, dict):
+ img = img['image']
+ elif isinstance(img, List):
+ img = cv2.imread(img[0])
+ img = img[:, :, ::-1]
+ h_, w_ = img.shape[:2]
+ if h_ > 512:
+ ratio = 512 / h_
+ new_h, new_w = int(h_ * ratio), int(w_ * ratio)
+ preprocessed_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
+ else:
+ preprocessed_img = img.copy()
+
+ return (preprocessed_img, preprocessed_img, preprocessed_img)
+
+ def preprocess(self, img):
+ image = self.image_processor.preprocess(img,
+ do_center_crop=False,
+ return_tensors='pt')['pixel_values'][0]
+
+ image = torch.nn.functional.interpolate(image.unsqueeze(0),
+ size=(512, 512),
+ mode='bilinear',
+ align_corners=False).squeeze(0)
+
+ return image
+
+ def seagull_predict(self, img, mask, instruct_type):
+ image = self.preprocess(img)
+
+ mask = np.array(mask, dtype=np.int)
+ ys, xs = np.where(mask > 0)
+ if len(xs) > 0 and len(ys) > 0:
+ # Find the minimal bounding rectangle for the entire mask
+ x_min, x_max = np.min(xs), np.max(xs)
+ y_min, y_max = np.min(ys), np.max(ys)
+ w1 = x_max - x_min
+ h1 = y_max - y_min
+
+ bounding_box = (x_min, y_min, w1, h1)
+ else:
+ bounding_box = None
+
+ mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)
+ mask = np.array(mask > 0.1, dtype=np.uint8)
+ masks = torch.Tensor(mask).unsqueeze(0).to(self.model.device)
+
+ input_ids = self.ids_input[instruct_type.lower()]
+
+ x1, y1, w1, h1 = list(map(int, bounding_box)) # x y w h
+ cropped_img = img[y1:y1 + h1, x1:x1 + w1]
+ cropped_img = Image.fromarray(cropped_img)
+ cropped_img = self.preprocess(cropped_img)
+
+ with torch.inference_mode():
+
+ self.model.orig_forward = self.model.forward
+ self.model.forward = partial(self.model.orig_forward,
+ img_metas=[None],
+ masks=[masks.half()],
+ cropped_img=cropped_img.unsqueeze(0)
+ )
+ output_ids = self.model.generate(
+ input_ids,
+ images=image.unsqueeze(0).half().to(self.model.device),
+ do_sample=False,
+ temperature=1,
+ max_new_tokens=2048,
+ use_cache=True,
+ num_beams=1,
+ top_k = 0, # 不进行topk
+ top_p = 1, # 累计概率为
+ )
+
+ self.model.forward = self.model.orig_forward
+
+ input_token_len = input_ids.shape[1]
+ n_diff_input_output = (
+ input_ids != output_ids[:, :input_token_len]).sum().item()
+ if n_diff_input_output > 0:
+ print(
+ f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
+ outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
+ skip_special_tokens=True)[0]
+
+ outputs = outputs.strip()
+ if outputs.endswith(self.stop_str):
+ outputs = outputs[:-len(self.stop_str)]
+ outputs = outputs.strip()
+ if ':' in outputs:
+ outputs = outputs.split(':')[1]
+
+ outputs_list = outputs.split('.')
+ outputs_list_final = []
+ outputs_str = ''
+ for output in outputs_list:
+ if output not in outputs_list_final:
+ if output=='':
+ continue
+ outputs_list_final.append(output)
+ outputs_str+=output+'.'
+ else:
+ break
+ return outputs_str
\ No newline at end of file
diff --git a/imgs/.DS_Store b/imgs/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..1e34149506e88262cee995f941024694f014652f
Binary files /dev/null and b/imgs/.DS_Store differ
diff --git a/imgs/Examples/1.png b/imgs/Examples/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..e744e2810370a36f09a6c3fd1b10c089ee5fb904
Binary files /dev/null and b/imgs/Examples/1.png differ
diff --git a/imgs/Examples/2.png b/imgs/Examples/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..427f107e2ef0b67ec4f71e1fe0813f6c77565aa7
Binary files /dev/null and b/imgs/Examples/2.png differ
diff --git a/seagull/__init__.py b/seagull/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9117d4278fcd90b028ca3afe75935e7813810637
--- /dev/null
+++ b/seagull/__init__.py
@@ -0,0 +1 @@
+from .model import SeagullLlamaForCausalLM
diff --git a/seagull/__pycache__/__init__.cpython-310.pyc b/seagull/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6888a3826bf029a2bf99132b38e4306dda328af8
Binary files /dev/null and b/seagull/__pycache__/__init__.cpython-310.pyc differ
diff --git a/seagull/__pycache__/constants.cpython-310.pyc b/seagull/__pycache__/constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14ac8b1a6edcd2c2d2f5932c52d39a32039b7c3e
Binary files /dev/null and b/seagull/__pycache__/constants.cpython-310.pyc differ
diff --git a/seagull/__pycache__/conversation.cpython-310.pyc b/seagull/__pycache__/conversation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e1a5a99fca7d03b73349ed49ceaac7935448cbc
Binary files /dev/null and b/seagull/__pycache__/conversation.cpython-310.pyc differ
diff --git a/seagull/__pycache__/mm_utils.cpython-310.pyc b/seagull/__pycache__/mm_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c09ca03a51bc5e12172f5f2c3d16ac55fa155d2
Binary files /dev/null and b/seagull/__pycache__/mm_utils.cpython-310.pyc differ
diff --git a/seagull/__pycache__/utils.cpython-310.pyc b/seagull/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c4a294b7ee246db285cfb5635deee7ca4dc6095
Binary files /dev/null and b/seagull/__pycache__/utils.cpython-310.pyc differ
diff --git a/seagull/builder.py b/seagull/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd303e01651b3c3d6c43b65a44d805334af6afd7
--- /dev/null
+++ b/seagull/builder.py
@@ -0,0 +1,171 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import warnings
+import shutil
+
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
+import torch
+from seagull.model import *
+from seagull.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+
+
+def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
+ kwargs = {"device_map": device_map, **kwargs}
+
+ if device != "cuda":
+ kwargs['device_map'] = {"": device}
+
+ if load_8bit:
+ kwargs['load_in_8bit'] = True
+ elif load_4bit:
+ kwargs['load_in_4bit'] = True
+ kwargs['quantization_config'] = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type='nf4'
+ )
+ else:
+ kwargs['torch_dtype'] = torch.float16
+
+ if use_flash_attn:
+ kwargs['attn_implementation'] = 'flash_attention_2'
+
+ if 'seagull' in model_name.lower() or True:
+ # Load LLaVA model
+ if 'lora' in model_name.lower() and model_base is None:
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
+ if 'lora' in model_name.lower() and model_base is not None or True:
+ from seagull.model.language_model.seagull_llama import SeagullConfig
+ lora_cfg_pretrained = SeagullConfig.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ print('Loading LLaVA from base model...')
+ model = SeagullLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
+ if model.lm_head.weight.shape[0] != token_num:
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+
+ print('Loading additional LLaVA weights...')
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
+ else:
+ # this is probably from HF Hub
+ from huggingface_hub import hf_hub_download
+ def load_from_hf(repo_id, filename, subfolder=None):
+ cache_file = hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ subfolder=subfolder)
+ return torch.load(cache_file, map_location='cpu')
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
+
+ for k, v in non_lora_trainables.items():
+ print(k)
+ print('print non lora')
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
+ model.load_state_dict(non_lora_trainables, strict=False)
+
+ from peft import PeftModel
+ print('Loading LoRA weights...')
+ model = PeftModel.from_pretrained(model, model_path)
+ print('Merging LoRA weights...')
+ model = model.merge_and_unload()
+ print('Model is loaded...')
+ elif model_base is not None:
+ # this may be mm projector only
+ print('Loading LLaVA from base model...')
+ if 'mpt' in model_name.lower():
+ if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
+ shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
+ cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+ model = SeagullMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ model = SeagullLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
+
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
+ model.load_state_dict(mm_projector_weights, strict=False)
+ else:
+ if 'mpt' in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
+ model = SeagullMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+ elif 'mistral' in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ model = SeagullMistralForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ **kwargs
+ )
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ model = SeagullLlamaForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ **kwargs
+ )
+ else:
+ # Load language model
+ if model_base is not None:
+ # PEFT model
+ from peft import PeftModel
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
+ print(f"Loading LoRA weights from {model_path}")
+ model = PeftModel.from_pretrained(model, model_path)
+ print(f"Merging weights")
+ model = model.merge_and_unload()
+ print('Convert to FP16...')
+ model.to(torch.float16)
+ else:
+ use_fast = False
+ if 'mpt' in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+
+ image_processor = None
+
+ if 'seagull' in model_name.lower() or True:
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
+ if mm_use_im_patch_token:
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ if mm_use_im_start_end:
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ model.resize_token_embeddings(len(tokenizer))
+
+ vision_tower = model.get_vision_tower()
+ if not vision_tower.is_loaded:
+ vision_tower.load_model(device_map=device_map)
+ if device_map != 'auto':
+ vision_tower.to(device=device_map, dtype=torch.float16)
+ image_processor = vision_tower.image_processor
+
+ if hasattr(model.config, "max_sequence_length"):
+ context_len = model.config.max_sequence_length
+ else:
+ context_len = 2048
+
+ return tokenizer, model, image_processor, context_len
diff --git a/seagull/constants.py b/seagull/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..be8cf0204969a6c973f442b383d8e425d684e826
--- /dev/null
+++ b/seagull/constants.py
@@ -0,0 +1,12 @@
+CONTROLLER_HEART_BEAT_EXPIRATION = 30
+WORKER_HEART_BEAT_INTERVAL = 15
+
+LOGDIR = "."
+
+# Model Constants
+IGNORE_INDEX = -100
+IMAGE_TOKEN_INDEX = -200
+DEFAULT_IMAGE_TOKEN = ""
+DEFAULT_IMAGE_PATCH_TOKEN = ""
+DEFAULT_IM_START_TOKEN = ""
+DEFAULT_IM_END_TOKEN = ""
diff --git a/seagull/conversation.py b/seagull/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f246ac43dc5a2013b2ded2963fde3d2de6f30d07
--- /dev/null
+++ b/seagull/conversation.py
@@ -0,0 +1,381 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+ MPT = auto()
+ PLAIN = auto()
+ LLAMA_2 = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ messages = self.messages
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
+ messages = self.messages.copy()
+ init_role, init_msg = messages[0].copy()
+ init_msg = init_msg[0].replace("", "").strip()
+ if 'mmtag' in self.version:
+ messages[0] = (init_role, init_msg)
+ messages.insert(0, (self.roles[0], ""))
+ messages.insert(1, (self.roles[1], "Received."))
+ else:
+ messages[0] = (init_role, "\n" + init_msg)
+
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
+ wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n"
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
+ ret = ""
+
+ for i, (role, message) in enumerate(messages):
+ if i == 0:
+ assert message, "first message should not be none"
+ assert role == self.roles[0], "first message should come from user"
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i == 0: message = wrap_sys(self.system) + message
+ if i % 2 == 0:
+ message = wrap_inst(message)
+ ret += self.sep + message
+ else:
+ ret += " " + message + " " + self.sep2
+ else:
+ ret += ""
+ ret = ret.lstrip(self.sep)
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def get_images(self, return_pil=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ from PIL import Image
+ msg, image, image_process_mode = msg
+ if image_process_mode == "Pad":
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image)
+ elif image_process_mode in ["Default", "Crop"]:
+ pass
+ elif image_process_mode == "Resize":
+ image = image.resize((336, 336))
+ else:
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if longest_edge != max(image.size):
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ if return_pil:
+ images.append(image)
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ images.append(img_b64_str)
+ return images
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ msg, image, image_process_mode = msg
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ img_str = f''
+ msg = img_str + msg.replace('', '').strip()
+ ret.append([msg, None])
+ else:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ version=self.version)
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+conv_vicuna_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
+ ("Assistant",
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+ "renewable and non-renewable energy sources:\n"
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+ "energy sources are finite and will eventually run out.\n"
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+ "and other negative effects.\n"
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+ "have lower operational costs than non-renewable sources.\n"
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+ "locations than non-renewable sources.\n"
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_vicuna_v1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_llama_2 = Conversation(
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_seagull_llama_2 = Conversation(
+ system="You are a helpful language and vision assistant. "
+ "You are able to understand the visual content that the user provides, "
+ "and assist the user with a variety of tasks using natural language.",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_mpt = Conversation(
+ system="""<|im_start|>system
+A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_seagull_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.PLAIN,
+ sep="\n",
+)
+
+conv_seagull_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_seagull_v0_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+ version="v0_mmtag",
+)
+
+conv_seagull_v1 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_seagull_v1_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("USER", "ASSISTANT"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+ version="v1_mmtag",
+)
+
+default_conversation = conv_vicuna_v0
+conv_templates = {
+ "default": conv_vicuna_v0,
+ "v0": conv_vicuna_v0,
+ "v1": conv_vicuna_v1,
+ "vicuna_v1": conv_vicuna_v1,
+ "llama_2": conv_llama_2,
+
+ "plain": conv_seagull_plain,
+ "v0_plain": conv_seagull_plain,
+ "seagull_v0": conv_seagull_v0,
+ "v0_mmtag": conv_seagull_v0_mmtag,
+ "seagull_v1": conv_seagull_v1,
+ "v1_mmtag": conv_seagull_v1_mmtag,
+ "seagull_llama_2": conv_seagull_llama_2,
+
+ "mpt": conv_mpt,
+}
+
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
diff --git a/seagull/mm_utils.py b/seagull/mm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d98b5f1d9ecdbe379e4546e270088ebbfd98b053
--- /dev/null
+++ b/seagull/mm_utils.py
@@ -0,0 +1,95 @@
+from PIL import Image
+from io import BytesIO
+import base64
+
+import torch
+from transformers import StoppingCriteria
+from seagull.constants import IMAGE_TOKEN_INDEX
+
+
+def load_image_from_base64(image):
+ return Image.open(BytesIO(base64.b64decode(image)))
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def process_images(images, image_processor, model_cfg):
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
+ new_images = []
+ if image_aspect_ratio == 'pad':
+ for image in images:
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ new_images.append(image)
+ else:
+ return image_processor(images, return_tensors='pt')['pixel_values']
+ if all(x.shape == new_images[0].shape for x in new_images):
+ new_images = torch.stack(new_images, dim=0)
+ return new_images
+
+
+def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
+
+ def insert_separator(X, sep):
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
+
+ input_ids = []
+ offset = 0
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
+ offset = 1
+ input_ids.append(prompt_chunks[0][0])
+
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
+ input_ids.extend(x[offset:])
+
+ if return_tensors is not None:
+ if return_tensors == 'pt':
+ return torch.tensor(input_ids, dtype=torch.long)
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
+ return input_ids
+
+
+def get_model_name_from_path(model_path):
+ model_path = model_path.strip("/")
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith('checkpoint-'):
+ return model_paths[-2] + "_" + model_paths[-1]
+ else:
+ return model_paths[-1]
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = []
+ for keyword in keywords:
+ cur_keyword_ids = tokenizer(keyword).input_ids
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
+ cur_keyword_ids = cur_keyword_ids[1:]
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
+ self.tokenizer = tokenizer
+ self.start_len = input_ids.shape[1]
+
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
+ offset = min(output_ids.shape[1] - self.start_len, 3)
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
+ for keyword_id in self.keyword_ids:
+ if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
+ return True
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
diff --git a/seagull/model/__init__.py b/seagull/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0791450c706cfca822d5bd7afc233ccffabf2555
--- /dev/null
+++ b/seagull/model/__init__.py
@@ -0,0 +1 @@
+from .language_model.seagull_llama import SeagullLlamaForCausalLM, SeagullConfig
diff --git a/seagull/model/__pycache__/Q_A.cpython-310.pyc b/seagull/model/__pycache__/Q_A.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..89f9e59764f357e9c8bb5d50f73e245b68f9d2c8
Binary files /dev/null and b/seagull/model/__pycache__/Q_A.cpython-310.pyc differ
diff --git a/seagull/model/__pycache__/Q_A_pretrain.cpython-310.pyc b/seagull/model/__pycache__/Q_A_pretrain.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0235dd81ac448a00a58eaa10a22d765abb13df2
Binary files /dev/null and b/seagull/model/__pycache__/Q_A_pretrain.cpython-310.pyc differ
diff --git a/seagull/model/__pycache__/Q_A_pretrain_level.cpython-310.pyc b/seagull/model/__pycache__/Q_A_pretrain_level.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..833ba8213254fffb99b267e44168ad5f53537dee
Binary files /dev/null and b/seagull/model/__pycache__/Q_A_pretrain_level.cpython-310.pyc differ
diff --git a/seagull/model/__pycache__/Q_A_stage3.cpython-310.pyc b/seagull/model/__pycache__/Q_A_stage3.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ac642cfad7163a3400e2a008ee146beb2c49489
Binary files /dev/null and b/seagull/model/__pycache__/Q_A_stage3.cpython-310.pyc differ
diff --git a/seagull/model/__pycache__/__init__.cpython-310.pyc b/seagull/model/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2f7b511f41154c85768bfe53411a327c68c0e10
Binary files /dev/null and b/seagull/model/__pycache__/__init__.cpython-310.pyc differ
diff --git a/seagull/model/__pycache__/layer.cpython-310.pyc b/seagull/model/__pycache__/layer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48ee84195dc28f2c60d889ff81815724ad66af8d
Binary files /dev/null and b/seagull/model/__pycache__/layer.cpython-310.pyc differ
diff --git a/seagull/model/__pycache__/layer_osprey.cpython-310.pyc b/seagull/model/__pycache__/layer_osprey.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c7e079642977decc63c435b880dd710d1151c28
Binary files /dev/null and b/seagull/model/__pycache__/layer_osprey.cpython-310.pyc differ
diff --git a/seagull/model/__pycache__/osprey_arch.cpython-310.pyc b/seagull/model/__pycache__/osprey_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..445073450dfe8a8ecaa4de82ae2a00efd3fe0436
Binary files /dev/null and b/seagull/model/__pycache__/osprey_arch.cpython-310.pyc differ
diff --git a/seagull/model/__pycache__/seagull_arch.cpython-310.pyc b/seagull/model/__pycache__/seagull_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a93d56549569a24058f2b8d2d5dfccadb488851
Binary files /dev/null and b/seagull/model/__pycache__/seagull_arch.cpython-310.pyc differ
diff --git a/seagull/model/__pycache__/stage2_distrotion_maker.cpython-310.pyc b/seagull/model/__pycache__/stage2_distrotion_maker.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d9f2c608f0c5c16fc8c52fecdb831c390d16e8b
Binary files /dev/null and b/seagull/model/__pycache__/stage2_distrotion_maker.cpython-310.pyc differ
diff --git a/seagull/model/consolidate.py b/seagull/model/consolidate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab05e6b9639378c1d110fd1a937a1d5d9d9cbd49
--- /dev/null
+++ b/seagull/model/consolidate.py
@@ -0,0 +1,26 @@
+
+import argparse
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from seagull.model import *
+from seagull.model.utils import auto_upgrade
+
+
+def consolidate_ckpt(src_path, dst_path):
+ print("Loading model")
+ auto_upgrade(src_path)
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
+ src_model.save_pretrained(dst_path)
+ src_tokenizer.save_pretrained(dst_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--src", type=str, required=True)
+ parser.add_argument("--dst", type=str, required=True)
+
+ args = parser.parse_args()
+
+ consolidate_ckpt(args.src, args.dst)
diff --git a/seagull/model/language_model/__pycache__/osprey_llama.cpython-310.pyc b/seagull/model/language_model/__pycache__/osprey_llama.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2d10eaeefb2c41cd357c1537a7576ab4236917c
Binary files /dev/null and b/seagull/model/language_model/__pycache__/osprey_llama.cpython-310.pyc differ
diff --git a/seagull/model/language_model/__pycache__/seagull_llama.cpython-310.pyc b/seagull/model/language_model/__pycache__/seagull_llama.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..abaf71e73c8ae9a618368547ed00a0a69bed7274
Binary files /dev/null and b/seagull/model/language_model/__pycache__/seagull_llama.cpython-310.pyc differ
diff --git a/seagull/model/language_model/seagull_llama.py b/seagull/model/language_model/seagull_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dc420fbf43232afa7875d5cba64eb284f0336b3
--- /dev/null
+++ b/seagull/model/language_model/seagull_llama.py
@@ -0,0 +1,128 @@
+from typing import List, Optional, Tuple, Union
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+from transformers import AutoConfig, AutoModelForCausalLM, \
+ LlamaConfig, LlamaModel, LlamaForCausalLM
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from ..seagull_arch import SeagullMetaModel, SeagullMetaForCausalLM
+from ..layer import MaskExtractor
+
+class SeagullConfig(LlamaConfig):
+ model_type = "seagull"
+
+class SeagullLlamaModel(SeagullMetaModel, LlamaModel):
+ config_class = SeagullConfig
+
+ def __init__(self, config: LlamaConfig):
+ super(SeagullLlamaModel, self).__init__(config)
+
+class SeagullLlamaForCausalLM(LlamaForCausalLM, SeagullMetaForCausalLM):
+ config_class = SeagullConfig
+
+ def __init__(self, config):
+ super(LlamaForCausalLM, self).__init__(config)
+ self.model = SeagullLlamaModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.mask_extractor = MaskExtractor()
+
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ img_metas = None,
+ masks = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ preprocessed_img_dict = None,
+ return_dict: Optional[bool] = None,
+ cropped_img: Optional[torch.FloatTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, masks, attention_mask, past_key_values, labels, images, preprocessed_img_dict=preprocessed_img_dict, cropped_img=cropped_img)
+
+ if inputs_embeds is not None:
+ inputs_embeds = inputs_embeds.bfloat16()
+
+ self.model = self.model.bfloat16()
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict
+ )
+
+ hidden_states = outputs[0]
+ self.lm_head = self.lm_head.to(hidden_states.dtype)
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model/pipeline parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "images": kwargs.get("images", None),
+ }
+ )
+ return model_inputs
+
+AutoConfig.register("seagull", SeagullConfig)
+AutoModelForCausalLM.register(SeagullConfig, SeagullLlamaForCausalLM)
diff --git a/seagull/model/layer.py b/seagull/model/layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..60f86799a96deb3a59766c89f6e6f53ec4270520
--- /dev/null
+++ b/seagull/model/layer.py
@@ -0,0 +1,250 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Tuple, Type, Any
+from torch import Tensor
+import math
+import numpy as np
+from einops import rearrange
+
+class MLP(nn.Module):
+
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
+ num_layers: int) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class MaskExtractor(nn.Module): # Mask-based Feature Extractor
+ def __init__(self, mask_shape=112, embed_dim=1024, out_dim=4096, num_heads=8, mlp_dim=2048, downsample_rate=2, skip_first_layer_pe=False):
+ super(MaskExtractor, self).__init__()
+ self.mask_shape = mask_shape
+ self.mask_pooling = MaskPooling()
+ self.feat_linear = nn.Linear(embed_dim, out_dim)
+ self.cross_feat_linear = nn.Linear(embed_dim, out_dim)
+ self.mask_linear = MLP(mask_shape*mask_shape, embed_dim, out_dim, 3)
+
+ self.feature_name = ['res2', 'res3', 'res4', 'res5']
+
+ self.cross_att_res = CrossAttention(
+ embedding_dim=embed_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ douwnsample_rate=downsample_rate,
+ skip_first_layer_pe=skip_first_layer_pe
+ )
+
+ self.res2 = nn.Linear(192, 1024)
+ self.res3 = nn.Linear(384, 1024)
+ self.res4 = nn.Linear(768, 1024)
+ self.res5 = nn.Linear(1536, 1024)
+
+ self.g_res2 = nn.Linear(16384, 1024) # h * w
+ self.g_res3 = nn.Linear(4096, 1024)
+ self.g_res4 = nn.Linear(1024, 1024)
+ self.g_res5 = nn.Linear(256, 1024)
+
+ self.final_mlp = nn.Linear(2 * out_dim, out_dim)
+
+ self.global_vit = nn.Sequential(
+ nn.Conv2d(3, 5, 1),
+ nn.GELU(),
+ nn.AvgPool2d(4, 4),
+
+ nn.Conv2d(5, 1, 1),
+ nn.GELU(),
+ nn.AvgPool2d(4, 4),
+ )
+ self.is_first = 0
+
+ self.sa = Attention(32 * 32, num_heads) # self-attention
+ self.mlp = MLP(32 * 32, 512, out_dim, 3)
+
+ def cal_globa_local(self, mask_feat_raw, feat_new, res, g_res, cross_attention):
+ mask_feat_flatten = mask_feat_raw.to(device=res.weight.device, dtype=res.weight.dtype)
+ mask_feat = res(mask_feat_flatten) # (b, q, 1024)
+
+ feat_new = feat_new.to(device=g_res.weight.device, dtype=g_res.weight.dtype)
+ all_feat_new = g_res(feat_new) # (b, c, 1024)
+ global_mask = cross_attention(mask_feat, all_feat_new)
+ return mask_feat, global_mask
+
+ def forward(self, feats, masks, cropped_img):
+ global_features = []
+ local_features = []
+ num_imgs = len(masks)
+
+ for idx in range(num_imgs):
+ mask = masks[idx].unsqueeze(0).float() #(1, q, h, w)
+ cropped_ = cropped_img[idx] # (q, 3, h, w)
+
+ num_feats = len(self.feature_name)
+ mask_feats = mask.new_zeros(num_feats, mask.shape[1], 1024)
+ global_masks = mask.new_zeros(num_feats, mask.shape[1], 1024)
+
+ for i, name in enumerate(self.feature_name):
+ feat = feats[name][idx].unsqueeze(0)
+ feat = feat.to(mask.dtype)
+
+ mask_feat_raw = self.mask_pooling(feat, mask)
+ feat_new = rearrange(feat, 'b c h w -> b c (h w)')
+
+ mask_feat, global_mask = self.cal_globa_local(mask_feat_raw, feat_new, res=getattr(self, name), g_res=getattr(self, 'g_{}'.format(name)), cross_attention=getattr(self,"cross_att_res"))
+
+ mask_feats[i] = mask_feat.squeeze(0) # (q, 1024)
+ global_masks[i] = global_mask.squeeze(0)
+ mask_feats = mask_feats.sum(0) # (1, q, 1024)
+ global_masks = global_masks.sum(0) # (1, q, 1024)
+ global_masks = global_masks.to(device=self.cross_feat_linear.weight.device, dtype=self.cross_feat_linear.weight.dtype)
+ global_masks_linear = self.cross_feat_linear(global_masks)
+ mask_feats = mask_feats.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
+ mask_feats_linear = self.feat_linear(mask_feats) #(1, q, 4096)
+
+ query_feat = self.final_mlp(torch.cat((global_masks_linear, mask_feats_linear), dim=-1))
+ global_features.append(query_feat) # global
+
+ cropped_ = cropped_.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
+ global_features = self.global_vit(cropped_).to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype) # q, 1, 32, 32
+ global_features = global_features.reshape(-1, 1, 32 * 32) # q, 1, 32 * 32
+ pos_feat = self.mlp(self.sa(global_features, global_features, global_features).squeeze(1)) # q, output
+
+ local_features.append(pos_feat) #(imgs_num, 1, q, 4096) # local
+
+ return global_features, local_features
+
+class MaskPooling(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, mask):
+
+ if not x.shape[-2:] == mask.shape[-2:]:
+ # reshape mask to x
+ mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
+
+ mask = (mask > 0).to(mask.dtype)
+ denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
+
+ mask_pooled_x = torch.einsum(
+ "bchw,bqhw->bqc",
+ x,
+ mask / denorm,
+ )
+ return mask_pooled_x
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ douwnsample_rate: int = 2,
+ activation: Type[nn.Module] = nn.ReLU,
+ skip_first_layer_pe: bool = False
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.num_heads =num_heads
+ self.self_attn = Attention(embedding_dim, num_heads) # self-attention
+ self.skip_first_layer_pe = skip_first_layer_pe
+ self.norm1 = nn.LayerNorm(embedding_dim)
+
+ # cross-attention
+ self.cross_attn = Attention(embedding_dim, num_heads, downsample_rate=douwnsample_rate)
+ self.norm2 = nn.LayerNorm(embedding_dim)
+
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) # MLP
+
+ def forward(self, queries, keys):
+ attn_out = self.self_attn(queries, queries, queries)
+ queries = queries + attn_out
+ queries = self.norm1(queries)
+
+ attn_out = self.cross_attn(q=queries, k=keys, v=keys)
+ queries = attn_out + queries
+ queries = self.norm2(queries)
+
+ # MLP
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ return queries
+
+class MLPBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ mlp_dim: int,
+ act: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ super().__init__()
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
+ self.act = act()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.lin2(self.act(self.lin1(x)))
+
+class Attention(nn.Module):
+ """
+ An attention layer that allows for downscaling the size of the embedding
+ after projection to queries, keys, and values.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ downsample_rate: int = 1,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ self.num_heads = num_heads
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
+
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
+
+ def _recombine_heads(self, x: Tensor) -> Tensor:
+ b, n_heads, n_tokens, c_per_head = x.shape
+ x = x.transpose(1, 2)
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Attention
+ _, _, _, c_per_head = q.shape
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
+ attn = attn / math.sqrt(c_per_head)
+ attn = torch.softmax(attn, dim=-1)
+
+ # Get output
+ out = attn @ v
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
diff --git a/seagull/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc b/seagull/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34083dd640ef4f14b99e253437adbe2544998576
Binary files /dev/null and b/seagull/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc differ
diff --git a/seagull/model/multimodal_encoder/__pycache__/clip.cpython-310.pyc b/seagull/model/multimodal_encoder/__pycache__/clip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..539e773e29035b0519df2973d8de48d5a5fde856
Binary files /dev/null and b/seagull/model/multimodal_encoder/__pycache__/clip.cpython-310.pyc differ
diff --git a/seagull/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc b/seagull/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdca9e1091ab685f2d1c1e51006ca220264144ae
Binary files /dev/null and b/seagull/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc differ
diff --git a/seagull/model/multimodal_encoder/builder.py b/seagull/model/multimodal_encoder/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ff45da3a08c6597b233824491c7e8a0ffee01e1
--- /dev/null
+++ b/seagull/model/multimodal_encoder/builder.py
@@ -0,0 +1,7 @@
+import os
+from .clip_encoder import CLIPVisionTower
+
+
+def build_vision_tower(vision_tower_cfg, delay_load=False):
+
+ return CLIPVisionTower(args=vision_tower_cfg)
diff --git a/seagull/model/multimodal_encoder/clip.py b/seagull/model/multimodal_encoder/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..69431f1fd08cc9721c44e45629dbe4153af90fa8
--- /dev/null
+++ b/seagull/model/multimodal_encoder/clip.py
@@ -0,0 +1,40 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+
+from open_clip.model import _build_vision_tower
+
+
+class CLIP(nn.Module):
+ def __init__(self):
+ super().__init__()
+ model_name = 'convnext_large'
+
+ vision_cfg = {'timm_model_name': model_name, 'timm_model_pretrained': False, 'timm_pool': '', 'timm_proj': 'mlp', 'timm_drop': 0.0, 'timm_drop_path': 0.1, 'image_size': 320}
+ self.visual = _build_vision_tower(embed_dim=768, vision_cfg=vision_cfg, quick_gelu=False)
+
+ self.eval()
+ self.freeze_everything()
+
+ def freeze_everything(self):
+ for param in self.visual.parameters():
+ param.requires_grad = False
+
+ def extract_features(self, x):
+ out = {}
+ x = x.to(self.visual.trunk.stem.state_dict()['1.bias'].dtype)
+ x = self.visual.trunk.stem(x)
+ out['stem'] = x.contiguous()
+ for i in range(4):
+ x = self.visual.trunk.stages[i](x)
+ out[f'res{i+2}'] = x.contiguous()
+
+ x = self.visual.trunk.norm_pre(x)
+ out['clip_vis_dense'] = x.contiguous()
+ return out
+
+ def forward(self, x):
+ self.eval()
+ with torch.no_grad():
+ return self.extract_features(x)
+
\ No newline at end of file
diff --git a/seagull/model/multimodal_encoder/clip_encoder.py b/seagull/model/multimodal_encoder/clip_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b3966ca3c231723b4604f247d4317cfc9048057
--- /dev/null
+++ b/seagull/model/multimodal_encoder/clip_encoder.py
@@ -0,0 +1,59 @@
+import torch
+import torch.nn as nn
+
+from transformers import CLIPImageProcessor
+from .clip import CLIP
+
+class CLIPVisionTower(nn.Module):
+ def __init__(self, args, img_size=512, delay_load=False):
+ super().__init__()
+
+ # test
+ if hasattr(args, 'mm_vision_tower'):
+ self.clip_model = args.mm_vision_tower
+ else: # train
+ self.clip_model = args.vision_tower
+ self.is_loaded = False
+ self.img_size = img_size
+
+ if not delay_load:
+ self.load_model()
+
+ def load_model(self):
+ self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":self.img_size}, resample=3, do_center_crop=True, crop_size={"height": self.img_size, "width": self.img_size},
+ do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
+
+ self.vision_tower = CLIP()
+
+ self.vision_tower.load_state_dict(torch.load(self.clip_model),strict=False)
+
+ self.is_loaded = True
+
+ @torch.no_grad()
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ image_features_dict = []
+ for image in images:
+ image_feature_dict = self.vision_tower(image.unsqueeze(0))
+ image_features_dict.append(image_feature_dict)
+ image_feature = image_feature_dict['res4']
+ image_feature = image_feature.reshape(*image_feature.shape[:2],-1).permute(0,2,1)
+ image_features.append(image_feature)
+ else:
+ # print(images.device)
+ # print(self.vision_tower.device)
+ image_features_dict = self.vision_tower(images)
+ image_features = image_features_dict['res4']
+ image_features = image_features.reshape(*image_features.shape[:2],-1).permute(0,2,1)
+
+ return image_features, image_features_dict
+
+ @property
+ def dtype(self):
+ return self.vision_tower.dtype
+
+ @property
+ def device(self):
+ return self.vision_tower.device
diff --git a/seagull/model/multimodal_projector/__pycache__/builder.cpython-310.pyc b/seagull/model/multimodal_projector/__pycache__/builder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82538a63fdc81052110b8b6a46ff641de9235f32
Binary files /dev/null and b/seagull/model/multimodal_projector/__pycache__/builder.cpython-310.pyc differ
diff --git a/seagull/model/multimodal_projector/builder.py b/seagull/model/multimodal_projector/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e7c2b078c05c9f87d040501ec93e2be221630a0
--- /dev/null
+++ b/seagull/model/multimodal_projector/builder.py
@@ -0,0 +1,52 @@
+import torch
+import torch.nn as nn
+import re
+
+
+class IdentityMap(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+ @property
+ def config(self):
+ return {"mm_projector_type": 'identity'}
+
+
+class SimpleResBlock(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.pre_norm = nn.LayerNorm(channels)
+
+ self.proj = nn.Sequential(
+ nn.Linear(channels, channels),
+ nn.GELU(),
+ nn.Linear(channels, channels)
+ )
+ def forward(self, x):
+ x = self.pre_norm(x)
+ return x + self.proj(x)
+
+
+def build_vision_projector(config, delay_load=False, **kwargs):
+ mm_hidden_size = getattr(config, 'mm_hidden_size', 768)
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
+
+ if projector_type == 'linear':
+ return nn.Linear(mm_hidden_size, config.hidden_size)
+
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
+ if mlp_gelu_match:
+ mlp_depth = int(mlp_gelu_match.group(1))
+ modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
+ return nn.Sequential(*modules)
+
+ if projector_type == 'identity':
+ return IdentityMap()
+
+ raise ValueError(f'Unknown projector type: {projector_type}')
diff --git a/seagull/model/seagull_arch.py b/seagull/model/seagull_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa57f25891eef840b6f614f398ca854bed39428e
--- /dev/null
+++ b/seagull/model/seagull_arch.py
@@ -0,0 +1,281 @@
+from abc import ABC, abstractmethod
+
+import torch
+
+from .multimodal_encoder.builder import build_vision_tower
+from .multimodal_projector.builder import build_vision_projector
+
+from seagull.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+
+
+class SeagullMetaModel:
+
+ def __init__(self, config):
+ super(SeagullMetaModel, self).__init__(config)
+
+ if hasattr(config, "mm_vision_tower"):
+ self.vision_tower = build_vision_tower(config, delay_load=False)
+ self.mm_projector = build_vision_projector(config)
+
+ def get_vision_tower(self):
+ vision_tower = getattr(self, 'vision_tower', None)
+ if type(vision_tower) is list:
+ vision_tower = vision_tower[0]
+ return vision_tower
+
+ def initialize_vision_modules(self, model_args, fsdp=None):
+
+ vision_tower = model_args.vision_tower
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
+
+ if not hasattr(self.config, "mm_vision_tower"):
+ self.config.mm_vision_tower = vision_tower
+
+ vision_tower = build_vision_tower(model_args)
+
+ if fsdp is not None and len(fsdp) > 0:
+ self.vision_tower = [self.vision_tower]
+ else:
+ self.vision_tower = vision_tower
+
+ self.config.use_mm_proj = True
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
+
+ self.mm_projector = build_vision_projector(self.config)
+
+ if pretrain_mm_mlp_adapter is not None:
+ print("***********load projector_weights********")
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
+ def get_w(weights, keyword):
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
+
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
+
+
+
+class SeagullMetaForCausalLM(ABC):
+ def __init__(self):
+ super(SeagullMetaForCausalLM, self).__init__()
+
+ @abstractmethod
+ def get_model(self):
+ pass
+
+ def get_vision_tower(self):
+ return self.get_model().get_vision_tower()
+
+ def encode_images(self, images):
+ image_features, image_features_dict = self.get_model().get_vision_tower()(images)
+ self.get_model().mm_projector.to(device=image_features.device, dtype=image_features.dtype)
+ image_features = self.get_model().mm_projector(image_features)
+ return image_features, image_features_dict
+
+ def prepare_inputs_labels_for_multimodal(
+ self, input_ids, masks, attention_mask, past_key_values, labels, images, preprocessed_img_dict=None, cropped_img=None
+ ):
+ vision_tower = self.get_vision_tower()
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
+ return input_ids, attention_mask, past_key_values, None, labels
+
+ if preprocessed_img_dict is not None:
+ image_features, image_features_dict = images, preprocessed_img_dict
+ else:
+ if type(images) is list or images.ndim == 5:
+ concat_images = torch.cat([image for image in images], dim=0)
+ image_features, image_features_dict = self.encode_images(concat_images)
+ split_sizes = [image.shape[0] for image in images]
+ image_features = torch.split(image_features, split_sizes, dim=0)
+ image_features = [x.flatten(0, 1).to(concat_images.device) for x in image_features]
+ else:
+ image_features, image_features_dict = self.encode_images(images)
+
+
+ mask_feats, pos_feats = self.mask_extractor(image_features_dict, masks, cropped_img=cropped_img)
+
+ new_input_embeds = []
+ new_labels = [] if labels is not None else None
+ cur_image_idx = 0
+ for batch_idx, cur_input_ids in enumerate(input_ids):
+
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
+ # multimodal LLM, but the current sample is not multimodal
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
+ half_len = cur_input_ids.shape[0] // 2
+ cur_image_features = image_features[cur_image_idx]
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
+ cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
+ new_input_embeds.append(cur_input_embeds)
+ if labels is not None:
+ new_labels.append(labels[batch_idx])
+ cur_image_idx += 1
+ continue
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
+ cur_new_input_embeds = []
+ if labels is not None:
+ cur_labels = labels[batch_idx]
+ cur_new_labels = []
+ assert cur_labels.shape == cur_input_ids.shape
+ while image_token_indices.numel() > 0:
+ cur_image_features = image_features[cur_image_idx]
+ image_token_start = image_token_indices[0]
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
+ cur_new_input_embeds.append(cur_image_features)
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
+ if labels is not None:
+ cur_new_labels.append(cur_labels[:image_token_start])
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
+ cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
+ cur_labels = cur_labels[image_token_start+2:]
+ else:
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
+ cur_new_input_embeds.append(cur_image_features)
+ if labels is not None:
+ cur_new_labels.append(cur_labels[:image_token_start])
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
+ cur_labels = cur_labels[image_token_start+1:]
+ cur_image_idx += 1
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
+ cur_input_ids = cur_input_ids[image_token_start+2:]
+ else:
+ cur_input_ids = cur_input_ids[image_token_start+1:]
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
+ if cur_input_ids.numel() > 0:
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
+ mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids([''])[0])
+
+ _l = 0
+ for i, idx in enumerate(mask_idx):
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
+ ## mask
+ cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].detach())
+ ## pos
+ cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].detach())
+ if labels is not None:
+ cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
+ _l = idx[0]+2
+ if _l< len(cur_input_ids):
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]).detach())
+
+ else:
+
+ mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids([''])[0])
+ assert len(mask_idx) == len(mask_feats[batch_idx]), "mask num not equal to mask feats"
+
+ _l = 0
+ for i, idx in enumerate(mask_idx):
+ cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
+ cur_new_input_embeds.append(cur_raw_new_input_embeds)
+ ## mask
+ cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
+ ## pos
+ cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
+
+ if labels is not None:
+ cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
+
+ _l = idx[0]+2
+ if _l< len(cur_input_ids):
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]))
+
+ if labels is not None:
+ cur_new_labels.append(cur_labels)
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
+
+ new_input_embeds.append(cur_new_input_embeds)
+ if labels is not None:
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
+ new_labels.append(cur_new_labels)
+
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
+ max_len = max(x.shape[0] for x in new_input_embeds)
+
+ new_input_embeds_align = []
+ for cur_new_embed in new_input_embeds:
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
+ new_input_embeds_align.append(cur_new_embed)
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
+
+ if labels is not None:
+ new_labels_align = []
+ _new_labels = new_labels
+ for cur_new_label in new_labels:
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
+ new_labels_align.append(cur_new_label)
+ new_labels = torch.stack(new_labels_align, dim=0)
+
+ if attention_mask is not None:
+ new_attention_mask = []
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
+ new_attention_mask.append(cur_new_attention_mask)
+ attention_mask = torch.stack(new_attention_mask, dim=0)
+ assert attention_mask.shape == new_labels.shape
+ else:
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
+ if labels is not None:
+ new_labels = torch.stack(new_labels, dim=0)
+
+ if attention_mask is not None:
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
+ assert attention_mask.shape == new_input_embeds.shape[:2]
+
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
+
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
+ if model_args.mm_use_im_patch_token:
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ mask_tokens = ['', '']
+ num_new_tokens = tokenizer.add_tokens(mask_tokens, special_tokens=True)
+
+ if model_args.mm_use_im_start_end:
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = self.get_input_embeddings().weight.data
+ output_embeddings = self.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = True
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = False
+
+ if model_args.pretrain_mm_mlp_adapter:
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
+ assert num_new_tokens == 2
+ if input_embeddings.shape == embed_tokens_weight.shape:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
+ else:
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
+ elif model_args.mm_use_im_patch_token:
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = False
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = False
+
+ for m in self.modules():
+ m.tokenizer = tokenizer
\ No newline at end of file
diff --git a/seagull/train/__pycache__/seagull_trainer.cpython-310.pyc b/seagull/train/__pycache__/seagull_trainer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76d4ccc828cee42d3b63805506cd3664685ac26a
Binary files /dev/null and b/seagull/train/__pycache__/seagull_trainer.cpython-310.pyc differ
diff --git a/seagull/train/__pycache__/train.cpython-310.pyc b/seagull/train/__pycache__/train.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf4d8dde71f72f063a469dba14cc93692a7f5f98
Binary files /dev/null and b/seagull/train/__pycache__/train.cpython-310.pyc differ
diff --git a/seagull/train/llama_flash_attn_monkey_patch.py b/seagull/train/llama_flash_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ff5165ceb0240ffc4ae79875acf04d4bb3994dd
--- /dev/null
+++ b/seagull/train/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,116 @@
+from typing import Optional, Tuple
+import warnings
+
+import torch
+
+import transformers
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
+
+try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+except ImportError:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ warnings.warn(
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+ # print("begin_#")
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ # print("OK_#")
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ ) # shape: (b, num_heads, s, head_dim)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+
+ if past_key_value is not None:
+ # reuse k, v
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # Transform the data into the format required by flash attention
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
+ key_padding_mask = attention_mask
+
+ if key_padding_mask is None:
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
+ cu_q_lens = torch.arange(
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+ )
+ max_s = q_len
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = output.view(bsz, q_len, -1)
+ else:
+ qkv = qkv.reshape(bsz, q_len, -1)
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
+ output = pad_input(output_unpad, indices, bsz, q_len)
+
+ return self.o_proj(output), None, past_key_value
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
+ if cuda_major < 8:
+ warnings.warn(
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
+ )
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
+ _prepare_decoder_attention_mask
+ )
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/seagull/train/seagull_trainer.py b/seagull/train/seagull_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c673b9e09d6e5170f269c2d975184d59e30d765a
--- /dev/null
+++ b/seagull/train/seagull_trainer.py
@@ -0,0 +1,175 @@
+import os
+import torch
+
+from torch.utils.data import Sampler
+
+from transformers import Trainer
+from transformers.trainer import (
+ has_length,
+)
+from typing import List, Optional
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ print(name, 'no ignore status')
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def split_to_even_chunks(indices, lengths, num_chunks):
+ """
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
+ """
+
+ if len(indices) % num_chunks != 0:
+ return [indices[i::num_chunks] for i in range(num_chunks)]
+
+ num_indices_per_chunk = len(indices) // num_chunks
+
+ chunks = [[] for _ in range(num_chunks)]
+ chunks_lengths = [0 for _ in range(num_chunks)]
+ for index in indices:
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
+ chunks[shortest_chunk].append(index)
+ chunks_lengths[shortest_chunk] += lengths[index]
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
+ chunks_lengths[shortest_chunk] = float("inf")
+
+ return chunks
+
+
+def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ assert all(l != 0 for l in lengths), "Should not have zero length."
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
+
+ assert len(mm_indices) > 0, "Should have at least one multimodal sample."
+ assert len(lang_indices) > 0, "Should have at least one language sample."
+
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
+ megabatch_size = world_size * batch_size
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
+
+ last_mm = mm_megabatches[-1]
+ last_lang = lang_megabatches[-1]
+ additional_batch = last_mm + last_lang
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in megabatch_indices]
+
+ if len(additional_batch) >= megabatch_size:
+ megabatches = [additional_batch[:megabatch_size]] + megabatches
+ additional_batch = additional_batch[megabatch_size:]
+
+ if len(additional_batch) > 0:
+ megabatches.append(additional_batch)
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ indices = torch.randperm(len(lengths), generator=generator)
+ megabatch_size = world_size * batch_size
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
+
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
+
+
+class LengthGroupedSampler(Sampler):
+ r"""
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
+ keeping a bit of randomness.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ world_size: int,
+ lengths: Optional[List[int]] = None,
+ generator=None,
+ group_by_modality: bool = False,
+ ):
+ if lengths is None:
+ raise ValueError("Lengths must be provided.")
+
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.lengths = lengths
+ self.generator = generator
+ self.group_by_modality = group_by_modality
+
+ def __len__(self):
+ return len(self.lengths)
+
+ def __iter__(self):
+ if self.group_by_modality:
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ else:
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ return iter(indices)
+
+
+class SeagullTrainer(Trainer):
+
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ if self.args.group_by_modality_length:
+ lengths = self.train_dataset.modality_lengths
+ return LengthGroupedSampler(
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
+ self.args.train_batch_size,
+ world_size=self.args.world_size,
+ lengths=lengths,
+ group_by_modality=True,
+ )
+ else:
+ return super()._get_train_sampler()
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+
+ # Only save Adapter
+ keys_to_match = ['mm_projector', 'vision_resampler']
+ if getattr(self.args, "use_im_start_end", False):
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
+
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
+ self.model.config.save_pretrained(output_dir)
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
+ else:
+ super(SeagullTrainer, self)._save_checkpoint(model, trial, metrics)
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
+ pass
+ else:
+ super(SeagullTrainer, self)._save(output_dir, state_dict)
diff --git a/seagull/train/train.py b/seagull/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..c749fc84e1ddd81e67a07d234717f2b680485433
--- /dev/null
+++ b/seagull/train/train.py
@@ -0,0 +1,743 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+from dataclasses import dataclass, field
+import logging
+import pathlib
+from typing import Dict, Optional, Sequence
+
+import torch
+
+import transformers
+
+from seagull.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from .seagull_trainer import SeagullTrainer
+
+from seagull import conversation as conversation_lib
+from seagull.model import *
+from seagull.mm_utils import tokenizer_image_token
+
+
+local_rank = None
+
+
+def rank0_print(*args):
+ if local_rank == 0:
+ print(*args)
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ version: Optional[str] = field(default="v0")
+ freeze_backbone: bool = field(default=False)
+ tune_mm_mlp_adapter: bool = field(default=False)
+ vision_tower: Optional[str] = field(default=None)
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
+ mm_projector_type: Optional[str] = field(default='linear')
+ mm_use_im_start_end: bool = field(default=False)
+ mm_use_im_patch_token: bool = field(default=True)
+ mm_vision_select_feature: Optional[str] = field(default="patch")
+
+@dataclass
+class DataArguments:
+ lazy_preprocess: bool = False
+ is_multimodal: bool = False
+ sep_image_conv_front: bool = False
+ image_token_len: int = 0
+ dataset_list: str = "LIVEC,BID,KONIQ,SPAQ"
+ image_aspect_ratio: str = 'square'
+ image_grid_pinpoints: Optional[str] = field(default=None)
+ dataset_config: Optional[str] = field(default='./seagull/configs/stage1.json',
+ metadata={'help': 'Path to the dataset config file.'})
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ remove_unused_columns: bool = field(default=False)
+ freeze_mm_mlp_adapter: bool = field(default=False)
+ model_max_length: int = field(
+ default=512,
+ metadata={
+ "help":
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+ double_quant: bool = field(
+ default=True,
+ metadata={"help": "Compress the quantization statistics through double quantization."}
+ )
+ quant_type: str = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
+ )
+ bits: int = field(
+ default=16,
+ metadata={"help": "How many bits to use."}
+ )
+ lora_enable: bool = False
+ lora_r: int = 64
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ mm_projector_lr: Optional[float] = None
+ group_by_modality_length: bool = field(default=False)
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+def find_all_linear_names(model):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler', "mask_extractor"]
+ for name, module in model.named_modules():
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+ continue
+ if isinstance(module, cls):
+ lora_module_names.add(name)
+ rank0_print('Lora Finetunine: ', lora_module_names)
+ if 'lm_head' in lora_module_names: # needed for 16-bit
+ lora_module_names.remove('lm_head')
+ return list(lora_module_names)
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
+ output_dir: str):
+ """Collects the state dict and dump to disk."""
+
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False):
+ # Only save Adapter
+ keys_to_match = ['mm_projector']
+ if getattr(trainer.args, "use_im_start_end", False):
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
+ trainer.model.config.save_pretrained(output_dir)
+
+ current_folder = output_dir.split('/')[-1]
+ parent_folder = os.path.dirname(output_dir)
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
+ if current_folder.startswith('checkpoint-'):
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
+ os.makedirs(mm_projector_folder, exist_ok=True)
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
+ else:
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
+ return
+
+ if trainer.deepspeed:
+ torch.cuda.synchronize()
+ trainer.save_model(output_dir)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {
+ key: value.cpu()
+ for key, value in state_dict.items()
+ }
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(strings: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ) for text in strings
+ ]
+ input_ids = labels = [
+ tokenized.input_ids[0] for tokenized in tokenized_list
+ ]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
+ for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
+
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = conversation_lib.default_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = conversation_lib.default_conversation.roles[1]
+ else:
+ from_str = 'unknown'
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
+ sentence["value"] + END_SIGNAL)
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+
+def preprocess_multimodal(
+ sources: Sequence[str],
+ data_args: DataArguments,
+ cur_token_len: int = 0
+) -> Dict:
+
+ for source in sources:
+ for sentence in source:
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
+ sentence['value'] = sentence['value'].strip()
+ if "mmtag" in conversation_lib.default_conversation.version:
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '')
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if data_args.mm_use_im_start_end:
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ return sources
+
+
+def preprocess_llama_2(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
+
+ # Mask targets
+ sep = "[/INST] "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ rank0_print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+def preprocess_v1(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ rank0_print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_plain(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ assert len(source) == 2
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
+ conversations.append(conversation)
+ # tokenize conversations
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
+ target[:tokenized_len] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
+ return preprocess_plain(sources, tokenizer)
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version.startswith("v1"):
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ header = f"{conversation_lib.default_conversation.system}\n\n"
+ conversation = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+ # tokenize conversations
+ def get_tokenize_len(prompts):
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
+
+ if has_image:
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ else:
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
+ input_ids = conversations_tokenized["input_ids"]
+
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ if has_image:
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
+ else:
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
+ speakers = [sentence["from"] for sentence in source]
+ _mask_targets(target, tokenized_lens, speakers)
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+import time
+def train():
+ global local_rank
+
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ local_rank = training_args.local_rank
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+
+ bnb_model_from_pretrained_args = {}
+ if training_args.bits in [4, 8]:
+ from transformers import BitsAndBytesConfig
+ bnb_model_from_pretrained_args.update(dict(
+ device_map={"": training_args.device},
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=training_args.double_quant,
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
+ )
+ ))
+
+ if model_args.vision_tower is not None:
+ model = SeagullLlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args
+ )
+ else:
+ model = transformers.LlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args
+ )
+ model.config.use_cache = False
+
+ if model_args.freeze_backbone:
+ model.model.requires_grad_(False)
+
+ if training_args.bits in [4, 8]:
+ from peft import prepare_model_for_kbit_training
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
+
+ if training_args.gradient_checkpointing:
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if training_args.lora_enable:
+ from peft import LoraConfig, get_peft_model
+ lora_config = LoraConfig(
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ target_modules=find_all_linear_names(model),
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ if training_args.bits == 16:
+ if training_args.bf16:
+ model.to(torch.bfloat16)
+ if training_args.fp16:
+ model.to(torch.float16)
+ rank0_print("Adding LoRA adapters...")
+ model = get_peft_model(model, lora_config)
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=True,
+ )
+
+ if model_args.version == "v0":
+ if tokenizer.pad_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token="[PAD]"),
+ tokenizer=tokenizer,
+ model=model,
+ )
+ elif model_args.version == "v0.5":
+ tokenizer.pad_token = tokenizer.unk_token
+ else:
+ tokenizer.pad_token = tokenizer.unk_token
+ if model_args.version in conversation_lib.conv_templates:
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
+ else:
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
+
+ if model_args.vision_tower is not None:
+ model.get_model().initialize_vision_modules(
+ model_args=model_args,
+ fsdp=training_args.fsdp
+ )
+
+ vision_tower = model.get_vision_tower()
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
+
+ data_args.image_processor = vision_tower.image_processor
+
+ rank0_print(data_args.image_processor)
+ data_args.is_multimodal = True
+
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
+ model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
+
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
+ if model_args.tune_mm_mlp_adapter:
+ model.requires_grad_(False)
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = True
+
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
+ if training_args.freeze_mm_mlp_adapter:
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = False
+
+ if training_args.bits in [4, 8]:
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
+
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
+ model.config.mm_projector_lr = training_args.mm_projector_lr
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
+
+ if training_args.bits in [4, 8]:
+ from peft.tuners.lora import LoraLayer
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if training_args.bf16:
+ module = module.to(torch.bfloat16)
+ if 'norm' in name:
+ module = module.to(torch.float32)
+ if 'lm_head' in name or 'embed_tokens' in name:
+ if hasattr(module, 'weight'):
+ if training_args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+ from ..datasets.data_modules import make_multitask_data_module
+ data_args.dataset_list = data_args.dataset_list.split(',')
+ rank0_print('Training on: ', data_args.dataset_list)
+ data_module = make_multitask_data_module(tokenizer=tokenizer,
+ data_args=data_args)
+
+ if os.environ.get('TRAIN_MASK_MODULE', None):
+ for n, p in model.named_parameters():
+ if 'mask_extractor' not in n:
+ p.requires_grad = False
+ else:
+ p.requires_grad = True
+
+ if os.environ.get('TRAIN_MASK_MODULE_STAGE3', None):
+ for n, p in model.named_parameters():
+ # if 'mm_projector' in n:
+ if 'mm_projector' in n or 'mask_extractor' in n:
+ p.requires_grad = True
+
+ trainer = SeagullTrainer(model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ **data_module)
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+
+ # time.sleep((local_rank) * 120)
+ trainer.save_state()
+
+ model.config.use_cache = True
+
+ if training_args.lora_enable:
+ state_dict = get_peft_state_maybe_zero_3(
+ model.named_parameters(), training_args.lora_bias
+ )
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
+ model.named_parameters()
+ )
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
+ model.config.save_pretrained(training_args.output_dir)
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
+ else:
+ safe_save_model_for_hf_trainer(trainer=trainer,
+ output_dir=training_args.output_dir)
diff --git a/seagull/train/train_mem.py b/seagull/train/train_mem.py
new file mode 100644
index 0000000000000000000000000000000000000000..a66bec5046f14c9b807d2e93f32d7064e649764f
--- /dev/null
+++ b/seagull/train/train_mem.py
@@ -0,0 +1,13 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
+
+# Need to call this before importing transformers.
+from seagull.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
+
+replace_llama_attn_with_flash_attn()
+
+from seagull.train.train import train
+
+if __name__ == "__main__":
+ train()
diff --git a/seagull/utils.py b/seagull/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fadb45fca366c9c0d1b188f90d296e5eb8c9c3df
--- /dev/null
+++ b/seagull/utils.py
@@ -0,0 +1,126 @@
+import logging
+import logging.handlers
+import os
+import sys
+
+import requests
+
+from seagull.constants import LOGDIR
+
+server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+
+handler = None
+
+
+def build_logger(logger_name, logger_filename):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # Add a file handler for all loggers
+ if handler is None:
+ os.makedirs(LOGDIR, exist_ok=True)
+ filename = os.path.join(LOGDIR, logger_filename)
+ handler = logging.handlers.TimedRotatingFileHandler(
+ filename, when='D', utc=True)
+ handler.setFormatter(formatter)
+
+ for name, item in logging.root.manager.loggerDict.items():
+ if isinstance(item, logging.Logger):
+ item.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ''
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ''
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == '\n':
+ self.logger.log(self.log_level, line.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != '':
+ self.logger.log(self.log_level, self.linebuf.rstrip())
+ self.linebuf = ''
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def violates_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ url = "https://api.openai.com/v1/moderations"
+ headers = {"Content-Type": "application/json",
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
+ text = text.replace("\n", "")
+ data = "{" + '"input": ' + f'"{text}"' + "}"
+ data = data.encode("utf-8")
+ try:
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
+ flagged = ret.json()["results"][0]["flagged"]
+ except requests.exceptions.RequestException as e:
+ flagged = False
+ except KeyError as e:
+ flagged = False
+
+ return flagged
+
+
+def pretty_print_semaphore(semaphore):
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
+