Zevin2023 commited on
Commit
8fa1f84
1 Parent(s): a25e4cb

add online demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. README.md +30 -1
  3. app.py +62 -0
  4. demo/UI.py +143 -0
  5. demo/__pycache__/UI.cpython-310.pyc +0 -0
  6. demo/__pycache__/mask_utils.cpython-310.pyc +0 -0
  7. demo/__pycache__/sam_inference.cpython-310.pyc +0 -0
  8. demo/__pycache__/seagull_inference.cpython-310.pyc +0 -0
  9. demo/mask_utils.py +144 -0
  10. demo/sam_inference.py +102 -0
  11. demo/seagull_inference.py +163 -0
  12. imgs/.DS_Store +0 -0
  13. imgs/Examples/1.png +0 -0
  14. imgs/Examples/2.png +0 -0
  15. seagull/__init__.py +1 -0
  16. seagull/__pycache__/__init__.cpython-310.pyc +0 -0
  17. seagull/__pycache__/constants.cpython-310.pyc +0 -0
  18. seagull/__pycache__/conversation.cpython-310.pyc +0 -0
  19. seagull/__pycache__/mm_utils.cpython-310.pyc +0 -0
  20. seagull/__pycache__/utils.cpython-310.pyc +0 -0
  21. seagull/builder.py +171 -0
  22. seagull/constants.py +12 -0
  23. seagull/conversation.py +381 -0
  24. seagull/mm_utils.py +95 -0
  25. seagull/model/__init__.py +1 -0
  26. seagull/model/__pycache__/Q_A.cpython-310.pyc +0 -0
  27. seagull/model/__pycache__/Q_A_pretrain.cpython-310.pyc +0 -0
  28. seagull/model/__pycache__/Q_A_pretrain_level.cpython-310.pyc +0 -0
  29. seagull/model/__pycache__/Q_A_stage3.cpython-310.pyc +0 -0
  30. seagull/model/__pycache__/__init__.cpython-310.pyc +0 -0
  31. seagull/model/__pycache__/layer.cpython-310.pyc +0 -0
  32. seagull/model/__pycache__/layer_osprey.cpython-310.pyc +0 -0
  33. seagull/model/__pycache__/osprey_arch.cpython-310.pyc +0 -0
  34. seagull/model/__pycache__/seagull_arch.cpython-310.pyc +0 -0
  35. seagull/model/__pycache__/stage2_distrotion_maker.cpython-310.pyc +0 -0
  36. seagull/model/consolidate.py +26 -0
  37. seagull/model/language_model/__pycache__/osprey_llama.cpython-310.pyc +0 -0
  38. seagull/model/language_model/__pycache__/seagull_llama.cpython-310.pyc +0 -0
  39. seagull/model/language_model/seagull_llama.py +128 -0
  40. seagull/model/layer.py +250 -0
  41. seagull/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  42. seagull/model/multimodal_encoder/__pycache__/clip.cpython-310.pyc +0 -0
  43. seagull/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc +0 -0
  44. seagull/model/multimodal_encoder/builder.py +7 -0
  45. seagull/model/multimodal_encoder/clip.py +40 -0
  46. seagull/model/multimodal_encoder/clip_encoder.py +59 -0
  47. seagull/model/multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
  48. seagull/model/multimodal_projector/builder.py +52 -0
  49. seagull/model/seagull_arch.py +281 -0
  50. seagull/train/__pycache__/seagull_trainer.cpython-310.pyc +0 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pth
2
+ *.bin
README.md CHANGED
@@ -9,4 +9,33 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ <img src="https://github.com/chencn2020/SEAGULL/raw/main/imgs/Logo/logo.png" alt="SEAGULL" style="height: auto; width: 100%;">
13
+
14
+ <div style="display: flex; justify-content: center; gap: 10px; flex-wrap: wrap; width: 100%;">
15
+ <a href=""><img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm-dark.svg" alt="Open in Spaces" style="max-width: 100%; height: auto;"></a>
16
+ <a href="https://arxiv.org/abs/2411.10161"><img src="https://img.shields.io/badge/Arxiv-2411:10161-red" style="max-width: 100%; height: auto;"></a>
17
+ <a href="https://hits.seeyoufarm.com"><img src="https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fhuggingface.co%2Fdatasets%2FZevin2023%2FSEAGULL-100w&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visitors&edge_flat=false" style="max-width: 100%; height: auto;"></a>
18
+ <a href='https://github.com/chencn2020/SEAGULL/'><img src='https://img.shields.io/github/stars/chencn2020/Seagull.svg?style=social' style="max-width: 100%; height: auto;"></a>
19
+ </div>
20
+
21
+ ## Acknowledgement 💌
22
+ <div id="Acknowledgement"></div>
23
+ - [Osprey](https://github.com/CircleRadon/Osprey) and [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA): We build this repostory based on them.
24
+ - [RAISE](http://loki.disi.unitn.it/RAISE/): The Dist. images in SEAGULL-100w are constructed based on this dataset.
25
+ - [SAM](https://segment-anything.com/) and [SEEM](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once): The mask-based ROIs are generated using these two awesome works. And SAM are used to get the segmentation result in the demo.
26
+ - [TOPIQ](https://github.com/chaofengc/IQA-PyTorch): The quality scores and importance scores for ROIs are generated using this great FR-IQA.
27
+
28
+
29
+ ## Citation 🖊️
30
+ If our work is useful to your research, we will be grateful for you to cite our paper:
31
+ ```
32
+ @misc{chen2024seagull,
33
+ title={SEAGULL: No-reference Image Quality Assessment for Regions of Interest via Vision-Language Instruction Tuning},
34
+ author={Zewen Chen and Juan Wang and Wen Wang and Sunhan Xu and Hang Xiong and Yun Zeng and Jian Guo and Shuxun Wang and Chunfeng Yuan and Bing Li and Weiming Hu},
35
+ year={2024},
36
+ eprint={2411.10161},
37
+ archivePrefix={arXiv},
38
+ primaryClass={cs.CV},
39
+ url={https://arxiv.org/abs/2411.10161},
40
+ }
41
+ ```
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from demo.UI import Main_ui
3
+
4
+ if __name__ == '__main__':
5
+ import subprocess
6
+ import sys
7
+ def run_command(command):
8
+ subprocess.check_call([sys.executable, '-m'] + command.split(), shell=False)
9
+
10
+ # Install the package in editable mode
11
+ run_command("pip install -e .")
12
+
13
+ # Install NVM (Node Version Manager)
14
+ run_command("curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.3/install.sh | bash")
15
+
16
+ # Source the appropriate shell configuration file
17
+ run_command("source ~/.bashrc") # You can change to ~/.zshrc based on your shell
18
+
19
+ # Install Node.js version 18.16.0
20
+ run_command("nvm install v18.16.0")
21
+
22
+ # Install pnpm (package manager)
23
+ run_command("curl -fsSL https://get.pnpm.io/install.sh | sh -")
24
+
25
+ # Source the shell configuration file again (for pnpm)
26
+ run_command("source ~/.bashrc") # You can change to ~/.zshrc based on your shell
27
+
28
+ # Verify if pnpm was installed correctly
29
+ run_command("pnpm --version")
30
+
31
+ # Clone the Gradio BBox repository
32
+ run_command("git clone https://github.com/chencn2020/gradio-bbox.git")
33
+
34
+ # Change into the cloned repository directory
35
+ run_command("cd gradio-bbox")
36
+
37
+ # Build frontend
38
+ run_command("bash scripts/build_frontend.sh")
39
+
40
+
41
+
42
+ # Change back to the previous directory
43
+ run_command("cd ..")
44
+
45
+ # Install the package again in editable mode
46
+ run_command("pip install -e .")
47
+
48
+ # Install Segment Anything repository from GitHub
49
+ run_command("pip install git+https://github.com/facebookresearch/segment-anything.git")
50
+
51
+ # Download the model checkpoint
52
+ run_command("curl -o ./checkpoints/sam_vit_b_01ec64.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth")
53
+
54
+
55
+
56
+ parser = argparse.ArgumentParser(description='SEAGULL', formatter_class=argparse.RawTextHelpFormatter)
57
+ parser.add_argument('--model', help='path to seagull model', default='Zevin2023/SEAGULL-7B')
58
+ parser.add_argument('--example_path', help='path to examples', default='./imgs/Examples')
59
+ args = parser.parse_args()
60
+
61
+ demo = Main_ui(args).load_demo()
62
+ demo.launch(server_port=7530)
demo/UI.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from demo.sam_inference import SAM_Inference
4
+ from demo.seagull_inference import Seagull
5
+ from demo.mask_utils import ImageSketcher
6
+
7
+ class Main_ui():
8
+ def __init__(self, args) -> None:
9
+ self.args = args
10
+ self.seagull = Seagull(model_path=args.model)
11
+
12
+ self.example_list = self.load_example()
13
+ self.sam = SAM_Inference()
14
+ # self.sam_predictor = get_sam_predictor()
15
+ # self.mask_generator = get_mask_generator()
16
+
17
+ def load_example(self):
18
+ examples = []
19
+ for file in sorted(os.listdir(self.args.example_path)):
20
+ examples.append([os.path.join(self.args.example_path, file)])
21
+ return examples
22
+
23
+ def load_demo(self):
24
+ with gr.Blocks() as demo:
25
+ preprocessed_img = gr.State(value=None)
26
+ binary_mask = gr.State(value=None)
27
+
28
+ with gr.Row():
29
+ gr.Markdown("""
30
+ <img src="https://github.com/chencn2020/SEAGULL/raw/main/imgs/Logo/logo.png" alt="SEAGULL" style="height: auto; width: 100%; margin-bottom: 3%;">
31
+
32
+ ## 🔔 Usage
33
+
34
+ Firstly, you need to upload an image and choose the analyse types **(quality score, importance score and distortion analysis)**.
35
+
36
+ Then you can click **(points)** or pull a frame **(bbox)** on the image to indicate the region of interest (ROIs).
37
+
38
+ After that, this demo process the following steps:
39
+
40
+ > 1. SAM extracts the mask-based ROIs based on your clicked points or frame.
41
+
42
+ > 2. Based on the uploaded image and mask-based ROIs, SEAGULL analyses the quality of the ROIs.
43
+
44
+ """)
45
+
46
+ with gr.TabItem("Mask-based ROIs (Points)"):
47
+ with gr.Row():
48
+ input_image_ponit = gr.Image(type="numpy", label='Input image', height=512) # input image
49
+ output_mask_ponit = gr.Image(label='Mask-based ROI', height=512) # output binary mask
50
+
51
+ with gr.Row():
52
+ output_mask_point_on_img = gr.Image(label='Mask on image', height=512) # mask on image for better view
53
+
54
+ with gr.Column():
55
+ radio_point = gr.Radio(label='Analysis type', choices=['Quality Score', 'Importance Score', 'Distortion Analysis'], value='Quality Score')
56
+ output_text_point = gr.Textbox(label='Analysis Results')
57
+ point_seg_button = gr.Button('Analysis')
58
+
59
+ point_example = gr.Dataset(label='Examples', components=[input_image_ponit], samples=self.example_list)
60
+
61
+ with gr.TabItem("Mask-based ROIs (BBox)"):
62
+ with gr.Row():
63
+ input_image_BBOX = ImageSketcher(type="numpy", label='Input image', height=512)
64
+ output_mask_BBOX = gr.Image(label='Mask-based ROI', height=512)
65
+
66
+ with gr.Row():
67
+ output_BBOX_mask_on_img = gr.Image(label='Mask on image', height=512)
68
+
69
+ with gr.Column():
70
+ radio_BBOX = gr.Radio(label='Analysis type', choices=['Quality Score', 'Importance Score', 'Distortion Analysis'], value='Quality Score')
71
+ output_text_BBOX = gr.Textbox(label='ROI Quality Analysis')
72
+ box_seg_button = gr.Button('Generate mask and analysis')
73
+ box_analyse_button = gr.Button('Analysis')
74
+
75
+ BBOX_example = gr.Dataset(label='Examples', components=[input_image_BBOX], samples=self.example_list)
76
+
77
+ # click point
78
+ input_image_ponit.upload(
79
+ self.seagull.init_image,
80
+ [input_image_ponit],
81
+ [preprocessed_img, input_image_ponit, input_image_BBOX]
82
+ )
83
+
84
+ point_example.click(
85
+ self.seagull.init_image,
86
+ [point_example],
87
+ [preprocessed_img, input_image_ponit, input_image_BBOX]
88
+ )
89
+
90
+ # after clicking on the image
91
+ input_image_ponit.select(
92
+ self.sam.img_select_point,
93
+ [preprocessed_img],
94
+ [input_image_ponit, output_mask_ponit, output_mask_point_on_img, binary_mask]
95
+ ).then(
96
+ self.seagull.seagull_predict,
97
+ [preprocessed_img, binary_mask, radio_point],
98
+ [output_text_point]
99
+ )
100
+
101
+ point_seg_button.click(
102
+ self.seagull.seagull_predict,
103
+ [preprocessed_img, binary_mask, radio_point],
104
+ [output_text_point]
105
+ )
106
+
107
+ # draw frame
108
+ input_image_BBOX.upload(
109
+ self.seagull.init_image,
110
+ [input_image_BBOX],
111
+ [preprocessed_img, input_image_ponit, input_image_BBOX]
112
+ )
113
+
114
+ BBOX_example.click(
115
+ self.seagull.init_image,
116
+ [BBOX_example],
117
+ [preprocessed_img, input_image_ponit, input_image_BBOX]
118
+ )
119
+
120
+ # after drawing a frame on the image
121
+ input_image_BBOX.select(
122
+ self.sam.gen_box_seg,
123
+ [input_image_BBOX],
124
+ [output_mask_BBOX, output_BBOX_mask_on_img, binary_mask]
125
+ )
126
+
127
+ box_seg_button.click(
128
+ self.sam.gen_box_seg,
129
+ [input_image_BBOX],
130
+ [output_mask_BBOX, output_BBOX_mask_on_img, binary_mask]
131
+ ).then(
132
+ self.seagull.seagull_predict,
133
+ [preprocessed_img, binary_mask, radio_BBOX],
134
+ [output_text_BBOX]
135
+ )
136
+
137
+ box_analyse_button.click(
138
+ self.seagull.seagull_predict,
139
+ [preprocessed_img, binary_mask, radio_BBOX],
140
+ [output_text_BBOX]
141
+ )
142
+
143
+ return demo
demo/__pycache__/UI.cpython-310.pyc ADDED
Binary file (4.2 kB). View file
 
demo/__pycache__/mask_utils.cpython-310.pyc ADDED
Binary file (4.72 kB). View file
 
demo/__pycache__/sam_inference.cpython-310.pyc ADDED
Binary file (3.51 kB). View file
 
demo/__pycache__/seagull_inference.cpython-310.pyc ADDED
Binary file (5.48 kB). View file
 
demo/mask_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+
7
+ class ImageSketcher(gr.Image):
8
+ """
9
+ Code is from https://github.com/jshilong/GPT4RoI/blob/7c157b5f33914f21cfbc804fb301d3ce06324193/gpt4roi/app.py#L365
10
+
11
+ Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
12
+ """
13
+
14
+ is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
15
+
16
+ def __init__(self, **kwargs):
17
+ super().__init__(tool='boxes', **kwargs)
18
+
19
+ def preprocess(self, x):
20
+ if x is None:
21
+ return x
22
+ if self.tool == 'boxes' and self.source in ['upload', 'webcam']:
23
+ if isinstance(x, str):
24
+ x = {'image': x, 'boxes': []}
25
+ else:
26
+ assert isinstance(x, dict)
27
+ assert isinstance(x['image'], str)
28
+ assert isinstance(x['boxes'], list)
29
+ x = super().preprocess(x)
30
+ return x
31
+
32
+ def process_mask_to_show(mask):
33
+ '''
34
+ Process the mask to show on the gradio.Image
35
+ '''
36
+ mask = np.array(mask > 0.1, dtype=np.uint8) * 255
37
+ mask_stacked = np.stack([mask] * 3, axis=-1)
38
+
39
+ return mask_stacked
40
+
41
+ def img_add_masks(img_, colored_mask, mask, linewidth=2):
42
+ if type(img_) is np.ndarray:
43
+ img = Image.fromarray(img_, mode='RGB').convert('RGBA')
44
+ else:
45
+ img = img_.copy()
46
+ h, w = img.height, img.width
47
+ # contour
48
+ temp = np.zeros((h, w, 1))
49
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
50
+ cv2.drawContours(temp, contours, -1, (255, 255, 255), linewidth)
51
+ color = np.array([1, 1, 1, 1])
52
+ contour_mask = temp * color.reshape(1, 1, -1)
53
+
54
+ overlay_inner = Image.fromarray(colored_mask.astype(np.uint8), 'RGBA')
55
+ img.paste(overlay_inner, (0, 0), overlay_inner)
56
+
57
+ overlay_contour = Image.fromarray(contour_mask.astype(np.uint8), 'RGBA')
58
+ img.paste(overlay_contour, (0, 0), overlay_contour)
59
+ return img
60
+
61
+ def gen_colored_masks(
62
+ annotation,
63
+ random_color=False,
64
+ ):
65
+ """
66
+ Code is largely based on https://github.com/CASIA-IVA-Lab/FastSAM/blob/4d153e909f0ad9c8ecd7632566e5a24e21cf0071/utils/tools_gradio.py#L130
67
+ """
68
+ device = annotation.device
69
+ mask_sum = annotation.shape[0]
70
+ height = annotation.shape[1]
71
+ weight = annotation.shape[2]
72
+ areas = torch.sum(annotation, dim=(1, 2))
73
+ sorted_indices = torch.argsort(areas, descending=False)
74
+ annotation = annotation[sorted_indices]
75
+
76
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
77
+ if random_color:
78
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
79
+ else:
80
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
81
+ [30 / 255, 144 / 255, 255 / 255]
82
+ ).to(device)
83
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
84
+ visual = torch.cat([color, transparency], dim=-1)
85
+ mask_image = torch.unsqueeze(annotation, -1) * visual
86
+
87
+ mask = torch.zeros((height, weight, 4)).to(device)
88
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
89
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
90
+
91
+ mask[h_indices, w_indices, :] = mask_image[indices]
92
+ mask_cpu = mask.cpu().numpy()
93
+
94
+ return mask_cpu, sorted_indices
95
+
96
+ def mask_foreground(mask, trans=0.6, random_color=True):
97
+ if random_color:
98
+ color = np.concatenate([np.random.random(3) * 255, np.array([trans * 255])], axis=0)
99
+ else:
100
+ color = np.array([30, 144, 255, trans * 255])
101
+ h, w = mask.shape[-2:]
102
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
103
+
104
+ return mask_image
105
+
106
+
107
+ def mask_background(mask, trans=0.5):
108
+ h, w = mask.shape[-2:]
109
+ mask_image = (1 - mask.reshape(h, w, 1)) * np.array([0, 0, 0, trans * 255])
110
+
111
+ return mask_image
112
+
113
+
114
+ def mask_select_point(all_masks, output_mask_2_raw, mask_order, evt: gr.SelectData):
115
+ h, w = output_mask_2_raw.height, output_mask_2_raw.width
116
+ pointed_mask = None
117
+ for i in range(len(mask_order)):
118
+ idx = mask_order[i]
119
+ msk = all_masks[idx]
120
+ if msk[evt.index[1], evt.index[0]] == 1:
121
+ pointed_mask = msk.copy()
122
+ break
123
+
124
+ if pointed_mask is not None:
125
+ contours, hierarchy = cv2.findContours(pointed_mask.astype("uint8"), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
126
+ ret = output_mask_2_raw.copy()
127
+
128
+ temp = np.zeros((h, w, 1))
129
+ contours, _ = cv2.findContours(msk.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
130
+ cv2.drawContours(temp, contours, -1, (255, 255, 255), 3)
131
+ color = np.array([1, 1, 1, 1])
132
+ contour_mask = temp * color.reshape(1, 1, -1)
133
+
134
+ colored_mask = mask_background(pointed_mask)
135
+
136
+ overlay_inner = Image.fromarray(colored_mask.astype(np.uint8), 'RGBA')
137
+ ret.paste(overlay_inner, (0, 0), overlay_inner)
138
+
139
+ overlay_contour = Image.fromarray(contour_mask.astype(np.uint8), 'RGBA')
140
+ ret.paste(overlay_contour, (0, 0), overlay_contour)
141
+
142
+ return ret, pointed_mask
143
+ else:
144
+ return output_mask_2_raw, None
demo/sam_inference.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import numpy as np
4
+ import torch
5
+ from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
6
+ import gradio as gr
7
+ import cv2
8
+ from demo.mask_utils import *
9
+
10
+ class SAM_Inference:
11
+ def __init__(self, model_type='vit_b', device='cuda') -> None:
12
+ models = {
13
+ 'vit_b': './checkpoints/sam_vit_b_01ec64.pth',
14
+ 'vit_l': './checkpoints/sam_vit_l_0b3195.pth',
15
+ 'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
16
+ }
17
+
18
+ sam = sam_model_registry[model_type](checkpoint=models[model_type])
19
+ sam = sam.to(device)
20
+
21
+ self.predictor = SamPredictor(sam)
22
+ self.mask_generator = SamAutomaticMaskGenerator(model=sam)
23
+
24
+ def img_select_point(self, original_img: np.ndarray, evt: gr.SelectData):
25
+ img = original_img.copy()
26
+ sel_pix = [(evt.index, 1)] # append the foreground_point
27
+
28
+ masks = self.run_inference(original_img, sel_pix)
29
+ for point, label in sel_pix:
30
+ cv2.circle(img, point, 5, (240, 240, 240), -1, 0)
31
+ cv2.circle(img, point, 5, (30, 144, 255), 2, 0)
32
+
33
+ mask = masks[0][0]
34
+ colored_mask = mask_foreground(mask)
35
+ res = img_add_masks(original_img, colored_mask, mask)
36
+ return img, process_mask_to_show(mask), res, mask
37
+
38
+ def gen_box_seg(self, inp):
39
+ if inp is None:
40
+ raise gr.Error("Please upload an image first!")
41
+ image = inp['image']
42
+ if len(inp['boxes']) == 0:
43
+ raise gr.Error("Please clear the raw boxes and draw a box first!")
44
+ boxes = inp['boxes'][-1]
45
+
46
+ input_box = np.array([boxes[0], boxes[1], boxes[2], boxes[3]]).astype(int)
47
+
48
+ masks = self.predict_box(image, input_box)
49
+
50
+ mask = masks[0][0]
51
+ colored_mask = mask_foreground(mask)
52
+ res = img_add_masks(image, colored_mask, mask)
53
+
54
+ return process_mask_to_show(mask), res, mask
55
+
56
+ def run_inference(self, input_x, selected_points):
57
+ if len(selected_points) == 0:
58
+ return []
59
+
60
+ self.predictor.set_image(input_x)
61
+
62
+ points = torch.Tensor(
63
+ [p for p, _ in selected_points]
64
+ ).to(self.predictor.device).unsqueeze(0)
65
+
66
+ labels = torch.Tensor(
67
+ [int(l) for _, l in selected_points]
68
+ ).to(self.predictor.device).unsqueeze(0)
69
+
70
+ transformed_points = self.predictor.transform.apply_coords_torch(
71
+ points, input_x.shape[:2])
72
+
73
+ # predict segmentation according to the boxes
74
+ masks, scores, logits = self.predictor.predict_torch(
75
+ point_coords=transformed_points,
76
+ point_labels=labels,
77
+ multimask_output=False,
78
+ )
79
+ masks = masks.cpu().detach().numpy()
80
+
81
+ gc.collect()
82
+ torch.cuda.empty_cache()
83
+
84
+ return masks
85
+
86
+ def predict_box(self, input_x, input_box):
87
+ self.predictor.set_image(input_x)
88
+
89
+ input_boxes = torch.tensor(input_box[None, :], device=self.predictor.device)
90
+ transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, input_x.shape[:2])
91
+
92
+ masks, _, _ = self.predictor.predict_torch(
93
+ point_coords=None,
94
+ point_labels=None,
95
+ boxes=transformed_boxes,
96
+ multimask_output=False
97
+ )
98
+ masks = masks.cpu().detach().numpy()
99
+
100
+ gc.collect()
101
+ torch.cuda.empty_cache()
102
+ return masks
demo/seagull_inference.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from seagull.utils import disable_torch_init
3
+ from transformers import AutoTokenizer, CLIPImageProcessor
4
+ from seagull.model.language_model.seagull_llama import SeagullLlamaForCausalLM
5
+ from seagull.mm_utils import tokenizer_image_token
6
+ from seagull.conversation import conv_templates, SeparatorStyle
7
+ from seagull.constants import IMAGE_TOKEN_INDEX
8
+ from seagull.train.train import DataArguments
9
+
10
+ from functools import partial
11
+ import os
12
+ import numpy as np
13
+ import cv2
14
+ from typing import List
15
+ from PIL import Image
16
+
17
+ class Seagull():
18
+ def __init__(self, model_path, device='cuda'):
19
+ disable_torch_init()
20
+ model_path = os.path.expanduser(model_path)
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=2048, padding_side="right", use_fast=True)
22
+ self.model = SeagullLlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16,).to(device)
23
+ self.tokenizer.pad_token = self.tokenizer.unk_token
24
+
25
+ self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":512}, resample=3, do_center_crop=True, crop_size={"height": 512, "width": 512},
26
+ do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
27
+ image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
28
+
29
+ spi_tokens = ['<global>', '<local>']
30
+ self.tokenizer.add_tokens(spi_tokens, special_tokens=True)
31
+
32
+ for m in self.model.modules():
33
+ m.tokenizer = self.tokenizer
34
+
35
+ vision_tower = self.model.get_vision_tower()
36
+ if not vision_tower.is_loaded:
37
+ vision_tower.load_model()
38
+ vision_tower.to(dtype=torch.float16, device=device)
39
+
40
+ begin_str = "<image>\nThis provides an overview of the image.\n Please answer the following questions about the provided region. Note: Distortions include: blur, colorfulness, compression, contrast exposure and noise.\n Here is the region <global><local>. "
41
+
42
+ instruction = {
43
+ 'distortion analysis': 'Provide the distortion type of this region.',
44
+ 'quality score': 'Analyze the quality of this region.',
45
+ 'importance score': 'Consider the impact of this region on the overall image quality. Analyze its importance to the overall image quality.'
46
+ }
47
+
48
+ self.ids_input = {}
49
+ for ins_type, ins in instruction.items():
50
+ conv = conv_templates['seagull_v1'].copy()
51
+ qs = begin_str + ins
52
+ conv.append_message(conv.roles[0], qs)
53
+ conv.append_message(conv.roles[1], None)
54
+ prompt = conv.get_prompt()
55
+ self.ids_input[ins_type] = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
56
+
57
+ self.stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
58
+
59
+ def init_image(self, img):
60
+ if isinstance(img, dict):
61
+ img = img['image']
62
+ elif isinstance(img, List):
63
+ img = cv2.imread(img[0])
64
+ img = img[:, :, ::-1]
65
+ h_, w_ = img.shape[:2]
66
+ if h_ > 512:
67
+ ratio = 512 / h_
68
+ new_h, new_w = int(h_ * ratio), int(w_ * ratio)
69
+ preprocessed_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
70
+ else:
71
+ preprocessed_img = img.copy()
72
+
73
+ return (preprocessed_img, preprocessed_img, preprocessed_img)
74
+
75
+ def preprocess(self, img):
76
+ image = self.image_processor.preprocess(img,
77
+ do_center_crop=False,
78
+ return_tensors='pt')['pixel_values'][0]
79
+
80
+ image = torch.nn.functional.interpolate(image.unsqueeze(0),
81
+ size=(512, 512),
82
+ mode='bilinear',
83
+ align_corners=False).squeeze(0)
84
+
85
+ return image
86
+
87
+ def seagull_predict(self, img, mask, instruct_type):
88
+ image = self.preprocess(img)
89
+
90
+ mask = np.array(mask, dtype=np.int)
91
+ ys, xs = np.where(mask > 0)
92
+ if len(xs) > 0 and len(ys) > 0:
93
+ # Find the minimal bounding rectangle for the entire mask
94
+ x_min, x_max = np.min(xs), np.max(xs)
95
+ y_min, y_max = np.min(ys), np.max(ys)
96
+ w1 = x_max - x_min
97
+ h1 = y_max - y_min
98
+
99
+ bounding_box = (x_min, y_min, w1, h1)
100
+ else:
101
+ bounding_box = None
102
+
103
+ mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)
104
+ mask = np.array(mask > 0.1, dtype=np.uint8)
105
+ masks = torch.Tensor(mask).unsqueeze(0).to(self.model.device)
106
+
107
+ input_ids = self.ids_input[instruct_type.lower()]
108
+
109
+ x1, y1, w1, h1 = list(map(int, bounding_box)) # x y w h
110
+ cropped_img = img[y1:y1 + h1, x1:x1 + w1]
111
+ cropped_img = Image.fromarray(cropped_img)
112
+ cropped_img = self.preprocess(cropped_img)
113
+
114
+ with torch.inference_mode():
115
+
116
+ self.model.orig_forward = self.model.forward
117
+ self.model.forward = partial(self.model.orig_forward,
118
+ img_metas=[None],
119
+ masks=[masks.half()],
120
+ cropped_img=cropped_img.unsqueeze(0)
121
+ )
122
+ output_ids = self.model.generate(
123
+ input_ids,
124
+ images=image.unsqueeze(0).half().to(self.model.device),
125
+ do_sample=False,
126
+ temperature=1,
127
+ max_new_tokens=2048,
128
+ use_cache=True,
129
+ num_beams=1,
130
+ top_k = 0, # 不进行topk
131
+ top_p = 1, # 累计概率为
132
+ )
133
+
134
+ self.model.forward = self.model.orig_forward
135
+
136
+ input_token_len = input_ids.shape[1]
137
+ n_diff_input_output = (
138
+ input_ids != output_ids[:, :input_token_len]).sum().item()
139
+ if n_diff_input_output > 0:
140
+ print(
141
+ f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
142
+ outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
143
+ skip_special_tokens=True)[0]
144
+
145
+ outputs = outputs.strip()
146
+ if outputs.endswith(self.stop_str):
147
+ outputs = outputs[:-len(self.stop_str)]
148
+ outputs = outputs.strip()
149
+ if ':' in outputs:
150
+ outputs = outputs.split(':')[1]
151
+
152
+ outputs_list = outputs.split('.')
153
+ outputs_list_final = []
154
+ outputs_str = ''
155
+ for output in outputs_list:
156
+ if output not in outputs_list_final:
157
+ if output=='':
158
+ continue
159
+ outputs_list_final.append(output)
160
+ outputs_str+=output+'.'
161
+ else:
162
+ break
163
+ return outputs_str
imgs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
imgs/Examples/1.png ADDED
imgs/Examples/2.png ADDED
seagull/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import SeagullLlamaForCausalLM
seagull/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (190 Bytes). View file
 
seagull/__pycache__/constants.cpython-310.pyc ADDED
Binary file (450 Bytes). View file
 
seagull/__pycache__/conversation.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
seagull/__pycache__/mm_utils.cpython-310.pyc ADDED
Binary file (4.25 kB). View file
 
seagull/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.99 kB). View file
 
seagull/builder.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from seagull.model import *
23
+ from seagull.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
27
+ kwargs = {"device_map": device_map, **kwargs}
28
+
29
+ if device != "cuda":
30
+ kwargs['device_map'] = {"": device}
31
+
32
+ if load_8bit:
33
+ kwargs['load_in_8bit'] = True
34
+ elif load_4bit:
35
+ kwargs['load_in_4bit'] = True
36
+ kwargs['quantization_config'] = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.float16,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type='nf4'
41
+ )
42
+ else:
43
+ kwargs['torch_dtype'] = torch.float16
44
+
45
+ if use_flash_attn:
46
+ kwargs['attn_implementation'] = 'flash_attention_2'
47
+
48
+ if 'seagull' in model_name.lower() or True:
49
+ # Load LLaVA model
50
+ if 'lora' in model_name.lower() and model_base is None:
51
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
52
+ if 'lora' in model_name.lower() and model_base is not None or True:
53
+ from seagull.model.language_model.seagull_llama import SeagullConfig
54
+ lora_cfg_pretrained = SeagullConfig.from_pretrained(model_path)
55
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
56
+ print('Loading LLaVA from base model...')
57
+ model = SeagullLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
58
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
59
+ if model.lm_head.weight.shape[0] != token_num:
60
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
61
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
62
+
63
+ print('Loading additional LLaVA weights...')
64
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
65
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
66
+ else:
67
+ # this is probably from HF Hub
68
+ from huggingface_hub import hf_hub_download
69
+ def load_from_hf(repo_id, filename, subfolder=None):
70
+ cache_file = hf_hub_download(
71
+ repo_id=repo_id,
72
+ filename=filename,
73
+ subfolder=subfolder)
74
+ return torch.load(cache_file, map_location='cpu')
75
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
76
+
77
+ for k, v in non_lora_trainables.items():
78
+ print(k)
79
+ print('print non lora')
80
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
81
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
82
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
83
+ model.load_state_dict(non_lora_trainables, strict=False)
84
+
85
+ from peft import PeftModel
86
+ print('Loading LoRA weights...')
87
+ model = PeftModel.from_pretrained(model, model_path)
88
+ print('Merging LoRA weights...')
89
+ model = model.merge_and_unload()
90
+ print('Model is loaded...')
91
+ elif model_base is not None:
92
+ # this may be mm projector only
93
+ print('Loading LLaVA from base model...')
94
+ if 'mpt' in model_name.lower():
95
+ if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
96
+ shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
97
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
98
+ cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
99
+ model = SeagullMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
100
+ else:
101
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
102
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
103
+ model = SeagullLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
104
+
105
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
106
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
107
+ model.load_state_dict(mm_projector_weights, strict=False)
108
+ else:
109
+ if 'mpt' in model_name.lower():
110
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
111
+ model = SeagullMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
112
+ elif 'mistral' in model_name.lower():
113
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
114
+ model = SeagullMistralForCausalLM.from_pretrained(
115
+ model_path,
116
+ low_cpu_mem_usage=True,
117
+ **kwargs
118
+ )
119
+ else:
120
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
121
+ model = SeagullLlamaForCausalLM.from_pretrained(
122
+ model_path,
123
+ low_cpu_mem_usage=True,
124
+ **kwargs
125
+ )
126
+ else:
127
+ # Load language model
128
+ if model_base is not None:
129
+ # PEFT model
130
+ from peft import PeftModel
131
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
132
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
133
+ print(f"Loading LoRA weights from {model_path}")
134
+ model = PeftModel.from_pretrained(model, model_path)
135
+ print(f"Merging weights")
136
+ model = model.merge_and_unload()
137
+ print('Convert to FP16...')
138
+ model.to(torch.float16)
139
+ else:
140
+ use_fast = False
141
+ if 'mpt' in model_name.lower():
142
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
143
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
144
+ else:
145
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
146
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
147
+
148
+ image_processor = None
149
+
150
+ if 'seagull' in model_name.lower() or True:
151
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
152
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
153
+ if mm_use_im_patch_token:
154
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
155
+ if mm_use_im_start_end:
156
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
157
+ model.resize_token_embeddings(len(tokenizer))
158
+
159
+ vision_tower = model.get_vision_tower()
160
+ if not vision_tower.is_loaded:
161
+ vision_tower.load_model(device_map=device_map)
162
+ if device_map != 'auto':
163
+ vision_tower.to(device=device_map, dtype=torch.float16)
164
+ image_processor = vision_tower.image_processor
165
+
166
+ if hasattr(model.config, "max_sequence_length"):
167
+ context_len = model.config.max_sequence_length
168
+ else:
169
+ context_len = 2048
170
+
171
+ return tokenizer, model, image_processor, context_len
seagull/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
seagull/conversation.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class Conversation:
17
+ """A class that keeps all conversation history."""
18
+ system: str
19
+ roles: List[str]
20
+ messages: List[List[str]]
21
+ offset: int
22
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
23
+ sep: str = "###"
24
+ sep2: str = None
25
+ version: str = "Unknown"
26
+
27
+ skip_next: bool = False
28
+
29
+ def get_prompt(self):
30
+ messages = self.messages
31
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
32
+ messages = self.messages.copy()
33
+ init_role, init_msg = messages[0].copy()
34
+ init_msg = init_msg[0].replace("<image>", "").strip()
35
+ if 'mmtag' in self.version:
36
+ messages[0] = (init_role, init_msg)
37
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
38
+ messages.insert(1, (self.roles[1], "Received."))
39
+ else:
40
+ messages[0] = (init_role, "<image>\n" + init_msg)
41
+
42
+ if self.sep_style == SeparatorStyle.SINGLE:
43
+ ret = self.system + self.sep
44
+ for role, message in messages:
45
+ if message:
46
+ if type(message) is tuple:
47
+ message, _, _ = message
48
+ ret += role + ": " + message + self.sep
49
+ else:
50
+ ret += role + ":"
51
+ elif self.sep_style == SeparatorStyle.TWO:
52
+ seps = [self.sep, self.sep2]
53
+ ret = self.system + seps[0]
54
+ for i, (role, message) in enumerate(messages):
55
+ if message:
56
+ if type(message) is tuple:
57
+ message, _, _ = message
58
+ ret += role + ": " + message + seps[i % 2]
59
+ else:
60
+ ret += role + ":"
61
+ elif self.sep_style == SeparatorStyle.MPT:
62
+ ret = self.system + self.sep
63
+ for role, message in messages:
64
+ if message:
65
+ if type(message) is tuple:
66
+ message, _, _ = message
67
+ ret += role + message + self.sep
68
+ else:
69
+ ret += role
70
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
71
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
72
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
73
+ ret = ""
74
+
75
+ for i, (role, message) in enumerate(messages):
76
+ if i == 0:
77
+ assert message, "first message should not be none"
78
+ assert role == self.roles[0], "first message should come from user"
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ if i == 0: message = wrap_sys(self.system) + message
83
+ if i % 2 == 0:
84
+ message = wrap_inst(message)
85
+ ret += self.sep + message
86
+ else:
87
+ ret += " " + message + " " + self.sep2
88
+ else:
89
+ ret += ""
90
+ ret = ret.lstrip(self.sep)
91
+ elif self.sep_style == SeparatorStyle.PLAIN:
92
+ seps = [self.sep, self.sep2]
93
+ ret = self.system
94
+ for i, (role, message) in enumerate(messages):
95
+ if message:
96
+ if type(message) is tuple:
97
+ message, _, _ = message
98
+ ret += message + seps[i % 2]
99
+ else:
100
+ ret += ""
101
+ else:
102
+ raise ValueError(f"Invalid style: {self.sep_style}")
103
+
104
+ return ret
105
+
106
+ def append_message(self, role, message):
107
+ self.messages.append([role, message])
108
+
109
+ def get_images(self, return_pil=False):
110
+ images = []
111
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
112
+ if i % 2 == 0:
113
+ if type(msg) is tuple:
114
+ import base64
115
+ from io import BytesIO
116
+ from PIL import Image
117
+ msg, image, image_process_mode = msg
118
+ if image_process_mode == "Pad":
119
+ def expand2square(pil_img, background_color=(122, 116, 104)):
120
+ width, height = pil_img.size
121
+ if width == height:
122
+ return pil_img
123
+ elif width > height:
124
+ result = Image.new(pil_img.mode, (width, width), background_color)
125
+ result.paste(pil_img, (0, (width - height) // 2))
126
+ return result
127
+ else:
128
+ result = Image.new(pil_img.mode, (height, height), background_color)
129
+ result.paste(pil_img, ((height - width) // 2, 0))
130
+ return result
131
+ image = expand2square(image)
132
+ elif image_process_mode in ["Default", "Crop"]:
133
+ pass
134
+ elif image_process_mode == "Resize":
135
+ image = image.resize((336, 336))
136
+ else:
137
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
138
+ max_hw, min_hw = max(image.size), min(image.size)
139
+ aspect_ratio = max_hw / min_hw
140
+ max_len, min_len = 800, 400
141
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
142
+ longest_edge = int(shortest_edge * aspect_ratio)
143
+ W, H = image.size
144
+ if longest_edge != max(image.size):
145
+ if H > W:
146
+ H, W = longest_edge, shortest_edge
147
+ else:
148
+ H, W = shortest_edge, longest_edge
149
+ image = image.resize((W, H))
150
+ if return_pil:
151
+ images.append(image)
152
+ else:
153
+ buffered = BytesIO()
154
+ image.save(buffered, format="PNG")
155
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
156
+ images.append(img_b64_str)
157
+ return images
158
+
159
+ def to_gradio_chatbot(self):
160
+ ret = []
161
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
162
+ if i % 2 == 0:
163
+ if type(msg) is tuple:
164
+ import base64
165
+ from io import BytesIO
166
+ msg, image, image_process_mode = msg
167
+ max_hw, min_hw = max(image.size), min(image.size)
168
+ aspect_ratio = max_hw / min_hw
169
+ max_len, min_len = 800, 400
170
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
171
+ longest_edge = int(shortest_edge * aspect_ratio)
172
+ W, H = image.size
173
+ if H > W:
174
+ H, W = longest_edge, shortest_edge
175
+ else:
176
+ H, W = shortest_edge, longest_edge
177
+ image = image.resize((W, H))
178
+ buffered = BytesIO()
179
+ image.save(buffered, format="JPEG")
180
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
181
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
182
+ msg = img_str + msg.replace('<image>', '').strip()
183
+ ret.append([msg, None])
184
+ else:
185
+ ret.append([msg, None])
186
+ else:
187
+ ret[-1][-1] = msg
188
+ return ret
189
+
190
+ def copy(self):
191
+ return Conversation(
192
+ system=self.system,
193
+ roles=self.roles,
194
+ messages=[[x, y] for x, y in self.messages],
195
+ offset=self.offset,
196
+ sep_style=self.sep_style,
197
+ sep=self.sep,
198
+ sep2=self.sep2,
199
+ version=self.version)
200
+
201
+ def dict(self):
202
+ if len(self.get_images()) > 0:
203
+ return {
204
+ "system": self.system,
205
+ "roles": self.roles,
206
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
207
+ "offset": self.offset,
208
+ "sep": self.sep,
209
+ "sep2": self.sep2,
210
+ }
211
+ return {
212
+ "system": self.system,
213
+ "roles": self.roles,
214
+ "messages": self.messages,
215
+ "offset": self.offset,
216
+ "sep": self.sep,
217
+ "sep2": self.sep2,
218
+ }
219
+
220
+
221
+ conv_vicuna_v0 = Conversation(
222
+ system="A chat between a curious human and an artificial intelligence assistant. "
223
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
224
+ roles=("Human", "Assistant"),
225
+ messages=(
226
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
227
+ ("Assistant",
228
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
229
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
230
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
231
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
232
+ "renewable and non-renewable energy sources:\n"
233
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
234
+ "energy sources are finite and will eventually run out.\n"
235
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
236
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
237
+ "and other negative effects.\n"
238
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
239
+ "have lower operational costs than non-renewable sources.\n"
240
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
241
+ "locations than non-renewable sources.\n"
242
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
243
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
244
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
245
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
246
+ ),
247
+ offset=2,
248
+ sep_style=SeparatorStyle.SINGLE,
249
+ sep="###",
250
+ )
251
+
252
+ conv_vicuna_v1 = Conversation(
253
+ system="A chat between a curious user and an artificial intelligence assistant. "
254
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
255
+ roles=("USER", "ASSISTANT"),
256
+ version="v1",
257
+ messages=(),
258
+ offset=0,
259
+ sep_style=SeparatorStyle.TWO,
260
+ sep=" ",
261
+ sep2="</s>",
262
+ )
263
+
264
+ conv_llama_2 = Conversation(
265
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
266
+
267
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
268
+ roles=("USER", "ASSISTANT"),
269
+ version="llama_v2",
270
+ messages=(),
271
+ offset=0,
272
+ sep_style=SeparatorStyle.LLAMA_2,
273
+ sep="<s>",
274
+ sep2="</s>",
275
+ )
276
+
277
+ conv_seagull_llama_2 = Conversation(
278
+ system="You are a helpful language and vision assistant. "
279
+ "You are able to understand the visual content that the user provides, "
280
+ "and assist the user with a variety of tasks using natural language.",
281
+ roles=("USER", "ASSISTANT"),
282
+ version="llama_v2",
283
+ messages=(),
284
+ offset=0,
285
+ sep_style=SeparatorStyle.LLAMA_2,
286
+ sep="<s>",
287
+ sep2="</s>",
288
+ )
289
+
290
+ conv_mpt = Conversation(
291
+ system="""<|im_start|>system
292
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
293
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
294
+ version="mpt",
295
+ messages=(),
296
+ offset=0,
297
+ sep_style=SeparatorStyle.MPT,
298
+ sep="<|im_end|>",
299
+ )
300
+
301
+ conv_seagull_plain = Conversation(
302
+ system="",
303
+ roles=("", ""),
304
+ messages=(
305
+ ),
306
+ offset=0,
307
+ sep_style=SeparatorStyle.PLAIN,
308
+ sep="\n",
309
+ )
310
+
311
+ conv_seagull_v0 = Conversation(
312
+ system="A chat between a curious human and an artificial intelligence assistant. "
313
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
314
+ roles=("Human", "Assistant"),
315
+ messages=(
316
+ ),
317
+ offset=0,
318
+ sep_style=SeparatorStyle.SINGLE,
319
+ sep="###",
320
+ )
321
+
322
+ conv_seagull_v0_mmtag = Conversation(
323
+ system="A chat between a curious user and an artificial intelligence assistant. "
324
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
325
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
326
+ roles=("Human", "Assistant"),
327
+ messages=(
328
+ ),
329
+ offset=0,
330
+ sep_style=SeparatorStyle.SINGLE,
331
+ sep="###",
332
+ version="v0_mmtag",
333
+ )
334
+
335
+ conv_seagull_v1 = Conversation(
336
+ system="A chat between a curious human and an artificial intelligence assistant. "
337
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
338
+ roles=("USER", "ASSISTANT"),
339
+ version="v1",
340
+ messages=(),
341
+ offset=0,
342
+ sep_style=SeparatorStyle.TWO,
343
+ sep=" ",
344
+ sep2="</s>",
345
+ )
346
+
347
+ conv_seagull_v1_mmtag = Conversation(
348
+ system="A chat between a curious user and an artificial intelligence assistant. "
349
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
350
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
351
+ roles=("USER", "ASSISTANT"),
352
+ messages=(),
353
+ offset=0,
354
+ sep_style=SeparatorStyle.TWO,
355
+ sep=" ",
356
+ sep2="</s>",
357
+ version="v1_mmtag",
358
+ )
359
+
360
+ default_conversation = conv_vicuna_v0
361
+ conv_templates = {
362
+ "default": conv_vicuna_v0,
363
+ "v0": conv_vicuna_v0,
364
+ "v1": conv_vicuna_v1,
365
+ "vicuna_v1": conv_vicuna_v1,
366
+ "llama_2": conv_llama_2,
367
+
368
+ "plain": conv_seagull_plain,
369
+ "v0_plain": conv_seagull_plain,
370
+ "seagull_v0": conv_seagull_v0,
371
+ "v0_mmtag": conv_seagull_v0_mmtag,
372
+ "seagull_v1": conv_seagull_v1,
373
+ "v1_mmtag": conv_seagull_v1_mmtag,
374
+ "seagull_llama_2": conv_seagull_llama_2,
375
+
376
+ "mpt": conv_mpt,
377
+ }
378
+
379
+
380
+ if __name__ == "__main__":
381
+ print(default_conversation.get_prompt())
seagull/mm_utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from seagull.constants import IMAGE_TOKEN_INDEX
8
+
9
+
10
+ def load_image_from_base64(image):
11
+ return Image.open(BytesIO(base64.b64decode(image)))
12
+
13
+ def expand2square(pil_img, background_color):
14
+ width, height = pil_img.size
15
+ if width == height:
16
+ return pil_img
17
+ elif width > height:
18
+ result = Image.new(pil_img.mode, (width, width), background_color)
19
+ result.paste(pil_img, (0, (width - height) // 2))
20
+ return result
21
+ else:
22
+ result = Image.new(pil_img.mode, (height, height), background_color)
23
+ result.paste(pil_img, ((height - width) // 2, 0))
24
+ return result
25
+
26
+
27
+ def process_images(images, image_processor, model_cfg):
28
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
29
+ new_images = []
30
+ if image_aspect_ratio == 'pad':
31
+ for image in images:
32
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
33
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
34
+ new_images.append(image)
35
+ else:
36
+ return image_processor(images, return_tensors='pt')['pixel_values']
37
+ if all(x.shape == new_images[0].shape for x in new_images):
38
+ new_images = torch.stack(new_images, dim=0)
39
+ return new_images
40
+
41
+
42
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
43
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
44
+
45
+ def insert_separator(X, sep):
46
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
47
+
48
+ input_ids = []
49
+ offset = 0
50
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
51
+ offset = 1
52
+ input_ids.append(prompt_chunks[0][0])
53
+
54
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
55
+ input_ids.extend(x[offset:])
56
+
57
+ if return_tensors is not None:
58
+ if return_tensors == 'pt':
59
+ return torch.tensor(input_ids, dtype=torch.long)
60
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
61
+ return input_ids
62
+
63
+
64
+ def get_model_name_from_path(model_path):
65
+ model_path = model_path.strip("/")
66
+ model_paths = model_path.split("/")
67
+ if model_paths[-1].startswith('checkpoint-'):
68
+ return model_paths[-2] + "_" + model_paths[-1]
69
+ else:
70
+ return model_paths[-1]
71
+
72
+ class KeywordsStoppingCriteria(StoppingCriteria):
73
+ def __init__(self, keywords, tokenizer, input_ids):
74
+ self.keywords = keywords
75
+ self.keyword_ids = []
76
+ for keyword in keywords:
77
+ cur_keyword_ids = tokenizer(keyword).input_ids
78
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
79
+ cur_keyword_ids = cur_keyword_ids[1:]
80
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
81
+ self.tokenizer = tokenizer
82
+ self.start_len = input_ids.shape[1]
83
+
84
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
85
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
86
+ offset = min(output_ids.shape[1] - self.start_len, 3)
87
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
88
+ for keyword_id in self.keyword_ids:
89
+ if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
90
+ return True
91
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
92
+ for keyword in self.keywords:
93
+ if keyword in outputs:
94
+ return True
95
+ return False
seagull/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .language_model.seagull_llama import SeagullLlamaForCausalLM, SeagullConfig
seagull/model/__pycache__/Q_A.cpython-310.pyc ADDED
Binary file (957 Bytes). View file
 
seagull/model/__pycache__/Q_A_pretrain.cpython-310.pyc ADDED
Binary file (2.32 kB). View file
 
seagull/model/__pycache__/Q_A_pretrain_level.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
seagull/model/__pycache__/Q_A_stage3.cpython-310.pyc ADDED
Binary file (5.74 kB). View file
 
seagull/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (243 Bytes). View file
 
seagull/model/__pycache__/layer.cpython-310.pyc ADDED
Binary file (8.2 kB). View file
 
seagull/model/__pycache__/layer_osprey.cpython-310.pyc ADDED
Binary file (3.72 kB). View file
 
seagull/model/__pycache__/osprey_arch.cpython-310.pyc ADDED
Binary file (9.29 kB). View file
 
seagull/model/__pycache__/seagull_arch.cpython-310.pyc ADDED
Binary file (9.26 kB). View file
 
seagull/model/__pycache__/stage2_distrotion_maker.cpython-310.pyc ADDED
Binary file (3.7 kB). View file
 
seagull/model/consolidate.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from seagull.model import *
7
+ from seagull.model.utils import auto_upgrade
8
+
9
+
10
+ def consolidate_ckpt(src_path, dst_path):
11
+ print("Loading model")
12
+ auto_upgrade(src_path)
13
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
14
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
15
+ src_model.save_pretrained(dst_path)
16
+ src_tokenizer.save_pretrained(dst_path)
17
+
18
+
19
+ if __name__ == "__main__":
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--src", type=str, required=True)
22
+ parser.add_argument("--dst", type=str, required=True)
23
+
24
+ args = parser.parse_args()
25
+
26
+ consolidate_ckpt(args.src, args.dst)
seagull/model/language_model/__pycache__/osprey_llama.cpython-310.pyc ADDED
Binary file (3.87 kB). View file
 
seagull/model/language_model/__pycache__/seagull_llama.cpython-310.pyc ADDED
Binary file (3.82 kB). View file
 
seagull/model/language_model/seagull_llama.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import CrossEntropyLoss
5
+ from transformers import AutoConfig, AutoModelForCausalLM, \
6
+ LlamaConfig, LlamaModel, LlamaForCausalLM
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from ..seagull_arch import SeagullMetaModel, SeagullMetaForCausalLM
9
+ from ..layer import MaskExtractor
10
+
11
+ class SeagullConfig(LlamaConfig):
12
+ model_type = "seagull"
13
+
14
+ class SeagullLlamaModel(SeagullMetaModel, LlamaModel):
15
+ config_class = SeagullConfig
16
+
17
+ def __init__(self, config: LlamaConfig):
18
+ super(SeagullLlamaModel, self).__init__(config)
19
+
20
+ class SeagullLlamaForCausalLM(LlamaForCausalLM, SeagullMetaForCausalLM):
21
+ config_class = SeagullConfig
22
+
23
+ def __init__(self, config):
24
+ super(LlamaForCausalLM, self).__init__(config)
25
+ self.model = SeagullLlamaModel(config)
26
+
27
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
28
+ self.mask_extractor = MaskExtractor()
29
+
30
+ self.post_init()
31
+
32
+ def get_model(self):
33
+ return self.model
34
+
35
+ def forward(
36
+ self,
37
+ input_ids: torch.LongTensor = None,
38
+ attention_mask: Optional[torch.Tensor] = None,
39
+ img_metas = None,
40
+ masks = None,
41
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
42
+ inputs_embeds: Optional[torch.FloatTensor] = None,
43
+ labels: Optional[torch.LongTensor] = None,
44
+ use_cache: Optional[bool] = None,
45
+ output_attentions: Optional[bool] = None,
46
+ output_hidden_states: Optional[bool] = None,
47
+ images: Optional[torch.FloatTensor] = None,
48
+ preprocessed_img_dict = None,
49
+ return_dict: Optional[bool] = None,
50
+ cropped_img: Optional[torch.FloatTensor] = None,
51
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
52
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
53
+ output_hidden_states = (
54
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
55
+ )
56
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
+
58
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, masks, attention_mask, past_key_values, labels, images, preprocessed_img_dict=preprocessed_img_dict, cropped_img=cropped_img)
59
+
60
+ if inputs_embeds is not None:
61
+ inputs_embeds = inputs_embeds.bfloat16()
62
+
63
+ self.model = self.model.bfloat16()
64
+
65
+ outputs = self.model(
66
+ input_ids=input_ids,
67
+ attention_mask=attention_mask,
68
+ past_key_values=past_key_values,
69
+ inputs_embeds=inputs_embeds,
70
+ use_cache=use_cache,
71
+ output_attentions=output_attentions,
72
+ output_hidden_states=output_hidden_states,
73
+ return_dict=return_dict
74
+ )
75
+
76
+ hidden_states = outputs[0]
77
+ self.lm_head = self.lm_head.to(hidden_states.dtype)
78
+ logits = self.lm_head(hidden_states)
79
+
80
+ loss = None
81
+ if labels is not None:
82
+ # Shift so that tokens < n predict n
83
+ shift_logits = logits[..., :-1, :].contiguous()
84
+ shift_labels = labels[..., 1:].contiguous()
85
+ # Flatten the tokens
86
+ loss_fct = CrossEntropyLoss()
87
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
88
+ shift_labels = shift_labels.view(-1)
89
+ # Enable model/pipeline parallelism
90
+ shift_labels = shift_labels.to(shift_logits.device)
91
+ loss = loss_fct(shift_logits, shift_labels)
92
+
93
+ if not return_dict:
94
+ output = (logits,) + outputs[1:]
95
+ return (loss,) + output if loss is not None else output
96
+
97
+ return CausalLMOutputWithPast(
98
+ loss=loss,
99
+ logits=logits,
100
+ past_key_values=outputs.past_key_values,
101
+ hidden_states=outputs.hidden_states,
102
+ attentions=outputs.attentions,
103
+ )
104
+
105
+ def prepare_inputs_for_generation(
106
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
107
+ ):
108
+ if past_key_values:
109
+ input_ids = input_ids[:, -1:]
110
+
111
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
112
+ if inputs_embeds is not None and past_key_values is None:
113
+ model_inputs = {"inputs_embeds": inputs_embeds}
114
+ else:
115
+ model_inputs = {"input_ids": input_ids}
116
+
117
+ model_inputs.update(
118
+ {
119
+ "past_key_values": past_key_values,
120
+ "use_cache": kwargs.get("use_cache"),
121
+ "attention_mask": attention_mask,
122
+ "images": kwargs.get("images", None),
123
+ }
124
+ )
125
+ return model_inputs
126
+
127
+ AutoConfig.register("seagull", SeagullConfig)
128
+ AutoModelForCausalLM.register(SeagullConfig, SeagullLlamaForCausalLM)
seagull/model/layer.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple, Type, Any
5
+ from torch import Tensor
6
+ import math
7
+ import numpy as np
8
+ from einops import rearrange
9
+
10
+ class MLP(nn.Module):
11
+
12
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
13
+ num_layers: int) -> None:
14
+ super().__init__()
15
+ self.num_layers = num_layers
16
+ h = [hidden_dim] * (num_layers - 1)
17
+ self.layers = nn.ModuleList(
18
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
19
+
20
+ def forward(self, x):
21
+ for i, layer in enumerate(self.layers):
22
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
23
+ return x
24
+
25
+ class MaskExtractor(nn.Module): # Mask-based Feature Extractor
26
+ def __init__(self, mask_shape=112, embed_dim=1024, out_dim=4096, num_heads=8, mlp_dim=2048, downsample_rate=2, skip_first_layer_pe=False):
27
+ super(MaskExtractor, self).__init__()
28
+ self.mask_shape = mask_shape
29
+ self.mask_pooling = MaskPooling()
30
+ self.feat_linear = nn.Linear(embed_dim, out_dim)
31
+ self.cross_feat_linear = nn.Linear(embed_dim, out_dim)
32
+ self.mask_linear = MLP(mask_shape*mask_shape, embed_dim, out_dim, 3)
33
+
34
+ self.feature_name = ['res2', 'res3', 'res4', 'res5']
35
+
36
+ self.cross_att_res = CrossAttention(
37
+ embedding_dim=embed_dim,
38
+ num_heads=num_heads,
39
+ mlp_dim=mlp_dim,
40
+ douwnsample_rate=downsample_rate,
41
+ skip_first_layer_pe=skip_first_layer_pe
42
+ )
43
+
44
+ self.res2 = nn.Linear(192, 1024)
45
+ self.res3 = nn.Linear(384, 1024)
46
+ self.res4 = nn.Linear(768, 1024)
47
+ self.res5 = nn.Linear(1536, 1024)
48
+
49
+ self.g_res2 = nn.Linear(16384, 1024) # h * w
50
+ self.g_res3 = nn.Linear(4096, 1024)
51
+ self.g_res4 = nn.Linear(1024, 1024)
52
+ self.g_res5 = nn.Linear(256, 1024)
53
+
54
+ self.final_mlp = nn.Linear(2 * out_dim, out_dim)
55
+
56
+ self.global_vit = nn.Sequential(
57
+ nn.Conv2d(3, 5, 1),
58
+ nn.GELU(),
59
+ nn.AvgPool2d(4, 4),
60
+
61
+ nn.Conv2d(5, 1, 1),
62
+ nn.GELU(),
63
+ nn.AvgPool2d(4, 4),
64
+ )
65
+ self.is_first = 0
66
+
67
+ self.sa = Attention(32 * 32, num_heads) # self-attention
68
+ self.mlp = MLP(32 * 32, 512, out_dim, 3)
69
+
70
+ def cal_globa_local(self, mask_feat_raw, feat_new, res, g_res, cross_attention):
71
+ mask_feat_flatten = mask_feat_raw.to(device=res.weight.device, dtype=res.weight.dtype)
72
+ mask_feat = res(mask_feat_flatten) # (b, q, 1024)
73
+
74
+ feat_new = feat_new.to(device=g_res.weight.device, dtype=g_res.weight.dtype)
75
+ all_feat_new = g_res(feat_new) # (b, c, 1024)
76
+ global_mask = cross_attention(mask_feat, all_feat_new)
77
+ return mask_feat, global_mask
78
+
79
+ def forward(self, feats, masks, cropped_img):
80
+ global_features = []
81
+ local_features = []
82
+ num_imgs = len(masks)
83
+
84
+ for idx in range(num_imgs):
85
+ mask = masks[idx].unsqueeze(0).float() #(1, q, h, w)
86
+ cropped_ = cropped_img[idx] # (q, 3, h, w)
87
+
88
+ num_feats = len(self.feature_name)
89
+ mask_feats = mask.new_zeros(num_feats, mask.shape[1], 1024)
90
+ global_masks = mask.new_zeros(num_feats, mask.shape[1], 1024)
91
+
92
+ for i, name in enumerate(self.feature_name):
93
+ feat = feats[name][idx].unsqueeze(0)
94
+ feat = feat.to(mask.dtype)
95
+
96
+ mask_feat_raw = self.mask_pooling(feat, mask)
97
+ feat_new = rearrange(feat, 'b c h w -> b c (h w)')
98
+
99
+ mask_feat, global_mask = self.cal_globa_local(mask_feat_raw, feat_new, res=getattr(self, name), g_res=getattr(self, 'g_{}'.format(name)), cross_attention=getattr(self,"cross_att_res"))
100
+
101
+ mask_feats[i] = mask_feat.squeeze(0) # (q, 1024)
102
+ global_masks[i] = global_mask.squeeze(0)
103
+ mask_feats = mask_feats.sum(0) # (1, q, 1024)
104
+ global_masks = global_masks.sum(0) # (1, q, 1024)
105
+ global_masks = global_masks.to(device=self.cross_feat_linear.weight.device, dtype=self.cross_feat_linear.weight.dtype)
106
+ global_masks_linear = self.cross_feat_linear(global_masks)
107
+ mask_feats = mask_feats.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
108
+ mask_feats_linear = self.feat_linear(mask_feats) #(1, q, 4096)
109
+
110
+ query_feat = self.final_mlp(torch.cat((global_masks_linear, mask_feats_linear), dim=-1))
111
+ global_features.append(query_feat) # global
112
+
113
+ cropped_ = cropped_.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
114
+ global_features = self.global_vit(cropped_).to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype) # q, 1, 32, 32
115
+ global_features = global_features.reshape(-1, 1, 32 * 32) # q, 1, 32 * 32
116
+ pos_feat = self.mlp(self.sa(global_features, global_features, global_features).squeeze(1)) # q, output
117
+
118
+ local_features.append(pos_feat) #(imgs_num, 1, q, 4096) # local
119
+
120
+ return global_features, local_features
121
+
122
+ class MaskPooling(nn.Module):
123
+ def __init__(self):
124
+ super().__init__()
125
+
126
+ def forward(self, x, mask):
127
+
128
+ if not x.shape[-2:] == mask.shape[-2:]:
129
+ # reshape mask to x
130
+ mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
131
+
132
+ mask = (mask > 0).to(mask.dtype)
133
+ denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
134
+
135
+ mask_pooled_x = torch.einsum(
136
+ "bchw,bqhw->bqc",
137
+ x,
138
+ mask / denorm,
139
+ )
140
+ return mask_pooled_x
141
+
142
+
143
+ class CrossAttention(nn.Module):
144
+ def __init__(
145
+ self,
146
+ embedding_dim: int,
147
+ num_heads: int,
148
+ mlp_dim: int = 2048,
149
+ douwnsample_rate: int = 2,
150
+ activation: Type[nn.Module] = nn.ReLU,
151
+ skip_first_layer_pe: bool = False
152
+ ) -> None:
153
+ super().__init__()
154
+ self.embedding_dim = embedding_dim
155
+ self.num_heads =num_heads
156
+ self.self_attn = Attention(embedding_dim, num_heads) # self-attention
157
+ self.skip_first_layer_pe = skip_first_layer_pe
158
+ self.norm1 = nn.LayerNorm(embedding_dim)
159
+
160
+ # cross-attention
161
+ self.cross_attn = Attention(embedding_dim, num_heads, downsample_rate=douwnsample_rate)
162
+ self.norm2 = nn.LayerNorm(embedding_dim)
163
+
164
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) # MLP
165
+
166
+ def forward(self, queries, keys):
167
+ attn_out = self.self_attn(queries, queries, queries)
168
+ queries = queries + attn_out
169
+ queries = self.norm1(queries)
170
+
171
+ attn_out = self.cross_attn(q=queries, k=keys, v=keys)
172
+ queries = attn_out + queries
173
+ queries = self.norm2(queries)
174
+
175
+ # MLP
176
+ mlp_out = self.mlp(queries)
177
+ queries = queries + mlp_out
178
+ return queries
179
+
180
+ class MLPBlock(nn.Module):
181
+ def __init__(
182
+ self,
183
+ embedding_dim: int,
184
+ mlp_dim: int,
185
+ act: Type[nn.Module] = nn.GELU,
186
+ ) -> None:
187
+ super().__init__()
188
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
189
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
190
+ self.act = act()
191
+
192
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
193
+ return self.lin2(self.act(self.lin1(x)))
194
+
195
+ class Attention(nn.Module):
196
+ """
197
+ An attention layer that allows for downscaling the size of the embedding
198
+ after projection to queries, keys, and values.
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ embedding_dim: int,
204
+ num_heads: int,
205
+ downsample_rate: int = 1,
206
+ ) -> None:
207
+ super().__init__()
208
+ self.embedding_dim = embedding_dim
209
+ self.internal_dim = embedding_dim // downsample_rate
210
+ self.num_heads = num_heads
211
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
212
+
213
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
214
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
215
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
216
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
217
+
218
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
219
+ b, n, c = x.shape
220
+ x = x.reshape(b, n, num_heads, c // num_heads)
221
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
222
+
223
+ def _recombine_heads(self, x: Tensor) -> Tensor:
224
+ b, n_heads, n_tokens, c_per_head = x.shape
225
+ x = x.transpose(1, 2)
226
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
227
+
228
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
229
+ # Input projections
230
+ q = self.q_proj(q)
231
+ k = self.k_proj(k)
232
+ v = self.v_proj(v)
233
+
234
+ # Separate into heads
235
+ q = self._separate_heads(q, self.num_heads)
236
+ k = self._separate_heads(k, self.num_heads)
237
+ v = self._separate_heads(v, self.num_heads)
238
+
239
+ # Attention
240
+ _, _, _, c_per_head = q.shape
241
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
242
+ attn = attn / math.sqrt(c_per_head)
243
+ attn = torch.softmax(attn, dim=-1)
244
+
245
+ # Get output
246
+ out = attn @ v
247
+ out = self._recombine_heads(out)
248
+ out = self.out_proj(out)
249
+
250
+ return out
seagull/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc ADDED
Binary file (392 Bytes). View file
 
seagull/model/multimodal_encoder/__pycache__/clip.cpython-310.pyc ADDED
Binary file (1.76 kB). View file
 
seagull/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc ADDED
Binary file (2.21 kB). View file
 
seagull/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+
4
+
5
+ def build_vision_tower(vision_tower_cfg, delay_load=False):
6
+
7
+ return CLIPVisionTower(args=vision_tower_cfg)
seagull/model/multimodal_encoder/clip.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+
5
+ from open_clip.model import _build_vision_tower
6
+
7
+
8
+ class CLIP(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ model_name = 'convnext_large'
12
+
13
+ vision_cfg = {'timm_model_name': model_name, 'timm_model_pretrained': False, 'timm_pool': '', 'timm_proj': 'mlp', 'timm_drop': 0.0, 'timm_drop_path': 0.1, 'image_size': 320}
14
+ self.visual = _build_vision_tower(embed_dim=768, vision_cfg=vision_cfg, quick_gelu=False)
15
+
16
+ self.eval()
17
+ self.freeze_everything()
18
+
19
+ def freeze_everything(self):
20
+ for param in self.visual.parameters():
21
+ param.requires_grad = False
22
+
23
+ def extract_features(self, x):
24
+ out = {}
25
+ x = x.to(self.visual.trunk.stem.state_dict()['1.bias'].dtype)
26
+ x = self.visual.trunk.stem(x)
27
+ out['stem'] = x.contiguous()
28
+ for i in range(4):
29
+ x = self.visual.trunk.stages[i](x)
30
+ out[f'res{i+2}'] = x.contiguous()
31
+
32
+ x = self.visual.trunk.norm_pre(x)
33
+ out['clip_vis_dense'] = x.contiguous()
34
+ return out
35
+
36
+ def forward(self, x):
37
+ self.eval()
38
+ with torch.no_grad():
39
+ return self.extract_features(x)
40
+
seagull/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPImageProcessor
5
+ from .clip import CLIP
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, args, img_size=512, delay_load=False):
9
+ super().__init__()
10
+
11
+ # test
12
+ if hasattr(args, 'mm_vision_tower'):
13
+ self.clip_model = args.mm_vision_tower
14
+ else: # train
15
+ self.clip_model = args.vision_tower
16
+ self.is_loaded = False
17
+ self.img_size = img_size
18
+
19
+ if not delay_load:
20
+ self.load_model()
21
+
22
+ def load_model(self):
23
+ self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":self.img_size}, resample=3, do_center_crop=True, crop_size={"height": self.img_size, "width": self.img_size},
24
+ do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
25
+ image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
26
+
27
+ self.vision_tower = CLIP()
28
+
29
+ self.vision_tower.load_state_dict(torch.load(self.clip_model),strict=False)
30
+
31
+ self.is_loaded = True
32
+
33
+ @torch.no_grad()
34
+ def forward(self, images):
35
+ if type(images) is list:
36
+ image_features = []
37
+ image_features_dict = []
38
+ for image in images:
39
+ image_feature_dict = self.vision_tower(image.unsqueeze(0))
40
+ image_features_dict.append(image_feature_dict)
41
+ image_feature = image_feature_dict['res4']
42
+ image_feature = image_feature.reshape(*image_feature.shape[:2],-1).permute(0,2,1)
43
+ image_features.append(image_feature)
44
+ else:
45
+ # print(images.device)
46
+ # print(self.vision_tower.device)
47
+ image_features_dict = self.vision_tower(images)
48
+ image_features = image_features_dict['res4']
49
+ image_features = image_features.reshape(*image_features.shape[:2],-1).permute(0,2,1)
50
+
51
+ return image_features, image_features_dict
52
+
53
+ @property
54
+ def dtype(self):
55
+ return self.vision_tower.dtype
56
+
57
+ @property
58
+ def device(self):
59
+ return self.vision_tower.device
seagull/model/multimodal_projector/__pycache__/builder.cpython-310.pyc ADDED
Binary file (2.04 kB). View file
 
seagull/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+
6
+ class IdentityMap(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x, *args, **kwargs):
11
+ return x
12
+
13
+ @property
14
+ def config(self):
15
+ return {"mm_projector_type": 'identity'}
16
+
17
+
18
+ class SimpleResBlock(nn.Module):
19
+ def __init__(self, channels):
20
+ super().__init__()
21
+ self.pre_norm = nn.LayerNorm(channels)
22
+
23
+ self.proj = nn.Sequential(
24
+ nn.Linear(channels, channels),
25
+ nn.GELU(),
26
+ nn.Linear(channels, channels)
27
+ )
28
+ def forward(self, x):
29
+ x = self.pre_norm(x)
30
+ return x + self.proj(x)
31
+
32
+
33
+ def build_vision_projector(config, delay_load=False, **kwargs):
34
+ mm_hidden_size = getattr(config, 'mm_hidden_size', 768)
35
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
36
+
37
+ if projector_type == 'linear':
38
+ return nn.Linear(mm_hidden_size, config.hidden_size)
39
+
40
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
41
+ if mlp_gelu_match:
42
+ mlp_depth = int(mlp_gelu_match.group(1))
43
+ modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
44
+ for _ in range(1, mlp_depth):
45
+ modules.append(nn.GELU())
46
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
47
+ return nn.Sequential(*modules)
48
+
49
+ if projector_type == 'identity':
50
+ return IdentityMap()
51
+
52
+ raise ValueError(f'Unknown projector type: {projector_type}')
seagull/model/seagull_arch.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+
5
+ from .multimodal_encoder.builder import build_vision_tower
6
+ from .multimodal_projector.builder import build_vision_projector
7
+
8
+ from seagull.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+
10
+
11
+ class SeagullMetaModel:
12
+
13
+ def __init__(self, config):
14
+ super(SeagullMetaModel, self).__init__(config)
15
+
16
+ if hasattr(config, "mm_vision_tower"):
17
+ self.vision_tower = build_vision_tower(config, delay_load=False)
18
+ self.mm_projector = build_vision_projector(config)
19
+
20
+ def get_vision_tower(self):
21
+ vision_tower = getattr(self, 'vision_tower', None)
22
+ if type(vision_tower) is list:
23
+ vision_tower = vision_tower[0]
24
+ return vision_tower
25
+
26
+ def initialize_vision_modules(self, model_args, fsdp=None):
27
+
28
+ vision_tower = model_args.vision_tower
29
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
30
+
31
+ if not hasattr(self.config, "mm_vision_tower"):
32
+ self.config.mm_vision_tower = vision_tower
33
+
34
+ vision_tower = build_vision_tower(model_args)
35
+
36
+ if fsdp is not None and len(fsdp) > 0:
37
+ self.vision_tower = [self.vision_tower]
38
+ else:
39
+ self.vision_tower = vision_tower
40
+
41
+ self.config.use_mm_proj = True
42
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
43
+
44
+ self.mm_projector = build_vision_projector(self.config)
45
+
46
+ if pretrain_mm_mlp_adapter is not None:
47
+ print("***********load projector_weights********")
48
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
49
+ def get_w(weights, keyword):
50
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
51
+
52
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
53
+
54
+
55
+
56
+ class SeagullMetaForCausalLM(ABC):
57
+ def __init__(self):
58
+ super(SeagullMetaForCausalLM, self).__init__()
59
+
60
+ @abstractmethod
61
+ def get_model(self):
62
+ pass
63
+
64
+ def get_vision_tower(self):
65
+ return self.get_model().get_vision_tower()
66
+
67
+ def encode_images(self, images):
68
+ image_features, image_features_dict = self.get_model().get_vision_tower()(images)
69
+ self.get_model().mm_projector.to(device=image_features.device, dtype=image_features.dtype)
70
+ image_features = self.get_model().mm_projector(image_features)
71
+ return image_features, image_features_dict
72
+
73
+ def prepare_inputs_labels_for_multimodal(
74
+ self, input_ids, masks, attention_mask, past_key_values, labels, images, preprocessed_img_dict=None, cropped_img=None
75
+ ):
76
+ vision_tower = self.get_vision_tower()
77
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
78
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
79
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
80
+ return input_ids, attention_mask, past_key_values, None, labels
81
+
82
+ if preprocessed_img_dict is not None:
83
+ image_features, image_features_dict = images, preprocessed_img_dict
84
+ else:
85
+ if type(images) is list or images.ndim == 5:
86
+ concat_images = torch.cat([image for image in images], dim=0)
87
+ image_features, image_features_dict = self.encode_images(concat_images)
88
+ split_sizes = [image.shape[0] for image in images]
89
+ image_features = torch.split(image_features, split_sizes, dim=0)
90
+ image_features = [x.flatten(0, 1).to(concat_images.device) for x in image_features]
91
+ else:
92
+ image_features, image_features_dict = self.encode_images(images)
93
+
94
+
95
+ mask_feats, pos_feats = self.mask_extractor(image_features_dict, masks, cropped_img=cropped_img)
96
+
97
+ new_input_embeds = []
98
+ new_labels = [] if labels is not None else None
99
+ cur_image_idx = 0
100
+ for batch_idx, cur_input_ids in enumerate(input_ids):
101
+
102
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
103
+ # multimodal LLM, but the current sample is not multimodal
104
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
105
+ half_len = cur_input_ids.shape[0] // 2
106
+ cur_image_features = image_features[cur_image_idx]
107
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
108
+ cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
109
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
110
+ new_input_embeds.append(cur_input_embeds)
111
+ if labels is not None:
112
+ new_labels.append(labels[batch_idx])
113
+ cur_image_idx += 1
114
+ continue
115
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
116
+ cur_new_input_embeds = []
117
+ if labels is not None:
118
+ cur_labels = labels[batch_idx]
119
+ cur_new_labels = []
120
+ assert cur_labels.shape == cur_input_ids.shape
121
+ while image_token_indices.numel() > 0:
122
+ cur_image_features = image_features[cur_image_idx]
123
+ image_token_start = image_token_indices[0]
124
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
125
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
126
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
127
+ cur_new_input_embeds.append(cur_image_features)
128
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
129
+ if labels is not None:
130
+ cur_new_labels.append(cur_labels[:image_token_start])
131
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
132
+ cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
133
+ cur_labels = cur_labels[image_token_start+2:]
134
+ else:
135
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
136
+ cur_new_input_embeds.append(cur_image_features)
137
+ if labels is not None:
138
+ cur_new_labels.append(cur_labels[:image_token_start])
139
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
140
+ cur_labels = cur_labels[image_token_start+1:]
141
+ cur_image_idx += 1
142
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
143
+ cur_input_ids = cur_input_ids[image_token_start+2:]
144
+ else:
145
+ cur_input_ids = cur_input_ids[image_token_start+1:]
146
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
147
+ if cur_input_ids.numel() > 0:
148
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
149
+ mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<global>'])[0])
150
+
151
+ _l = 0
152
+ for i, idx in enumerate(mask_idx):
153
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
154
+ ## mask
155
+ cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].detach())
156
+ ## pos
157
+ cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].detach())
158
+ if labels is not None:
159
+ cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
160
+ _l = idx[0]+2
161
+ if _l< len(cur_input_ids):
162
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]).detach())
163
+
164
+ else:
165
+
166
+ mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<global>'])[0])
167
+ assert len(mask_idx) == len(mask_feats[batch_idx]), "mask num not equal to mask feats"
168
+
169
+ _l = 0
170
+ for i, idx in enumerate(mask_idx):
171
+ cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
172
+ cur_new_input_embeds.append(cur_raw_new_input_embeds)
173
+ ## mask
174
+ cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
175
+ ## pos
176
+ cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
177
+
178
+ if labels is not None:
179
+ cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
180
+
181
+ _l = idx[0]+2
182
+ if _l< len(cur_input_ids):
183
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]))
184
+
185
+ if labels is not None:
186
+ cur_new_labels.append(cur_labels)
187
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
188
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
189
+
190
+ new_input_embeds.append(cur_new_input_embeds)
191
+ if labels is not None:
192
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
193
+ new_labels.append(cur_new_labels)
194
+
195
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
196
+ max_len = max(x.shape[0] for x in new_input_embeds)
197
+
198
+ new_input_embeds_align = []
199
+ for cur_new_embed in new_input_embeds:
200
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
201
+ new_input_embeds_align.append(cur_new_embed)
202
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
203
+
204
+ if labels is not None:
205
+ new_labels_align = []
206
+ _new_labels = new_labels
207
+ for cur_new_label in new_labels:
208
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
209
+ new_labels_align.append(cur_new_label)
210
+ new_labels = torch.stack(new_labels_align, dim=0)
211
+
212
+ if attention_mask is not None:
213
+ new_attention_mask = []
214
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
215
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
216
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
217
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
218
+ new_attention_mask.append(cur_new_attention_mask)
219
+ attention_mask = torch.stack(new_attention_mask, dim=0)
220
+ assert attention_mask.shape == new_labels.shape
221
+ else:
222
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
223
+ if labels is not None:
224
+ new_labels = torch.stack(new_labels, dim=0)
225
+
226
+ if attention_mask is not None:
227
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
228
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
229
+ assert attention_mask.shape == new_input_embeds.shape[:2]
230
+
231
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
232
+
233
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
234
+ if model_args.mm_use_im_patch_token:
235
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
236
+ self.resize_token_embeddings(len(tokenizer))
237
+
238
+ mask_tokens = ['<global>', '<pos>']
239
+ num_new_tokens = tokenizer.add_tokens(mask_tokens, special_tokens=True)
240
+
241
+ if model_args.mm_use_im_start_end:
242
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
243
+ self.resize_token_embeddings(len(tokenizer))
244
+
245
+ if num_new_tokens > 0:
246
+ input_embeddings = self.get_input_embeddings().weight.data
247
+ output_embeddings = self.get_output_embeddings().weight.data
248
+
249
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
250
+ dim=0, keepdim=True)
251
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
252
+ dim=0, keepdim=True)
253
+
254
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
255
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
256
+
257
+ if model_args.tune_mm_mlp_adapter:
258
+ for p in self.get_input_embeddings().parameters():
259
+ p.requires_grad = True
260
+ for p in self.get_output_embeddings().parameters():
261
+ p.requires_grad = False
262
+
263
+ if model_args.pretrain_mm_mlp_adapter:
264
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
265
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
266
+ assert num_new_tokens == 2
267
+ if input_embeddings.shape == embed_tokens_weight.shape:
268
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
269
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
270
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
271
+ else:
272
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
273
+ elif model_args.mm_use_im_patch_token:
274
+ if model_args.tune_mm_mlp_adapter:
275
+ for p in self.get_input_embeddings().parameters():
276
+ p.requires_grad = False
277
+ for p in self.get_output_embeddings().parameters():
278
+ p.requires_grad = False
279
+
280
+ for m in self.modules():
281
+ m.tokenizer = tokenizer
seagull/train/__pycache__/seagull_trainer.cpython-310.pyc ADDED
Binary file (8.35 kB). View file