add online demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +2 -0
- README.md +30 -1
- app.py +62 -0
- demo/UI.py +143 -0
- demo/__pycache__/UI.cpython-310.pyc +0 -0
- demo/__pycache__/mask_utils.cpython-310.pyc +0 -0
- demo/__pycache__/sam_inference.cpython-310.pyc +0 -0
- demo/__pycache__/seagull_inference.cpython-310.pyc +0 -0
- demo/mask_utils.py +144 -0
- demo/sam_inference.py +102 -0
- demo/seagull_inference.py +163 -0
- imgs/.DS_Store +0 -0
- imgs/Examples/1.png +0 -0
- imgs/Examples/2.png +0 -0
- seagull/__init__.py +1 -0
- seagull/__pycache__/__init__.cpython-310.pyc +0 -0
- seagull/__pycache__/constants.cpython-310.pyc +0 -0
- seagull/__pycache__/conversation.cpython-310.pyc +0 -0
- seagull/__pycache__/mm_utils.cpython-310.pyc +0 -0
- seagull/__pycache__/utils.cpython-310.pyc +0 -0
- seagull/builder.py +171 -0
- seagull/constants.py +12 -0
- seagull/conversation.py +381 -0
- seagull/mm_utils.py +95 -0
- seagull/model/__init__.py +1 -0
- seagull/model/__pycache__/Q_A.cpython-310.pyc +0 -0
- seagull/model/__pycache__/Q_A_pretrain.cpython-310.pyc +0 -0
- seagull/model/__pycache__/Q_A_pretrain_level.cpython-310.pyc +0 -0
- seagull/model/__pycache__/Q_A_stage3.cpython-310.pyc +0 -0
- seagull/model/__pycache__/__init__.cpython-310.pyc +0 -0
- seagull/model/__pycache__/layer.cpython-310.pyc +0 -0
- seagull/model/__pycache__/layer_osprey.cpython-310.pyc +0 -0
- seagull/model/__pycache__/osprey_arch.cpython-310.pyc +0 -0
- seagull/model/__pycache__/seagull_arch.cpython-310.pyc +0 -0
- seagull/model/__pycache__/stage2_distrotion_maker.cpython-310.pyc +0 -0
- seagull/model/consolidate.py +26 -0
- seagull/model/language_model/__pycache__/osprey_llama.cpython-310.pyc +0 -0
- seagull/model/language_model/__pycache__/seagull_llama.cpython-310.pyc +0 -0
- seagull/model/language_model/seagull_llama.py +128 -0
- seagull/model/layer.py +250 -0
- seagull/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
- seagull/model/multimodal_encoder/__pycache__/clip.cpython-310.pyc +0 -0
- seagull/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc +0 -0
- seagull/model/multimodal_encoder/builder.py +7 -0
- seagull/model/multimodal_encoder/clip.py +40 -0
- seagull/model/multimodal_encoder/clip_encoder.py +59 -0
- seagull/model/multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
- seagull/model/multimodal_projector/builder.py +52 -0
- seagull/model/seagull_arch.py +281 -0
- seagull/train/__pycache__/seagull_trainer.cpython-310.pyc +0 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.pth
|
2 |
+
*.bin
|
README.md
CHANGED
@@ -9,4 +9,33 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
<img src="https://github.com/chencn2020/SEAGULL/raw/main/imgs/Logo/logo.png" alt="SEAGULL" style="height: auto; width: 100%;">
|
13 |
+
|
14 |
+
<div style="display: flex; justify-content: center; gap: 10px; flex-wrap: wrap; width: 100%;">
|
15 |
+
<a href=""><img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm-dark.svg" alt="Open in Spaces" style="max-width: 100%; height: auto;"></a>
|
16 |
+
<a href="https://arxiv.org/abs/2411.10161"><img src="https://img.shields.io/badge/Arxiv-2411:10161-red" style="max-width: 100%; height: auto;"></a>
|
17 |
+
<a href="https://hits.seeyoufarm.com"><img src="https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fhuggingface.co%2Fdatasets%2FZevin2023%2FSEAGULL-100w&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visitors&edge_flat=false" style="max-width: 100%; height: auto;"></a>
|
18 |
+
<a href='https://github.com/chencn2020/SEAGULL/'><img src='https://img.shields.io/github/stars/chencn2020/Seagull.svg?style=social' style="max-width: 100%; height: auto;"></a>
|
19 |
+
</div>
|
20 |
+
|
21 |
+
## Acknowledgement 💌
|
22 |
+
<div id="Acknowledgement"></div>
|
23 |
+
- [Osprey](https://github.com/CircleRadon/Osprey) and [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA): We build this repostory based on them.
|
24 |
+
- [RAISE](http://loki.disi.unitn.it/RAISE/): The Dist. images in SEAGULL-100w are constructed based on this dataset.
|
25 |
+
- [SAM](https://segment-anything.com/) and [SEEM](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once): The mask-based ROIs are generated using these two awesome works. And SAM are used to get the segmentation result in the demo.
|
26 |
+
- [TOPIQ](https://github.com/chaofengc/IQA-PyTorch): The quality scores and importance scores for ROIs are generated using this great FR-IQA.
|
27 |
+
|
28 |
+
|
29 |
+
## Citation 🖊️
|
30 |
+
If our work is useful to your research, we will be grateful for you to cite our paper:
|
31 |
+
```
|
32 |
+
@misc{chen2024seagull,
|
33 |
+
title={SEAGULL: No-reference Image Quality Assessment for Regions of Interest via Vision-Language Instruction Tuning},
|
34 |
+
author={Zewen Chen and Juan Wang and Wen Wang and Sunhan Xu and Hang Xiong and Yun Zeng and Jian Guo and Shuxun Wang and Chunfeng Yuan and Bing Li and Weiming Hu},
|
35 |
+
year={2024},
|
36 |
+
eprint={2411.10161},
|
37 |
+
archivePrefix={arXiv},
|
38 |
+
primaryClass={cs.CV},
|
39 |
+
url={https://arxiv.org/abs/2411.10161},
|
40 |
+
}
|
41 |
+
```
|
app.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from demo.UI import Main_ui
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
def run_command(command):
|
8 |
+
subprocess.check_call([sys.executable, '-m'] + command.split(), shell=False)
|
9 |
+
|
10 |
+
# Install the package in editable mode
|
11 |
+
run_command("pip install -e .")
|
12 |
+
|
13 |
+
# Install NVM (Node Version Manager)
|
14 |
+
run_command("curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.3/install.sh | bash")
|
15 |
+
|
16 |
+
# Source the appropriate shell configuration file
|
17 |
+
run_command("source ~/.bashrc") # You can change to ~/.zshrc based on your shell
|
18 |
+
|
19 |
+
# Install Node.js version 18.16.0
|
20 |
+
run_command("nvm install v18.16.0")
|
21 |
+
|
22 |
+
# Install pnpm (package manager)
|
23 |
+
run_command("curl -fsSL https://get.pnpm.io/install.sh | sh -")
|
24 |
+
|
25 |
+
# Source the shell configuration file again (for pnpm)
|
26 |
+
run_command("source ~/.bashrc") # You can change to ~/.zshrc based on your shell
|
27 |
+
|
28 |
+
# Verify if pnpm was installed correctly
|
29 |
+
run_command("pnpm --version")
|
30 |
+
|
31 |
+
# Clone the Gradio BBox repository
|
32 |
+
run_command("git clone https://github.com/chencn2020/gradio-bbox.git")
|
33 |
+
|
34 |
+
# Change into the cloned repository directory
|
35 |
+
run_command("cd gradio-bbox")
|
36 |
+
|
37 |
+
# Build frontend
|
38 |
+
run_command("bash scripts/build_frontend.sh")
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
# Change back to the previous directory
|
43 |
+
run_command("cd ..")
|
44 |
+
|
45 |
+
# Install the package again in editable mode
|
46 |
+
run_command("pip install -e .")
|
47 |
+
|
48 |
+
# Install Segment Anything repository from GitHub
|
49 |
+
run_command("pip install git+https://github.com/facebookresearch/segment-anything.git")
|
50 |
+
|
51 |
+
# Download the model checkpoint
|
52 |
+
run_command("curl -o ./checkpoints/sam_vit_b_01ec64.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth")
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
parser = argparse.ArgumentParser(description='SEAGULL', formatter_class=argparse.RawTextHelpFormatter)
|
57 |
+
parser.add_argument('--model', help='path to seagull model', default='Zevin2023/SEAGULL-7B')
|
58 |
+
parser.add_argument('--example_path', help='path to examples', default='./imgs/Examples')
|
59 |
+
args = parser.parse_args()
|
60 |
+
|
61 |
+
demo = Main_ui(args).load_demo()
|
62 |
+
demo.launch(server_port=7530)
|
demo/UI.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from demo.sam_inference import SAM_Inference
|
4 |
+
from demo.seagull_inference import Seagull
|
5 |
+
from demo.mask_utils import ImageSketcher
|
6 |
+
|
7 |
+
class Main_ui():
|
8 |
+
def __init__(self, args) -> None:
|
9 |
+
self.args = args
|
10 |
+
self.seagull = Seagull(model_path=args.model)
|
11 |
+
|
12 |
+
self.example_list = self.load_example()
|
13 |
+
self.sam = SAM_Inference()
|
14 |
+
# self.sam_predictor = get_sam_predictor()
|
15 |
+
# self.mask_generator = get_mask_generator()
|
16 |
+
|
17 |
+
def load_example(self):
|
18 |
+
examples = []
|
19 |
+
for file in sorted(os.listdir(self.args.example_path)):
|
20 |
+
examples.append([os.path.join(self.args.example_path, file)])
|
21 |
+
return examples
|
22 |
+
|
23 |
+
def load_demo(self):
|
24 |
+
with gr.Blocks() as demo:
|
25 |
+
preprocessed_img = gr.State(value=None)
|
26 |
+
binary_mask = gr.State(value=None)
|
27 |
+
|
28 |
+
with gr.Row():
|
29 |
+
gr.Markdown("""
|
30 |
+
<img src="https://github.com/chencn2020/SEAGULL/raw/main/imgs/Logo/logo.png" alt="SEAGULL" style="height: auto; width: 100%; margin-bottom: 3%;">
|
31 |
+
|
32 |
+
## 🔔 Usage
|
33 |
+
|
34 |
+
Firstly, you need to upload an image and choose the analyse types **(quality score, importance score and distortion analysis)**.
|
35 |
+
|
36 |
+
Then you can click **(points)** or pull a frame **(bbox)** on the image to indicate the region of interest (ROIs).
|
37 |
+
|
38 |
+
After that, this demo process the following steps:
|
39 |
+
|
40 |
+
> 1. SAM extracts the mask-based ROIs based on your clicked points or frame.
|
41 |
+
|
42 |
+
> 2. Based on the uploaded image and mask-based ROIs, SEAGULL analyses the quality of the ROIs.
|
43 |
+
|
44 |
+
""")
|
45 |
+
|
46 |
+
with gr.TabItem("Mask-based ROIs (Points)"):
|
47 |
+
with gr.Row():
|
48 |
+
input_image_ponit = gr.Image(type="numpy", label='Input image', height=512) # input image
|
49 |
+
output_mask_ponit = gr.Image(label='Mask-based ROI', height=512) # output binary mask
|
50 |
+
|
51 |
+
with gr.Row():
|
52 |
+
output_mask_point_on_img = gr.Image(label='Mask on image', height=512) # mask on image for better view
|
53 |
+
|
54 |
+
with gr.Column():
|
55 |
+
radio_point = gr.Radio(label='Analysis type', choices=['Quality Score', 'Importance Score', 'Distortion Analysis'], value='Quality Score')
|
56 |
+
output_text_point = gr.Textbox(label='Analysis Results')
|
57 |
+
point_seg_button = gr.Button('Analysis')
|
58 |
+
|
59 |
+
point_example = gr.Dataset(label='Examples', components=[input_image_ponit], samples=self.example_list)
|
60 |
+
|
61 |
+
with gr.TabItem("Mask-based ROIs (BBox)"):
|
62 |
+
with gr.Row():
|
63 |
+
input_image_BBOX = ImageSketcher(type="numpy", label='Input image', height=512)
|
64 |
+
output_mask_BBOX = gr.Image(label='Mask-based ROI', height=512)
|
65 |
+
|
66 |
+
with gr.Row():
|
67 |
+
output_BBOX_mask_on_img = gr.Image(label='Mask on image', height=512)
|
68 |
+
|
69 |
+
with gr.Column():
|
70 |
+
radio_BBOX = gr.Radio(label='Analysis type', choices=['Quality Score', 'Importance Score', 'Distortion Analysis'], value='Quality Score')
|
71 |
+
output_text_BBOX = gr.Textbox(label='ROI Quality Analysis')
|
72 |
+
box_seg_button = gr.Button('Generate mask and analysis')
|
73 |
+
box_analyse_button = gr.Button('Analysis')
|
74 |
+
|
75 |
+
BBOX_example = gr.Dataset(label='Examples', components=[input_image_BBOX], samples=self.example_list)
|
76 |
+
|
77 |
+
# click point
|
78 |
+
input_image_ponit.upload(
|
79 |
+
self.seagull.init_image,
|
80 |
+
[input_image_ponit],
|
81 |
+
[preprocessed_img, input_image_ponit, input_image_BBOX]
|
82 |
+
)
|
83 |
+
|
84 |
+
point_example.click(
|
85 |
+
self.seagull.init_image,
|
86 |
+
[point_example],
|
87 |
+
[preprocessed_img, input_image_ponit, input_image_BBOX]
|
88 |
+
)
|
89 |
+
|
90 |
+
# after clicking on the image
|
91 |
+
input_image_ponit.select(
|
92 |
+
self.sam.img_select_point,
|
93 |
+
[preprocessed_img],
|
94 |
+
[input_image_ponit, output_mask_ponit, output_mask_point_on_img, binary_mask]
|
95 |
+
).then(
|
96 |
+
self.seagull.seagull_predict,
|
97 |
+
[preprocessed_img, binary_mask, radio_point],
|
98 |
+
[output_text_point]
|
99 |
+
)
|
100 |
+
|
101 |
+
point_seg_button.click(
|
102 |
+
self.seagull.seagull_predict,
|
103 |
+
[preprocessed_img, binary_mask, radio_point],
|
104 |
+
[output_text_point]
|
105 |
+
)
|
106 |
+
|
107 |
+
# draw frame
|
108 |
+
input_image_BBOX.upload(
|
109 |
+
self.seagull.init_image,
|
110 |
+
[input_image_BBOX],
|
111 |
+
[preprocessed_img, input_image_ponit, input_image_BBOX]
|
112 |
+
)
|
113 |
+
|
114 |
+
BBOX_example.click(
|
115 |
+
self.seagull.init_image,
|
116 |
+
[BBOX_example],
|
117 |
+
[preprocessed_img, input_image_ponit, input_image_BBOX]
|
118 |
+
)
|
119 |
+
|
120 |
+
# after drawing a frame on the image
|
121 |
+
input_image_BBOX.select(
|
122 |
+
self.sam.gen_box_seg,
|
123 |
+
[input_image_BBOX],
|
124 |
+
[output_mask_BBOX, output_BBOX_mask_on_img, binary_mask]
|
125 |
+
)
|
126 |
+
|
127 |
+
box_seg_button.click(
|
128 |
+
self.sam.gen_box_seg,
|
129 |
+
[input_image_BBOX],
|
130 |
+
[output_mask_BBOX, output_BBOX_mask_on_img, binary_mask]
|
131 |
+
).then(
|
132 |
+
self.seagull.seagull_predict,
|
133 |
+
[preprocessed_img, binary_mask, radio_BBOX],
|
134 |
+
[output_text_BBOX]
|
135 |
+
)
|
136 |
+
|
137 |
+
box_analyse_button.click(
|
138 |
+
self.seagull.seagull_predict,
|
139 |
+
[preprocessed_img, binary_mask, radio_BBOX],
|
140 |
+
[output_text_BBOX]
|
141 |
+
)
|
142 |
+
|
143 |
+
return demo
|
demo/__pycache__/UI.cpython-310.pyc
ADDED
Binary file (4.2 kB). View file
|
|
demo/__pycache__/mask_utils.cpython-310.pyc
ADDED
Binary file (4.72 kB). View file
|
|
demo/__pycache__/sam_inference.cpython-310.pyc
ADDED
Binary file (3.51 kB). View file
|
|
demo/__pycache__/seagull_inference.cpython-310.pyc
ADDED
Binary file (5.48 kB). View file
|
|
demo/mask_utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
class ImageSketcher(gr.Image):
|
8 |
+
"""
|
9 |
+
Code is from https://github.com/jshilong/GPT4RoI/blob/7c157b5f33914f21cfbc804fb301d3ce06324193/gpt4roi/app.py#L365
|
10 |
+
|
11 |
+
Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
|
12 |
+
"""
|
13 |
+
|
14 |
+
is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
|
15 |
+
|
16 |
+
def __init__(self, **kwargs):
|
17 |
+
super().__init__(tool='boxes', **kwargs)
|
18 |
+
|
19 |
+
def preprocess(self, x):
|
20 |
+
if x is None:
|
21 |
+
return x
|
22 |
+
if self.tool == 'boxes' and self.source in ['upload', 'webcam']:
|
23 |
+
if isinstance(x, str):
|
24 |
+
x = {'image': x, 'boxes': []}
|
25 |
+
else:
|
26 |
+
assert isinstance(x, dict)
|
27 |
+
assert isinstance(x['image'], str)
|
28 |
+
assert isinstance(x['boxes'], list)
|
29 |
+
x = super().preprocess(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
def process_mask_to_show(mask):
|
33 |
+
'''
|
34 |
+
Process the mask to show on the gradio.Image
|
35 |
+
'''
|
36 |
+
mask = np.array(mask > 0.1, dtype=np.uint8) * 255
|
37 |
+
mask_stacked = np.stack([mask] * 3, axis=-1)
|
38 |
+
|
39 |
+
return mask_stacked
|
40 |
+
|
41 |
+
def img_add_masks(img_, colored_mask, mask, linewidth=2):
|
42 |
+
if type(img_) is np.ndarray:
|
43 |
+
img = Image.fromarray(img_, mode='RGB').convert('RGBA')
|
44 |
+
else:
|
45 |
+
img = img_.copy()
|
46 |
+
h, w = img.height, img.width
|
47 |
+
# contour
|
48 |
+
temp = np.zeros((h, w, 1))
|
49 |
+
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
50 |
+
cv2.drawContours(temp, contours, -1, (255, 255, 255), linewidth)
|
51 |
+
color = np.array([1, 1, 1, 1])
|
52 |
+
contour_mask = temp * color.reshape(1, 1, -1)
|
53 |
+
|
54 |
+
overlay_inner = Image.fromarray(colored_mask.astype(np.uint8), 'RGBA')
|
55 |
+
img.paste(overlay_inner, (0, 0), overlay_inner)
|
56 |
+
|
57 |
+
overlay_contour = Image.fromarray(contour_mask.astype(np.uint8), 'RGBA')
|
58 |
+
img.paste(overlay_contour, (0, 0), overlay_contour)
|
59 |
+
return img
|
60 |
+
|
61 |
+
def gen_colored_masks(
|
62 |
+
annotation,
|
63 |
+
random_color=False,
|
64 |
+
):
|
65 |
+
"""
|
66 |
+
Code is largely based on https://github.com/CASIA-IVA-Lab/FastSAM/blob/4d153e909f0ad9c8ecd7632566e5a24e21cf0071/utils/tools_gradio.py#L130
|
67 |
+
"""
|
68 |
+
device = annotation.device
|
69 |
+
mask_sum = annotation.shape[0]
|
70 |
+
height = annotation.shape[1]
|
71 |
+
weight = annotation.shape[2]
|
72 |
+
areas = torch.sum(annotation, dim=(1, 2))
|
73 |
+
sorted_indices = torch.argsort(areas, descending=False)
|
74 |
+
annotation = annotation[sorted_indices]
|
75 |
+
|
76 |
+
index = (annotation != 0).to(torch.long).argmax(dim=0)
|
77 |
+
if random_color:
|
78 |
+
color = torch.rand((mask_sum, 1, 1, 3)).to(device)
|
79 |
+
else:
|
80 |
+
color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
|
81 |
+
[30 / 255, 144 / 255, 255 / 255]
|
82 |
+
).to(device)
|
83 |
+
transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
|
84 |
+
visual = torch.cat([color, transparency], dim=-1)
|
85 |
+
mask_image = torch.unsqueeze(annotation, -1) * visual
|
86 |
+
|
87 |
+
mask = torch.zeros((height, weight, 4)).to(device)
|
88 |
+
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
|
89 |
+
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
90 |
+
|
91 |
+
mask[h_indices, w_indices, :] = mask_image[indices]
|
92 |
+
mask_cpu = mask.cpu().numpy()
|
93 |
+
|
94 |
+
return mask_cpu, sorted_indices
|
95 |
+
|
96 |
+
def mask_foreground(mask, trans=0.6, random_color=True):
|
97 |
+
if random_color:
|
98 |
+
color = np.concatenate([np.random.random(3) * 255, np.array([trans * 255])], axis=0)
|
99 |
+
else:
|
100 |
+
color = np.array([30, 144, 255, trans * 255])
|
101 |
+
h, w = mask.shape[-2:]
|
102 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
103 |
+
|
104 |
+
return mask_image
|
105 |
+
|
106 |
+
|
107 |
+
def mask_background(mask, trans=0.5):
|
108 |
+
h, w = mask.shape[-2:]
|
109 |
+
mask_image = (1 - mask.reshape(h, w, 1)) * np.array([0, 0, 0, trans * 255])
|
110 |
+
|
111 |
+
return mask_image
|
112 |
+
|
113 |
+
|
114 |
+
def mask_select_point(all_masks, output_mask_2_raw, mask_order, evt: gr.SelectData):
|
115 |
+
h, w = output_mask_2_raw.height, output_mask_2_raw.width
|
116 |
+
pointed_mask = None
|
117 |
+
for i in range(len(mask_order)):
|
118 |
+
idx = mask_order[i]
|
119 |
+
msk = all_masks[idx]
|
120 |
+
if msk[evt.index[1], evt.index[0]] == 1:
|
121 |
+
pointed_mask = msk.copy()
|
122 |
+
break
|
123 |
+
|
124 |
+
if pointed_mask is not None:
|
125 |
+
contours, hierarchy = cv2.findContours(pointed_mask.astype("uint8"), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
126 |
+
ret = output_mask_2_raw.copy()
|
127 |
+
|
128 |
+
temp = np.zeros((h, w, 1))
|
129 |
+
contours, _ = cv2.findContours(msk.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
130 |
+
cv2.drawContours(temp, contours, -1, (255, 255, 255), 3)
|
131 |
+
color = np.array([1, 1, 1, 1])
|
132 |
+
contour_mask = temp * color.reshape(1, 1, -1)
|
133 |
+
|
134 |
+
colored_mask = mask_background(pointed_mask)
|
135 |
+
|
136 |
+
overlay_inner = Image.fromarray(colored_mask.astype(np.uint8), 'RGBA')
|
137 |
+
ret.paste(overlay_inner, (0, 0), overlay_inner)
|
138 |
+
|
139 |
+
overlay_contour = Image.fromarray(contour_mask.astype(np.uint8), 'RGBA')
|
140 |
+
ret.paste(overlay_contour, (0, 0), overlay_contour)
|
141 |
+
|
142 |
+
return ret, pointed_mask
|
143 |
+
else:
|
144 |
+
return output_mask_2_raw, None
|
demo/sam_inference.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
|
6 |
+
import gradio as gr
|
7 |
+
import cv2
|
8 |
+
from demo.mask_utils import *
|
9 |
+
|
10 |
+
class SAM_Inference:
|
11 |
+
def __init__(self, model_type='vit_b', device='cuda') -> None:
|
12 |
+
models = {
|
13 |
+
'vit_b': './checkpoints/sam_vit_b_01ec64.pth',
|
14 |
+
'vit_l': './checkpoints/sam_vit_l_0b3195.pth',
|
15 |
+
'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
|
16 |
+
}
|
17 |
+
|
18 |
+
sam = sam_model_registry[model_type](checkpoint=models[model_type])
|
19 |
+
sam = sam.to(device)
|
20 |
+
|
21 |
+
self.predictor = SamPredictor(sam)
|
22 |
+
self.mask_generator = SamAutomaticMaskGenerator(model=sam)
|
23 |
+
|
24 |
+
def img_select_point(self, original_img: np.ndarray, evt: gr.SelectData):
|
25 |
+
img = original_img.copy()
|
26 |
+
sel_pix = [(evt.index, 1)] # append the foreground_point
|
27 |
+
|
28 |
+
masks = self.run_inference(original_img, sel_pix)
|
29 |
+
for point, label in sel_pix:
|
30 |
+
cv2.circle(img, point, 5, (240, 240, 240), -1, 0)
|
31 |
+
cv2.circle(img, point, 5, (30, 144, 255), 2, 0)
|
32 |
+
|
33 |
+
mask = masks[0][0]
|
34 |
+
colored_mask = mask_foreground(mask)
|
35 |
+
res = img_add_masks(original_img, colored_mask, mask)
|
36 |
+
return img, process_mask_to_show(mask), res, mask
|
37 |
+
|
38 |
+
def gen_box_seg(self, inp):
|
39 |
+
if inp is None:
|
40 |
+
raise gr.Error("Please upload an image first!")
|
41 |
+
image = inp['image']
|
42 |
+
if len(inp['boxes']) == 0:
|
43 |
+
raise gr.Error("Please clear the raw boxes and draw a box first!")
|
44 |
+
boxes = inp['boxes'][-1]
|
45 |
+
|
46 |
+
input_box = np.array([boxes[0], boxes[1], boxes[2], boxes[3]]).astype(int)
|
47 |
+
|
48 |
+
masks = self.predict_box(image, input_box)
|
49 |
+
|
50 |
+
mask = masks[0][0]
|
51 |
+
colored_mask = mask_foreground(mask)
|
52 |
+
res = img_add_masks(image, colored_mask, mask)
|
53 |
+
|
54 |
+
return process_mask_to_show(mask), res, mask
|
55 |
+
|
56 |
+
def run_inference(self, input_x, selected_points):
|
57 |
+
if len(selected_points) == 0:
|
58 |
+
return []
|
59 |
+
|
60 |
+
self.predictor.set_image(input_x)
|
61 |
+
|
62 |
+
points = torch.Tensor(
|
63 |
+
[p for p, _ in selected_points]
|
64 |
+
).to(self.predictor.device).unsqueeze(0)
|
65 |
+
|
66 |
+
labels = torch.Tensor(
|
67 |
+
[int(l) for _, l in selected_points]
|
68 |
+
).to(self.predictor.device).unsqueeze(0)
|
69 |
+
|
70 |
+
transformed_points = self.predictor.transform.apply_coords_torch(
|
71 |
+
points, input_x.shape[:2])
|
72 |
+
|
73 |
+
# predict segmentation according to the boxes
|
74 |
+
masks, scores, logits = self.predictor.predict_torch(
|
75 |
+
point_coords=transformed_points,
|
76 |
+
point_labels=labels,
|
77 |
+
multimask_output=False,
|
78 |
+
)
|
79 |
+
masks = masks.cpu().detach().numpy()
|
80 |
+
|
81 |
+
gc.collect()
|
82 |
+
torch.cuda.empty_cache()
|
83 |
+
|
84 |
+
return masks
|
85 |
+
|
86 |
+
def predict_box(self, input_x, input_box):
|
87 |
+
self.predictor.set_image(input_x)
|
88 |
+
|
89 |
+
input_boxes = torch.tensor(input_box[None, :], device=self.predictor.device)
|
90 |
+
transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, input_x.shape[:2])
|
91 |
+
|
92 |
+
masks, _, _ = self.predictor.predict_torch(
|
93 |
+
point_coords=None,
|
94 |
+
point_labels=None,
|
95 |
+
boxes=transformed_boxes,
|
96 |
+
multimask_output=False
|
97 |
+
)
|
98 |
+
masks = masks.cpu().detach().numpy()
|
99 |
+
|
100 |
+
gc.collect()
|
101 |
+
torch.cuda.empty_cache()
|
102 |
+
return masks
|
demo/seagull_inference.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from seagull.utils import disable_torch_init
|
3 |
+
from transformers import AutoTokenizer, CLIPImageProcessor
|
4 |
+
from seagull.model.language_model.seagull_llama import SeagullLlamaForCausalLM
|
5 |
+
from seagull.mm_utils import tokenizer_image_token
|
6 |
+
from seagull.conversation import conv_templates, SeparatorStyle
|
7 |
+
from seagull.constants import IMAGE_TOKEN_INDEX
|
8 |
+
from seagull.train.train import DataArguments
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import os
|
12 |
+
import numpy as np
|
13 |
+
import cv2
|
14 |
+
from typing import List
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
class Seagull():
|
18 |
+
def __init__(self, model_path, device='cuda'):
|
19 |
+
disable_torch_init()
|
20 |
+
model_path = os.path.expanduser(model_path)
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=2048, padding_side="right", use_fast=True)
|
22 |
+
self.model = SeagullLlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16,).to(device)
|
23 |
+
self.tokenizer.pad_token = self.tokenizer.unk_token
|
24 |
+
|
25 |
+
self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":512}, resample=3, do_center_crop=True, crop_size={"height": 512, "width": 512},
|
26 |
+
do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
|
27 |
+
image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
|
28 |
+
|
29 |
+
spi_tokens = ['<global>', '<local>']
|
30 |
+
self.tokenizer.add_tokens(spi_tokens, special_tokens=True)
|
31 |
+
|
32 |
+
for m in self.model.modules():
|
33 |
+
m.tokenizer = self.tokenizer
|
34 |
+
|
35 |
+
vision_tower = self.model.get_vision_tower()
|
36 |
+
if not vision_tower.is_loaded:
|
37 |
+
vision_tower.load_model()
|
38 |
+
vision_tower.to(dtype=torch.float16, device=device)
|
39 |
+
|
40 |
+
begin_str = "<image>\nThis provides an overview of the image.\n Please answer the following questions about the provided region. Note: Distortions include: blur, colorfulness, compression, contrast exposure and noise.\n Here is the region <global><local>. "
|
41 |
+
|
42 |
+
instruction = {
|
43 |
+
'distortion analysis': 'Provide the distortion type of this region.',
|
44 |
+
'quality score': 'Analyze the quality of this region.',
|
45 |
+
'importance score': 'Consider the impact of this region on the overall image quality. Analyze its importance to the overall image quality.'
|
46 |
+
}
|
47 |
+
|
48 |
+
self.ids_input = {}
|
49 |
+
for ins_type, ins in instruction.items():
|
50 |
+
conv = conv_templates['seagull_v1'].copy()
|
51 |
+
qs = begin_str + ins
|
52 |
+
conv.append_message(conv.roles[0], qs)
|
53 |
+
conv.append_message(conv.roles[1], None)
|
54 |
+
prompt = conv.get_prompt()
|
55 |
+
self.ids_input[ins_type] = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
|
56 |
+
|
57 |
+
self.stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
58 |
+
|
59 |
+
def init_image(self, img):
|
60 |
+
if isinstance(img, dict):
|
61 |
+
img = img['image']
|
62 |
+
elif isinstance(img, List):
|
63 |
+
img = cv2.imread(img[0])
|
64 |
+
img = img[:, :, ::-1]
|
65 |
+
h_, w_ = img.shape[:2]
|
66 |
+
if h_ > 512:
|
67 |
+
ratio = 512 / h_
|
68 |
+
new_h, new_w = int(h_ * ratio), int(w_ * ratio)
|
69 |
+
preprocessed_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
70 |
+
else:
|
71 |
+
preprocessed_img = img.copy()
|
72 |
+
|
73 |
+
return (preprocessed_img, preprocessed_img, preprocessed_img)
|
74 |
+
|
75 |
+
def preprocess(self, img):
|
76 |
+
image = self.image_processor.preprocess(img,
|
77 |
+
do_center_crop=False,
|
78 |
+
return_tensors='pt')['pixel_values'][0]
|
79 |
+
|
80 |
+
image = torch.nn.functional.interpolate(image.unsqueeze(0),
|
81 |
+
size=(512, 512),
|
82 |
+
mode='bilinear',
|
83 |
+
align_corners=False).squeeze(0)
|
84 |
+
|
85 |
+
return image
|
86 |
+
|
87 |
+
def seagull_predict(self, img, mask, instruct_type):
|
88 |
+
image = self.preprocess(img)
|
89 |
+
|
90 |
+
mask = np.array(mask, dtype=np.int)
|
91 |
+
ys, xs = np.where(mask > 0)
|
92 |
+
if len(xs) > 0 and len(ys) > 0:
|
93 |
+
# Find the minimal bounding rectangle for the entire mask
|
94 |
+
x_min, x_max = np.min(xs), np.max(xs)
|
95 |
+
y_min, y_max = np.min(ys), np.max(ys)
|
96 |
+
w1 = x_max - x_min
|
97 |
+
h1 = y_max - y_min
|
98 |
+
|
99 |
+
bounding_box = (x_min, y_min, w1, h1)
|
100 |
+
else:
|
101 |
+
bounding_box = None
|
102 |
+
|
103 |
+
mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)
|
104 |
+
mask = np.array(mask > 0.1, dtype=np.uint8)
|
105 |
+
masks = torch.Tensor(mask).unsqueeze(0).to(self.model.device)
|
106 |
+
|
107 |
+
input_ids = self.ids_input[instruct_type.lower()]
|
108 |
+
|
109 |
+
x1, y1, w1, h1 = list(map(int, bounding_box)) # x y w h
|
110 |
+
cropped_img = img[y1:y1 + h1, x1:x1 + w1]
|
111 |
+
cropped_img = Image.fromarray(cropped_img)
|
112 |
+
cropped_img = self.preprocess(cropped_img)
|
113 |
+
|
114 |
+
with torch.inference_mode():
|
115 |
+
|
116 |
+
self.model.orig_forward = self.model.forward
|
117 |
+
self.model.forward = partial(self.model.orig_forward,
|
118 |
+
img_metas=[None],
|
119 |
+
masks=[masks.half()],
|
120 |
+
cropped_img=cropped_img.unsqueeze(0)
|
121 |
+
)
|
122 |
+
output_ids = self.model.generate(
|
123 |
+
input_ids,
|
124 |
+
images=image.unsqueeze(0).half().to(self.model.device),
|
125 |
+
do_sample=False,
|
126 |
+
temperature=1,
|
127 |
+
max_new_tokens=2048,
|
128 |
+
use_cache=True,
|
129 |
+
num_beams=1,
|
130 |
+
top_k = 0, # 不进行topk
|
131 |
+
top_p = 1, # 累计概率为
|
132 |
+
)
|
133 |
+
|
134 |
+
self.model.forward = self.model.orig_forward
|
135 |
+
|
136 |
+
input_token_len = input_ids.shape[1]
|
137 |
+
n_diff_input_output = (
|
138 |
+
input_ids != output_ids[:, :input_token_len]).sum().item()
|
139 |
+
if n_diff_input_output > 0:
|
140 |
+
print(
|
141 |
+
f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
142 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
|
143 |
+
skip_special_tokens=True)[0]
|
144 |
+
|
145 |
+
outputs = outputs.strip()
|
146 |
+
if outputs.endswith(self.stop_str):
|
147 |
+
outputs = outputs[:-len(self.stop_str)]
|
148 |
+
outputs = outputs.strip()
|
149 |
+
if ':' in outputs:
|
150 |
+
outputs = outputs.split(':')[1]
|
151 |
+
|
152 |
+
outputs_list = outputs.split('.')
|
153 |
+
outputs_list_final = []
|
154 |
+
outputs_str = ''
|
155 |
+
for output in outputs_list:
|
156 |
+
if output not in outputs_list_final:
|
157 |
+
if output=='':
|
158 |
+
continue
|
159 |
+
outputs_list_final.append(output)
|
160 |
+
outputs_str+=output+'.'
|
161 |
+
else:
|
162 |
+
break
|
163 |
+
return outputs_str
|
imgs/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
imgs/Examples/1.png
ADDED
imgs/Examples/2.png
ADDED
seagull/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import SeagullLlamaForCausalLM
|
seagull/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (190 Bytes). View file
|
|
seagull/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (450 Bytes). View file
|
|
seagull/__pycache__/conversation.cpython-310.pyc
ADDED
Binary file (10.3 kB). View file
|
|
seagull/__pycache__/mm_utils.cpython-310.pyc
ADDED
Binary file (4.25 kB). View file
|
|
seagull/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.99 kB). View file
|
|
seagull/builder.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import os
|
17 |
+
import warnings
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
21 |
+
import torch
|
22 |
+
from seagull.model import *
|
23 |
+
from seagull.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
+
|
25 |
+
|
26 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
|
27 |
+
kwargs = {"device_map": device_map, **kwargs}
|
28 |
+
|
29 |
+
if device != "cuda":
|
30 |
+
kwargs['device_map'] = {"": device}
|
31 |
+
|
32 |
+
if load_8bit:
|
33 |
+
kwargs['load_in_8bit'] = True
|
34 |
+
elif load_4bit:
|
35 |
+
kwargs['load_in_4bit'] = True
|
36 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
37 |
+
load_in_4bit=True,
|
38 |
+
bnb_4bit_compute_dtype=torch.float16,
|
39 |
+
bnb_4bit_use_double_quant=True,
|
40 |
+
bnb_4bit_quant_type='nf4'
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
kwargs['torch_dtype'] = torch.float16
|
44 |
+
|
45 |
+
if use_flash_attn:
|
46 |
+
kwargs['attn_implementation'] = 'flash_attention_2'
|
47 |
+
|
48 |
+
if 'seagull' in model_name.lower() or True:
|
49 |
+
# Load LLaVA model
|
50 |
+
if 'lora' in model_name.lower() and model_base is None:
|
51 |
+
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
|
52 |
+
if 'lora' in model_name.lower() and model_base is not None or True:
|
53 |
+
from seagull.model.language_model.seagull_llama import SeagullConfig
|
54 |
+
lora_cfg_pretrained = SeagullConfig.from_pretrained(model_path)
|
55 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
56 |
+
print('Loading LLaVA from base model...')
|
57 |
+
model = SeagullLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
|
58 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
59 |
+
if model.lm_head.weight.shape[0] != token_num:
|
60 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
61 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
62 |
+
|
63 |
+
print('Loading additional LLaVA weights...')
|
64 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
65 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
66 |
+
else:
|
67 |
+
# this is probably from HF Hub
|
68 |
+
from huggingface_hub import hf_hub_download
|
69 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
70 |
+
cache_file = hf_hub_download(
|
71 |
+
repo_id=repo_id,
|
72 |
+
filename=filename,
|
73 |
+
subfolder=subfolder)
|
74 |
+
return torch.load(cache_file, map_location='cpu')
|
75 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
76 |
+
|
77 |
+
for k, v in non_lora_trainables.items():
|
78 |
+
print(k)
|
79 |
+
print('print non lora')
|
80 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
81 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
82 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
83 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
84 |
+
|
85 |
+
from peft import PeftModel
|
86 |
+
print('Loading LoRA weights...')
|
87 |
+
model = PeftModel.from_pretrained(model, model_path)
|
88 |
+
print('Merging LoRA weights...')
|
89 |
+
model = model.merge_and_unload()
|
90 |
+
print('Model is loaded...')
|
91 |
+
elif model_base is not None:
|
92 |
+
# this may be mm projector only
|
93 |
+
print('Loading LLaVA from base model...')
|
94 |
+
if 'mpt' in model_name.lower():
|
95 |
+
if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
|
96 |
+
shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
|
97 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
|
98 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
99 |
+
model = SeagullMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
100 |
+
else:
|
101 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
102 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
103 |
+
model = SeagullLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
104 |
+
|
105 |
+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
106 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
107 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
108 |
+
else:
|
109 |
+
if 'mpt' in model_name.lower():
|
110 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
111 |
+
model = SeagullMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
112 |
+
elif 'mistral' in model_name.lower():
|
113 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
114 |
+
model = SeagullMistralForCausalLM.from_pretrained(
|
115 |
+
model_path,
|
116 |
+
low_cpu_mem_usage=True,
|
117 |
+
**kwargs
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
121 |
+
model = SeagullLlamaForCausalLM.from_pretrained(
|
122 |
+
model_path,
|
123 |
+
low_cpu_mem_usage=True,
|
124 |
+
**kwargs
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
# Load language model
|
128 |
+
if model_base is not None:
|
129 |
+
# PEFT model
|
130 |
+
from peft import PeftModel
|
131 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
132 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
|
133 |
+
print(f"Loading LoRA weights from {model_path}")
|
134 |
+
model = PeftModel.from_pretrained(model, model_path)
|
135 |
+
print(f"Merging weights")
|
136 |
+
model = model.merge_and_unload()
|
137 |
+
print('Convert to FP16...')
|
138 |
+
model.to(torch.float16)
|
139 |
+
else:
|
140 |
+
use_fast = False
|
141 |
+
if 'mpt' in model_name.lower():
|
142 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
143 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
|
144 |
+
else:
|
145 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
146 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
147 |
+
|
148 |
+
image_processor = None
|
149 |
+
|
150 |
+
if 'seagull' in model_name.lower() or True:
|
151 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
152 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
153 |
+
if mm_use_im_patch_token:
|
154 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
155 |
+
if mm_use_im_start_end:
|
156 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
157 |
+
model.resize_token_embeddings(len(tokenizer))
|
158 |
+
|
159 |
+
vision_tower = model.get_vision_tower()
|
160 |
+
if not vision_tower.is_loaded:
|
161 |
+
vision_tower.load_model(device_map=device_map)
|
162 |
+
if device_map != 'auto':
|
163 |
+
vision_tower.to(device=device_map, dtype=torch.float16)
|
164 |
+
image_processor = vision_tower.image_processor
|
165 |
+
|
166 |
+
if hasattr(model.config, "max_sequence_length"):
|
167 |
+
context_len = model.config.max_sequence_length
|
168 |
+
else:
|
169 |
+
context_len = 2048
|
170 |
+
|
171 |
+
return tokenizer, model, image_processor, context_len
|
seagull/constants.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
seagull/conversation.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
MPT = auto()
|
11 |
+
PLAIN = auto()
|
12 |
+
LLAMA_2 = auto()
|
13 |
+
|
14 |
+
|
15 |
+
@dataclasses.dataclass
|
16 |
+
class Conversation:
|
17 |
+
"""A class that keeps all conversation history."""
|
18 |
+
system: str
|
19 |
+
roles: List[str]
|
20 |
+
messages: List[List[str]]
|
21 |
+
offset: int
|
22 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
23 |
+
sep: str = "###"
|
24 |
+
sep2: str = None
|
25 |
+
version: str = "Unknown"
|
26 |
+
|
27 |
+
skip_next: bool = False
|
28 |
+
|
29 |
+
def get_prompt(self):
|
30 |
+
messages = self.messages
|
31 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
32 |
+
messages = self.messages.copy()
|
33 |
+
init_role, init_msg = messages[0].copy()
|
34 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
35 |
+
if 'mmtag' in self.version:
|
36 |
+
messages[0] = (init_role, init_msg)
|
37 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
38 |
+
messages.insert(1, (self.roles[1], "Received."))
|
39 |
+
else:
|
40 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
41 |
+
|
42 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
43 |
+
ret = self.system + self.sep
|
44 |
+
for role, message in messages:
|
45 |
+
if message:
|
46 |
+
if type(message) is tuple:
|
47 |
+
message, _, _ = message
|
48 |
+
ret += role + ": " + message + self.sep
|
49 |
+
else:
|
50 |
+
ret += role + ":"
|
51 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
52 |
+
seps = [self.sep, self.sep2]
|
53 |
+
ret = self.system + seps[0]
|
54 |
+
for i, (role, message) in enumerate(messages):
|
55 |
+
if message:
|
56 |
+
if type(message) is tuple:
|
57 |
+
message, _, _ = message
|
58 |
+
ret += role + ": " + message + seps[i % 2]
|
59 |
+
else:
|
60 |
+
ret += role + ":"
|
61 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
62 |
+
ret = self.system + self.sep
|
63 |
+
for role, message in messages:
|
64 |
+
if message:
|
65 |
+
if type(message) is tuple:
|
66 |
+
message, _, _ = message
|
67 |
+
ret += role + message + self.sep
|
68 |
+
else:
|
69 |
+
ret += role
|
70 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
71 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
72 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
73 |
+
ret = ""
|
74 |
+
|
75 |
+
for i, (role, message) in enumerate(messages):
|
76 |
+
if i == 0:
|
77 |
+
assert message, "first message should not be none"
|
78 |
+
assert role == self.roles[0], "first message should come from user"
|
79 |
+
if message:
|
80 |
+
if type(message) is tuple:
|
81 |
+
message, _, _ = message
|
82 |
+
if i == 0: message = wrap_sys(self.system) + message
|
83 |
+
if i % 2 == 0:
|
84 |
+
message = wrap_inst(message)
|
85 |
+
ret += self.sep + message
|
86 |
+
else:
|
87 |
+
ret += " " + message + " " + self.sep2
|
88 |
+
else:
|
89 |
+
ret += ""
|
90 |
+
ret = ret.lstrip(self.sep)
|
91 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
92 |
+
seps = [self.sep, self.sep2]
|
93 |
+
ret = self.system
|
94 |
+
for i, (role, message) in enumerate(messages):
|
95 |
+
if message:
|
96 |
+
if type(message) is tuple:
|
97 |
+
message, _, _ = message
|
98 |
+
ret += message + seps[i % 2]
|
99 |
+
else:
|
100 |
+
ret += ""
|
101 |
+
else:
|
102 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
103 |
+
|
104 |
+
return ret
|
105 |
+
|
106 |
+
def append_message(self, role, message):
|
107 |
+
self.messages.append([role, message])
|
108 |
+
|
109 |
+
def get_images(self, return_pil=False):
|
110 |
+
images = []
|
111 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
112 |
+
if i % 2 == 0:
|
113 |
+
if type(msg) is tuple:
|
114 |
+
import base64
|
115 |
+
from io import BytesIO
|
116 |
+
from PIL import Image
|
117 |
+
msg, image, image_process_mode = msg
|
118 |
+
if image_process_mode == "Pad":
|
119 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
120 |
+
width, height = pil_img.size
|
121 |
+
if width == height:
|
122 |
+
return pil_img
|
123 |
+
elif width > height:
|
124 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
125 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
126 |
+
return result
|
127 |
+
else:
|
128 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
129 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
130 |
+
return result
|
131 |
+
image = expand2square(image)
|
132 |
+
elif image_process_mode in ["Default", "Crop"]:
|
133 |
+
pass
|
134 |
+
elif image_process_mode == "Resize":
|
135 |
+
image = image.resize((336, 336))
|
136 |
+
else:
|
137 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
138 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
139 |
+
aspect_ratio = max_hw / min_hw
|
140 |
+
max_len, min_len = 800, 400
|
141 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
142 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
143 |
+
W, H = image.size
|
144 |
+
if longest_edge != max(image.size):
|
145 |
+
if H > W:
|
146 |
+
H, W = longest_edge, shortest_edge
|
147 |
+
else:
|
148 |
+
H, W = shortest_edge, longest_edge
|
149 |
+
image = image.resize((W, H))
|
150 |
+
if return_pil:
|
151 |
+
images.append(image)
|
152 |
+
else:
|
153 |
+
buffered = BytesIO()
|
154 |
+
image.save(buffered, format="PNG")
|
155 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
156 |
+
images.append(img_b64_str)
|
157 |
+
return images
|
158 |
+
|
159 |
+
def to_gradio_chatbot(self):
|
160 |
+
ret = []
|
161 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
162 |
+
if i % 2 == 0:
|
163 |
+
if type(msg) is tuple:
|
164 |
+
import base64
|
165 |
+
from io import BytesIO
|
166 |
+
msg, image, image_process_mode = msg
|
167 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
168 |
+
aspect_ratio = max_hw / min_hw
|
169 |
+
max_len, min_len = 800, 400
|
170 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
171 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
172 |
+
W, H = image.size
|
173 |
+
if H > W:
|
174 |
+
H, W = longest_edge, shortest_edge
|
175 |
+
else:
|
176 |
+
H, W = shortest_edge, longest_edge
|
177 |
+
image = image.resize((W, H))
|
178 |
+
buffered = BytesIO()
|
179 |
+
image.save(buffered, format="JPEG")
|
180 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
181 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
182 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
183 |
+
ret.append([msg, None])
|
184 |
+
else:
|
185 |
+
ret.append([msg, None])
|
186 |
+
else:
|
187 |
+
ret[-1][-1] = msg
|
188 |
+
return ret
|
189 |
+
|
190 |
+
def copy(self):
|
191 |
+
return Conversation(
|
192 |
+
system=self.system,
|
193 |
+
roles=self.roles,
|
194 |
+
messages=[[x, y] for x, y in self.messages],
|
195 |
+
offset=self.offset,
|
196 |
+
sep_style=self.sep_style,
|
197 |
+
sep=self.sep,
|
198 |
+
sep2=self.sep2,
|
199 |
+
version=self.version)
|
200 |
+
|
201 |
+
def dict(self):
|
202 |
+
if len(self.get_images()) > 0:
|
203 |
+
return {
|
204 |
+
"system": self.system,
|
205 |
+
"roles": self.roles,
|
206 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
207 |
+
"offset": self.offset,
|
208 |
+
"sep": self.sep,
|
209 |
+
"sep2": self.sep2,
|
210 |
+
}
|
211 |
+
return {
|
212 |
+
"system": self.system,
|
213 |
+
"roles": self.roles,
|
214 |
+
"messages": self.messages,
|
215 |
+
"offset": self.offset,
|
216 |
+
"sep": self.sep,
|
217 |
+
"sep2": self.sep2,
|
218 |
+
}
|
219 |
+
|
220 |
+
|
221 |
+
conv_vicuna_v0 = Conversation(
|
222 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
223 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
224 |
+
roles=("Human", "Assistant"),
|
225 |
+
messages=(
|
226 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
227 |
+
("Assistant",
|
228 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
229 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
230 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
231 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
232 |
+
"renewable and non-renewable energy sources:\n"
|
233 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
234 |
+
"energy sources are finite and will eventually run out.\n"
|
235 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
236 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
237 |
+
"and other negative effects.\n"
|
238 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
239 |
+
"have lower operational costs than non-renewable sources.\n"
|
240 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
241 |
+
"locations than non-renewable sources.\n"
|
242 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
243 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
244 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
245 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
246 |
+
),
|
247 |
+
offset=2,
|
248 |
+
sep_style=SeparatorStyle.SINGLE,
|
249 |
+
sep="###",
|
250 |
+
)
|
251 |
+
|
252 |
+
conv_vicuna_v1 = Conversation(
|
253 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
254 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
255 |
+
roles=("USER", "ASSISTANT"),
|
256 |
+
version="v1",
|
257 |
+
messages=(),
|
258 |
+
offset=0,
|
259 |
+
sep_style=SeparatorStyle.TWO,
|
260 |
+
sep=" ",
|
261 |
+
sep2="</s>",
|
262 |
+
)
|
263 |
+
|
264 |
+
conv_llama_2 = Conversation(
|
265 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
266 |
+
|
267 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
268 |
+
roles=("USER", "ASSISTANT"),
|
269 |
+
version="llama_v2",
|
270 |
+
messages=(),
|
271 |
+
offset=0,
|
272 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
273 |
+
sep="<s>",
|
274 |
+
sep2="</s>",
|
275 |
+
)
|
276 |
+
|
277 |
+
conv_seagull_llama_2 = Conversation(
|
278 |
+
system="You are a helpful language and vision assistant. "
|
279 |
+
"You are able to understand the visual content that the user provides, "
|
280 |
+
"and assist the user with a variety of tasks using natural language.",
|
281 |
+
roles=("USER", "ASSISTANT"),
|
282 |
+
version="llama_v2",
|
283 |
+
messages=(),
|
284 |
+
offset=0,
|
285 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
286 |
+
sep="<s>",
|
287 |
+
sep2="</s>",
|
288 |
+
)
|
289 |
+
|
290 |
+
conv_mpt = Conversation(
|
291 |
+
system="""<|im_start|>system
|
292 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
293 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
294 |
+
version="mpt",
|
295 |
+
messages=(),
|
296 |
+
offset=0,
|
297 |
+
sep_style=SeparatorStyle.MPT,
|
298 |
+
sep="<|im_end|>",
|
299 |
+
)
|
300 |
+
|
301 |
+
conv_seagull_plain = Conversation(
|
302 |
+
system="",
|
303 |
+
roles=("", ""),
|
304 |
+
messages=(
|
305 |
+
),
|
306 |
+
offset=0,
|
307 |
+
sep_style=SeparatorStyle.PLAIN,
|
308 |
+
sep="\n",
|
309 |
+
)
|
310 |
+
|
311 |
+
conv_seagull_v0 = Conversation(
|
312 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
313 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
314 |
+
roles=("Human", "Assistant"),
|
315 |
+
messages=(
|
316 |
+
),
|
317 |
+
offset=0,
|
318 |
+
sep_style=SeparatorStyle.SINGLE,
|
319 |
+
sep="###",
|
320 |
+
)
|
321 |
+
|
322 |
+
conv_seagull_v0_mmtag = Conversation(
|
323 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
324 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
325 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
326 |
+
roles=("Human", "Assistant"),
|
327 |
+
messages=(
|
328 |
+
),
|
329 |
+
offset=0,
|
330 |
+
sep_style=SeparatorStyle.SINGLE,
|
331 |
+
sep="###",
|
332 |
+
version="v0_mmtag",
|
333 |
+
)
|
334 |
+
|
335 |
+
conv_seagull_v1 = Conversation(
|
336 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
337 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
338 |
+
roles=("USER", "ASSISTANT"),
|
339 |
+
version="v1",
|
340 |
+
messages=(),
|
341 |
+
offset=0,
|
342 |
+
sep_style=SeparatorStyle.TWO,
|
343 |
+
sep=" ",
|
344 |
+
sep2="</s>",
|
345 |
+
)
|
346 |
+
|
347 |
+
conv_seagull_v1_mmtag = Conversation(
|
348 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
349 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
350 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
351 |
+
roles=("USER", "ASSISTANT"),
|
352 |
+
messages=(),
|
353 |
+
offset=0,
|
354 |
+
sep_style=SeparatorStyle.TWO,
|
355 |
+
sep=" ",
|
356 |
+
sep2="</s>",
|
357 |
+
version="v1_mmtag",
|
358 |
+
)
|
359 |
+
|
360 |
+
default_conversation = conv_vicuna_v0
|
361 |
+
conv_templates = {
|
362 |
+
"default": conv_vicuna_v0,
|
363 |
+
"v0": conv_vicuna_v0,
|
364 |
+
"v1": conv_vicuna_v1,
|
365 |
+
"vicuna_v1": conv_vicuna_v1,
|
366 |
+
"llama_2": conv_llama_2,
|
367 |
+
|
368 |
+
"plain": conv_seagull_plain,
|
369 |
+
"v0_plain": conv_seagull_plain,
|
370 |
+
"seagull_v0": conv_seagull_v0,
|
371 |
+
"v0_mmtag": conv_seagull_v0_mmtag,
|
372 |
+
"seagull_v1": conv_seagull_v1,
|
373 |
+
"v1_mmtag": conv_seagull_v1_mmtag,
|
374 |
+
"seagull_llama_2": conv_seagull_llama_2,
|
375 |
+
|
376 |
+
"mpt": conv_mpt,
|
377 |
+
}
|
378 |
+
|
379 |
+
|
380 |
+
if __name__ == "__main__":
|
381 |
+
print(default_conversation.get_prompt())
|
seagull/mm_utils.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import StoppingCriteria
|
7 |
+
from seagull.constants import IMAGE_TOKEN_INDEX
|
8 |
+
|
9 |
+
|
10 |
+
def load_image_from_base64(image):
|
11 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
12 |
+
|
13 |
+
def expand2square(pil_img, background_color):
|
14 |
+
width, height = pil_img.size
|
15 |
+
if width == height:
|
16 |
+
return pil_img
|
17 |
+
elif width > height:
|
18 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
19 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
20 |
+
return result
|
21 |
+
else:
|
22 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
23 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
24 |
+
return result
|
25 |
+
|
26 |
+
|
27 |
+
def process_images(images, image_processor, model_cfg):
|
28 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
29 |
+
new_images = []
|
30 |
+
if image_aspect_ratio == 'pad':
|
31 |
+
for image in images:
|
32 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
33 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
34 |
+
new_images.append(image)
|
35 |
+
else:
|
36 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
37 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
38 |
+
new_images = torch.stack(new_images, dim=0)
|
39 |
+
return new_images
|
40 |
+
|
41 |
+
|
42 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
43 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
44 |
+
|
45 |
+
def insert_separator(X, sep):
|
46 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
47 |
+
|
48 |
+
input_ids = []
|
49 |
+
offset = 0
|
50 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
51 |
+
offset = 1
|
52 |
+
input_ids.append(prompt_chunks[0][0])
|
53 |
+
|
54 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
55 |
+
input_ids.extend(x[offset:])
|
56 |
+
|
57 |
+
if return_tensors is not None:
|
58 |
+
if return_tensors == 'pt':
|
59 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
60 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
61 |
+
return input_ids
|
62 |
+
|
63 |
+
|
64 |
+
def get_model_name_from_path(model_path):
|
65 |
+
model_path = model_path.strip("/")
|
66 |
+
model_paths = model_path.split("/")
|
67 |
+
if model_paths[-1].startswith('checkpoint-'):
|
68 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
69 |
+
else:
|
70 |
+
return model_paths[-1]
|
71 |
+
|
72 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
73 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
74 |
+
self.keywords = keywords
|
75 |
+
self.keyword_ids = []
|
76 |
+
for keyword in keywords:
|
77 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
78 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
79 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
80 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
81 |
+
self.tokenizer = tokenizer
|
82 |
+
self.start_len = input_ids.shape[1]
|
83 |
+
|
84 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
85 |
+
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
86 |
+
offset = min(output_ids.shape[1] - self.start_len, 3)
|
87 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
88 |
+
for keyword_id in self.keyword_ids:
|
89 |
+
if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
|
90 |
+
return True
|
91 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
92 |
+
for keyword in self.keywords:
|
93 |
+
if keyword in outputs:
|
94 |
+
return True
|
95 |
+
return False
|
seagull/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .language_model.seagull_llama import SeagullLlamaForCausalLM, SeagullConfig
|
seagull/model/__pycache__/Q_A.cpython-310.pyc
ADDED
Binary file (957 Bytes). View file
|
|
seagull/model/__pycache__/Q_A_pretrain.cpython-310.pyc
ADDED
Binary file (2.32 kB). View file
|
|
seagull/model/__pycache__/Q_A_pretrain_level.cpython-310.pyc
ADDED
Binary file (2.62 kB). View file
|
|
seagull/model/__pycache__/Q_A_stage3.cpython-310.pyc
ADDED
Binary file (5.74 kB). View file
|
|
seagull/model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (243 Bytes). View file
|
|
seagull/model/__pycache__/layer.cpython-310.pyc
ADDED
Binary file (8.2 kB). View file
|
|
seagull/model/__pycache__/layer_osprey.cpython-310.pyc
ADDED
Binary file (3.72 kB). View file
|
|
seagull/model/__pycache__/osprey_arch.cpython-310.pyc
ADDED
Binary file (9.29 kB). View file
|
|
seagull/model/__pycache__/seagull_arch.cpython-310.pyc
ADDED
Binary file (9.26 kB). View file
|
|
seagull/model/__pycache__/stage2_distrotion_maker.cpython-310.pyc
ADDED
Binary file (3.7 kB). View file
|
|
seagull/model/consolidate.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
+
from seagull.model import *
|
7 |
+
from seagull.model.utils import auto_upgrade
|
8 |
+
|
9 |
+
|
10 |
+
def consolidate_ckpt(src_path, dst_path):
|
11 |
+
print("Loading model")
|
12 |
+
auto_upgrade(src_path)
|
13 |
+
src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
14 |
+
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
|
15 |
+
src_model.save_pretrained(dst_path)
|
16 |
+
src_tokenizer.save_pretrained(dst_path)
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--src", type=str, required=True)
|
22 |
+
parser.add_argument("--dst", type=str, required=True)
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
consolidate_ckpt(args.src, args.dst)
|
seagull/model/language_model/__pycache__/osprey_llama.cpython-310.pyc
ADDED
Binary file (3.87 kB). View file
|
|
seagull/model/language_model/__pycache__/seagull_llama.cpython-310.pyc
ADDED
Binary file (3.82 kB). View file
|
|
seagull/model/language_model/seagull_llama.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import CrossEntropyLoss
|
5 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
6 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
7 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
8 |
+
from ..seagull_arch import SeagullMetaModel, SeagullMetaForCausalLM
|
9 |
+
from ..layer import MaskExtractor
|
10 |
+
|
11 |
+
class SeagullConfig(LlamaConfig):
|
12 |
+
model_type = "seagull"
|
13 |
+
|
14 |
+
class SeagullLlamaModel(SeagullMetaModel, LlamaModel):
|
15 |
+
config_class = SeagullConfig
|
16 |
+
|
17 |
+
def __init__(self, config: LlamaConfig):
|
18 |
+
super(SeagullLlamaModel, self).__init__(config)
|
19 |
+
|
20 |
+
class SeagullLlamaForCausalLM(LlamaForCausalLM, SeagullMetaForCausalLM):
|
21 |
+
config_class = SeagullConfig
|
22 |
+
|
23 |
+
def __init__(self, config):
|
24 |
+
super(LlamaForCausalLM, self).__init__(config)
|
25 |
+
self.model = SeagullLlamaModel(config)
|
26 |
+
|
27 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
28 |
+
self.mask_extractor = MaskExtractor()
|
29 |
+
|
30 |
+
self.post_init()
|
31 |
+
|
32 |
+
def get_model(self):
|
33 |
+
return self.model
|
34 |
+
|
35 |
+
def forward(
|
36 |
+
self,
|
37 |
+
input_ids: torch.LongTensor = None,
|
38 |
+
attention_mask: Optional[torch.Tensor] = None,
|
39 |
+
img_metas = None,
|
40 |
+
masks = None,
|
41 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
42 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
43 |
+
labels: Optional[torch.LongTensor] = None,
|
44 |
+
use_cache: Optional[bool] = None,
|
45 |
+
output_attentions: Optional[bool] = None,
|
46 |
+
output_hidden_states: Optional[bool] = None,
|
47 |
+
images: Optional[torch.FloatTensor] = None,
|
48 |
+
preprocessed_img_dict = None,
|
49 |
+
return_dict: Optional[bool] = None,
|
50 |
+
cropped_img: Optional[torch.FloatTensor] = None,
|
51 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
52 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
53 |
+
output_hidden_states = (
|
54 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
55 |
+
)
|
56 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
57 |
+
|
58 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, masks, attention_mask, past_key_values, labels, images, preprocessed_img_dict=preprocessed_img_dict, cropped_img=cropped_img)
|
59 |
+
|
60 |
+
if inputs_embeds is not None:
|
61 |
+
inputs_embeds = inputs_embeds.bfloat16()
|
62 |
+
|
63 |
+
self.model = self.model.bfloat16()
|
64 |
+
|
65 |
+
outputs = self.model(
|
66 |
+
input_ids=input_ids,
|
67 |
+
attention_mask=attention_mask,
|
68 |
+
past_key_values=past_key_values,
|
69 |
+
inputs_embeds=inputs_embeds,
|
70 |
+
use_cache=use_cache,
|
71 |
+
output_attentions=output_attentions,
|
72 |
+
output_hidden_states=output_hidden_states,
|
73 |
+
return_dict=return_dict
|
74 |
+
)
|
75 |
+
|
76 |
+
hidden_states = outputs[0]
|
77 |
+
self.lm_head = self.lm_head.to(hidden_states.dtype)
|
78 |
+
logits = self.lm_head(hidden_states)
|
79 |
+
|
80 |
+
loss = None
|
81 |
+
if labels is not None:
|
82 |
+
# Shift so that tokens < n predict n
|
83 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
84 |
+
shift_labels = labels[..., 1:].contiguous()
|
85 |
+
# Flatten the tokens
|
86 |
+
loss_fct = CrossEntropyLoss()
|
87 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
88 |
+
shift_labels = shift_labels.view(-1)
|
89 |
+
# Enable model/pipeline parallelism
|
90 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
91 |
+
loss = loss_fct(shift_logits, shift_labels)
|
92 |
+
|
93 |
+
if not return_dict:
|
94 |
+
output = (logits,) + outputs[1:]
|
95 |
+
return (loss,) + output if loss is not None else output
|
96 |
+
|
97 |
+
return CausalLMOutputWithPast(
|
98 |
+
loss=loss,
|
99 |
+
logits=logits,
|
100 |
+
past_key_values=outputs.past_key_values,
|
101 |
+
hidden_states=outputs.hidden_states,
|
102 |
+
attentions=outputs.attentions,
|
103 |
+
)
|
104 |
+
|
105 |
+
def prepare_inputs_for_generation(
|
106 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
107 |
+
):
|
108 |
+
if past_key_values:
|
109 |
+
input_ids = input_ids[:, -1:]
|
110 |
+
|
111 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
112 |
+
if inputs_embeds is not None and past_key_values is None:
|
113 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
114 |
+
else:
|
115 |
+
model_inputs = {"input_ids": input_ids}
|
116 |
+
|
117 |
+
model_inputs.update(
|
118 |
+
{
|
119 |
+
"past_key_values": past_key_values,
|
120 |
+
"use_cache": kwargs.get("use_cache"),
|
121 |
+
"attention_mask": attention_mask,
|
122 |
+
"images": kwargs.get("images", None),
|
123 |
+
}
|
124 |
+
)
|
125 |
+
return model_inputs
|
126 |
+
|
127 |
+
AutoConfig.register("seagull", SeagullConfig)
|
128 |
+
AutoModelForCausalLM.register(SeagullConfig, SeagullLlamaForCausalLM)
|
seagull/model/layer.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import Optional, Tuple, Type, Any
|
5 |
+
from torch import Tensor
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
class MLP(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
|
13 |
+
num_layers: int) -> None:
|
14 |
+
super().__init__()
|
15 |
+
self.num_layers = num_layers
|
16 |
+
h = [hidden_dim] * (num_layers - 1)
|
17 |
+
self.layers = nn.ModuleList(
|
18 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
for i, layer in enumerate(self.layers):
|
22 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
23 |
+
return x
|
24 |
+
|
25 |
+
class MaskExtractor(nn.Module): # Mask-based Feature Extractor
|
26 |
+
def __init__(self, mask_shape=112, embed_dim=1024, out_dim=4096, num_heads=8, mlp_dim=2048, downsample_rate=2, skip_first_layer_pe=False):
|
27 |
+
super(MaskExtractor, self).__init__()
|
28 |
+
self.mask_shape = mask_shape
|
29 |
+
self.mask_pooling = MaskPooling()
|
30 |
+
self.feat_linear = nn.Linear(embed_dim, out_dim)
|
31 |
+
self.cross_feat_linear = nn.Linear(embed_dim, out_dim)
|
32 |
+
self.mask_linear = MLP(mask_shape*mask_shape, embed_dim, out_dim, 3)
|
33 |
+
|
34 |
+
self.feature_name = ['res2', 'res3', 'res4', 'res5']
|
35 |
+
|
36 |
+
self.cross_att_res = CrossAttention(
|
37 |
+
embedding_dim=embed_dim,
|
38 |
+
num_heads=num_heads,
|
39 |
+
mlp_dim=mlp_dim,
|
40 |
+
douwnsample_rate=downsample_rate,
|
41 |
+
skip_first_layer_pe=skip_first_layer_pe
|
42 |
+
)
|
43 |
+
|
44 |
+
self.res2 = nn.Linear(192, 1024)
|
45 |
+
self.res3 = nn.Linear(384, 1024)
|
46 |
+
self.res4 = nn.Linear(768, 1024)
|
47 |
+
self.res5 = nn.Linear(1536, 1024)
|
48 |
+
|
49 |
+
self.g_res2 = nn.Linear(16384, 1024) # h * w
|
50 |
+
self.g_res3 = nn.Linear(4096, 1024)
|
51 |
+
self.g_res4 = nn.Linear(1024, 1024)
|
52 |
+
self.g_res5 = nn.Linear(256, 1024)
|
53 |
+
|
54 |
+
self.final_mlp = nn.Linear(2 * out_dim, out_dim)
|
55 |
+
|
56 |
+
self.global_vit = nn.Sequential(
|
57 |
+
nn.Conv2d(3, 5, 1),
|
58 |
+
nn.GELU(),
|
59 |
+
nn.AvgPool2d(4, 4),
|
60 |
+
|
61 |
+
nn.Conv2d(5, 1, 1),
|
62 |
+
nn.GELU(),
|
63 |
+
nn.AvgPool2d(4, 4),
|
64 |
+
)
|
65 |
+
self.is_first = 0
|
66 |
+
|
67 |
+
self.sa = Attention(32 * 32, num_heads) # self-attention
|
68 |
+
self.mlp = MLP(32 * 32, 512, out_dim, 3)
|
69 |
+
|
70 |
+
def cal_globa_local(self, mask_feat_raw, feat_new, res, g_res, cross_attention):
|
71 |
+
mask_feat_flatten = mask_feat_raw.to(device=res.weight.device, dtype=res.weight.dtype)
|
72 |
+
mask_feat = res(mask_feat_flatten) # (b, q, 1024)
|
73 |
+
|
74 |
+
feat_new = feat_new.to(device=g_res.weight.device, dtype=g_res.weight.dtype)
|
75 |
+
all_feat_new = g_res(feat_new) # (b, c, 1024)
|
76 |
+
global_mask = cross_attention(mask_feat, all_feat_new)
|
77 |
+
return mask_feat, global_mask
|
78 |
+
|
79 |
+
def forward(self, feats, masks, cropped_img):
|
80 |
+
global_features = []
|
81 |
+
local_features = []
|
82 |
+
num_imgs = len(masks)
|
83 |
+
|
84 |
+
for idx in range(num_imgs):
|
85 |
+
mask = masks[idx].unsqueeze(0).float() #(1, q, h, w)
|
86 |
+
cropped_ = cropped_img[idx] # (q, 3, h, w)
|
87 |
+
|
88 |
+
num_feats = len(self.feature_name)
|
89 |
+
mask_feats = mask.new_zeros(num_feats, mask.shape[1], 1024)
|
90 |
+
global_masks = mask.new_zeros(num_feats, mask.shape[1], 1024)
|
91 |
+
|
92 |
+
for i, name in enumerate(self.feature_name):
|
93 |
+
feat = feats[name][idx].unsqueeze(0)
|
94 |
+
feat = feat.to(mask.dtype)
|
95 |
+
|
96 |
+
mask_feat_raw = self.mask_pooling(feat, mask)
|
97 |
+
feat_new = rearrange(feat, 'b c h w -> b c (h w)')
|
98 |
+
|
99 |
+
mask_feat, global_mask = self.cal_globa_local(mask_feat_raw, feat_new, res=getattr(self, name), g_res=getattr(self, 'g_{}'.format(name)), cross_attention=getattr(self,"cross_att_res"))
|
100 |
+
|
101 |
+
mask_feats[i] = mask_feat.squeeze(0) # (q, 1024)
|
102 |
+
global_masks[i] = global_mask.squeeze(0)
|
103 |
+
mask_feats = mask_feats.sum(0) # (1, q, 1024)
|
104 |
+
global_masks = global_masks.sum(0) # (1, q, 1024)
|
105 |
+
global_masks = global_masks.to(device=self.cross_feat_linear.weight.device, dtype=self.cross_feat_linear.weight.dtype)
|
106 |
+
global_masks_linear = self.cross_feat_linear(global_masks)
|
107 |
+
mask_feats = mask_feats.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
|
108 |
+
mask_feats_linear = self.feat_linear(mask_feats) #(1, q, 4096)
|
109 |
+
|
110 |
+
query_feat = self.final_mlp(torch.cat((global_masks_linear, mask_feats_linear), dim=-1))
|
111 |
+
global_features.append(query_feat) # global
|
112 |
+
|
113 |
+
cropped_ = cropped_.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
|
114 |
+
global_features = self.global_vit(cropped_).to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype) # q, 1, 32, 32
|
115 |
+
global_features = global_features.reshape(-1, 1, 32 * 32) # q, 1, 32 * 32
|
116 |
+
pos_feat = self.mlp(self.sa(global_features, global_features, global_features).squeeze(1)) # q, output
|
117 |
+
|
118 |
+
local_features.append(pos_feat) #(imgs_num, 1, q, 4096) # local
|
119 |
+
|
120 |
+
return global_features, local_features
|
121 |
+
|
122 |
+
class MaskPooling(nn.Module):
|
123 |
+
def __init__(self):
|
124 |
+
super().__init__()
|
125 |
+
|
126 |
+
def forward(self, x, mask):
|
127 |
+
|
128 |
+
if not x.shape[-2:] == mask.shape[-2:]:
|
129 |
+
# reshape mask to x
|
130 |
+
mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
|
131 |
+
|
132 |
+
mask = (mask > 0).to(mask.dtype)
|
133 |
+
denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
|
134 |
+
|
135 |
+
mask_pooled_x = torch.einsum(
|
136 |
+
"bchw,bqhw->bqc",
|
137 |
+
x,
|
138 |
+
mask / denorm,
|
139 |
+
)
|
140 |
+
return mask_pooled_x
|
141 |
+
|
142 |
+
|
143 |
+
class CrossAttention(nn.Module):
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
embedding_dim: int,
|
147 |
+
num_heads: int,
|
148 |
+
mlp_dim: int = 2048,
|
149 |
+
douwnsample_rate: int = 2,
|
150 |
+
activation: Type[nn.Module] = nn.ReLU,
|
151 |
+
skip_first_layer_pe: bool = False
|
152 |
+
) -> None:
|
153 |
+
super().__init__()
|
154 |
+
self.embedding_dim = embedding_dim
|
155 |
+
self.num_heads =num_heads
|
156 |
+
self.self_attn = Attention(embedding_dim, num_heads) # self-attention
|
157 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
158 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
159 |
+
|
160 |
+
# cross-attention
|
161 |
+
self.cross_attn = Attention(embedding_dim, num_heads, downsample_rate=douwnsample_rate)
|
162 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
163 |
+
|
164 |
+
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) # MLP
|
165 |
+
|
166 |
+
def forward(self, queries, keys):
|
167 |
+
attn_out = self.self_attn(queries, queries, queries)
|
168 |
+
queries = queries + attn_out
|
169 |
+
queries = self.norm1(queries)
|
170 |
+
|
171 |
+
attn_out = self.cross_attn(q=queries, k=keys, v=keys)
|
172 |
+
queries = attn_out + queries
|
173 |
+
queries = self.norm2(queries)
|
174 |
+
|
175 |
+
# MLP
|
176 |
+
mlp_out = self.mlp(queries)
|
177 |
+
queries = queries + mlp_out
|
178 |
+
return queries
|
179 |
+
|
180 |
+
class MLPBlock(nn.Module):
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
embedding_dim: int,
|
184 |
+
mlp_dim: int,
|
185 |
+
act: Type[nn.Module] = nn.GELU,
|
186 |
+
) -> None:
|
187 |
+
super().__init__()
|
188 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
189 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
190 |
+
self.act = act()
|
191 |
+
|
192 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
193 |
+
return self.lin2(self.act(self.lin1(x)))
|
194 |
+
|
195 |
+
class Attention(nn.Module):
|
196 |
+
"""
|
197 |
+
An attention layer that allows for downscaling the size of the embedding
|
198 |
+
after projection to queries, keys, and values.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(
|
202 |
+
self,
|
203 |
+
embedding_dim: int,
|
204 |
+
num_heads: int,
|
205 |
+
downsample_rate: int = 1,
|
206 |
+
) -> None:
|
207 |
+
super().__init__()
|
208 |
+
self.embedding_dim = embedding_dim
|
209 |
+
self.internal_dim = embedding_dim // downsample_rate
|
210 |
+
self.num_heads = num_heads
|
211 |
+
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
212 |
+
|
213 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
214 |
+
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
215 |
+
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
216 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
217 |
+
|
218 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
219 |
+
b, n, c = x.shape
|
220 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
221 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
222 |
+
|
223 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
224 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
225 |
+
x = x.transpose(1, 2)
|
226 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
227 |
+
|
228 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
229 |
+
# Input projections
|
230 |
+
q = self.q_proj(q)
|
231 |
+
k = self.k_proj(k)
|
232 |
+
v = self.v_proj(v)
|
233 |
+
|
234 |
+
# Separate into heads
|
235 |
+
q = self._separate_heads(q, self.num_heads)
|
236 |
+
k = self._separate_heads(k, self.num_heads)
|
237 |
+
v = self._separate_heads(v, self.num_heads)
|
238 |
+
|
239 |
+
# Attention
|
240 |
+
_, _, _, c_per_head = q.shape
|
241 |
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
242 |
+
attn = attn / math.sqrt(c_per_head)
|
243 |
+
attn = torch.softmax(attn, dim=-1)
|
244 |
+
|
245 |
+
# Get output
|
246 |
+
out = attn @ v
|
247 |
+
out = self._recombine_heads(out)
|
248 |
+
out = self.out_proj(out)
|
249 |
+
|
250 |
+
return out
|
seagull/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (392 Bytes). View file
|
|
seagull/model/multimodal_encoder/__pycache__/clip.cpython-310.pyc
ADDED
Binary file (1.76 kB). View file
|
|
seagull/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc
ADDED
Binary file (2.21 kB). View file
|
|
seagull/model/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .clip_encoder import CLIPVisionTower
|
3 |
+
|
4 |
+
|
5 |
+
def build_vision_tower(vision_tower_cfg, delay_load=False):
|
6 |
+
|
7 |
+
return CLIPVisionTower(args=vision_tower_cfg)
|
seagull/model/multimodal_encoder/clip.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from open_clip.model import _build_vision_tower
|
6 |
+
|
7 |
+
|
8 |
+
class CLIP(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
model_name = 'convnext_large'
|
12 |
+
|
13 |
+
vision_cfg = {'timm_model_name': model_name, 'timm_model_pretrained': False, 'timm_pool': '', 'timm_proj': 'mlp', 'timm_drop': 0.0, 'timm_drop_path': 0.1, 'image_size': 320}
|
14 |
+
self.visual = _build_vision_tower(embed_dim=768, vision_cfg=vision_cfg, quick_gelu=False)
|
15 |
+
|
16 |
+
self.eval()
|
17 |
+
self.freeze_everything()
|
18 |
+
|
19 |
+
def freeze_everything(self):
|
20 |
+
for param in self.visual.parameters():
|
21 |
+
param.requires_grad = False
|
22 |
+
|
23 |
+
def extract_features(self, x):
|
24 |
+
out = {}
|
25 |
+
x = x.to(self.visual.trunk.stem.state_dict()['1.bias'].dtype)
|
26 |
+
x = self.visual.trunk.stem(x)
|
27 |
+
out['stem'] = x.contiguous()
|
28 |
+
for i in range(4):
|
29 |
+
x = self.visual.trunk.stages[i](x)
|
30 |
+
out[f'res{i+2}'] = x.contiguous()
|
31 |
+
|
32 |
+
x = self.visual.trunk.norm_pre(x)
|
33 |
+
out['clip_vis_dense'] = x.contiguous()
|
34 |
+
return out
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
self.eval()
|
38 |
+
with torch.no_grad():
|
39 |
+
return self.extract_features(x)
|
40 |
+
|
seagull/model/multimodal_encoder/clip_encoder.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import CLIPImageProcessor
|
5 |
+
from .clip import CLIP
|
6 |
+
|
7 |
+
class CLIPVisionTower(nn.Module):
|
8 |
+
def __init__(self, args, img_size=512, delay_load=False):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
# test
|
12 |
+
if hasattr(args, 'mm_vision_tower'):
|
13 |
+
self.clip_model = args.mm_vision_tower
|
14 |
+
else: # train
|
15 |
+
self.clip_model = args.vision_tower
|
16 |
+
self.is_loaded = False
|
17 |
+
self.img_size = img_size
|
18 |
+
|
19 |
+
if not delay_load:
|
20 |
+
self.load_model()
|
21 |
+
|
22 |
+
def load_model(self):
|
23 |
+
self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":self.img_size}, resample=3, do_center_crop=True, crop_size={"height": self.img_size, "width": self.img_size},
|
24 |
+
do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
|
25 |
+
image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
|
26 |
+
|
27 |
+
self.vision_tower = CLIP()
|
28 |
+
|
29 |
+
self.vision_tower.load_state_dict(torch.load(self.clip_model),strict=False)
|
30 |
+
|
31 |
+
self.is_loaded = True
|
32 |
+
|
33 |
+
@torch.no_grad()
|
34 |
+
def forward(self, images):
|
35 |
+
if type(images) is list:
|
36 |
+
image_features = []
|
37 |
+
image_features_dict = []
|
38 |
+
for image in images:
|
39 |
+
image_feature_dict = self.vision_tower(image.unsqueeze(0))
|
40 |
+
image_features_dict.append(image_feature_dict)
|
41 |
+
image_feature = image_feature_dict['res4']
|
42 |
+
image_feature = image_feature.reshape(*image_feature.shape[:2],-1).permute(0,2,1)
|
43 |
+
image_features.append(image_feature)
|
44 |
+
else:
|
45 |
+
# print(images.device)
|
46 |
+
# print(self.vision_tower.device)
|
47 |
+
image_features_dict = self.vision_tower(images)
|
48 |
+
image_features = image_features_dict['res4']
|
49 |
+
image_features = image_features.reshape(*image_features.shape[:2],-1).permute(0,2,1)
|
50 |
+
|
51 |
+
return image_features, image_features_dict
|
52 |
+
|
53 |
+
@property
|
54 |
+
def dtype(self):
|
55 |
+
return self.vision_tower.dtype
|
56 |
+
|
57 |
+
@property
|
58 |
+
def device(self):
|
59 |
+
return self.vision_tower.device
|
seagull/model/multimodal_projector/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (2.04 kB). View file
|
|
seagull/model/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
class IdentityMap(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
def forward(self, x, *args, **kwargs):
|
11 |
+
return x
|
12 |
+
|
13 |
+
@property
|
14 |
+
def config(self):
|
15 |
+
return {"mm_projector_type": 'identity'}
|
16 |
+
|
17 |
+
|
18 |
+
class SimpleResBlock(nn.Module):
|
19 |
+
def __init__(self, channels):
|
20 |
+
super().__init__()
|
21 |
+
self.pre_norm = nn.LayerNorm(channels)
|
22 |
+
|
23 |
+
self.proj = nn.Sequential(
|
24 |
+
nn.Linear(channels, channels),
|
25 |
+
nn.GELU(),
|
26 |
+
nn.Linear(channels, channels)
|
27 |
+
)
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.pre_norm(x)
|
30 |
+
return x + self.proj(x)
|
31 |
+
|
32 |
+
|
33 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
34 |
+
mm_hidden_size = getattr(config, 'mm_hidden_size', 768)
|
35 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
36 |
+
|
37 |
+
if projector_type == 'linear':
|
38 |
+
return nn.Linear(mm_hidden_size, config.hidden_size)
|
39 |
+
|
40 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
41 |
+
if mlp_gelu_match:
|
42 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
43 |
+
modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
|
44 |
+
for _ in range(1, mlp_depth):
|
45 |
+
modules.append(nn.GELU())
|
46 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
47 |
+
return nn.Sequential(*modules)
|
48 |
+
|
49 |
+
if projector_type == 'identity':
|
50 |
+
return IdentityMap()
|
51 |
+
|
52 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
seagull/model/seagull_arch.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .multimodal_encoder.builder import build_vision_tower
|
6 |
+
from .multimodal_projector.builder import build_vision_projector
|
7 |
+
|
8 |
+
from seagull.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
|
10 |
+
|
11 |
+
class SeagullMetaModel:
|
12 |
+
|
13 |
+
def __init__(self, config):
|
14 |
+
super(SeagullMetaModel, self).__init__(config)
|
15 |
+
|
16 |
+
if hasattr(config, "mm_vision_tower"):
|
17 |
+
self.vision_tower = build_vision_tower(config, delay_load=False)
|
18 |
+
self.mm_projector = build_vision_projector(config)
|
19 |
+
|
20 |
+
def get_vision_tower(self):
|
21 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
22 |
+
if type(vision_tower) is list:
|
23 |
+
vision_tower = vision_tower[0]
|
24 |
+
return vision_tower
|
25 |
+
|
26 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
27 |
+
|
28 |
+
vision_tower = model_args.vision_tower
|
29 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
30 |
+
|
31 |
+
if not hasattr(self.config, "mm_vision_tower"):
|
32 |
+
self.config.mm_vision_tower = vision_tower
|
33 |
+
|
34 |
+
vision_tower = build_vision_tower(model_args)
|
35 |
+
|
36 |
+
if fsdp is not None and len(fsdp) > 0:
|
37 |
+
self.vision_tower = [self.vision_tower]
|
38 |
+
else:
|
39 |
+
self.vision_tower = vision_tower
|
40 |
+
|
41 |
+
self.config.use_mm_proj = True
|
42 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
43 |
+
|
44 |
+
self.mm_projector = build_vision_projector(self.config)
|
45 |
+
|
46 |
+
if pretrain_mm_mlp_adapter is not None:
|
47 |
+
print("***********load projector_weights********")
|
48 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
49 |
+
def get_w(weights, keyword):
|
50 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
51 |
+
|
52 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
class SeagullMetaForCausalLM(ABC):
|
57 |
+
def __init__(self):
|
58 |
+
super(SeagullMetaForCausalLM, self).__init__()
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def get_model(self):
|
62 |
+
pass
|
63 |
+
|
64 |
+
def get_vision_tower(self):
|
65 |
+
return self.get_model().get_vision_tower()
|
66 |
+
|
67 |
+
def encode_images(self, images):
|
68 |
+
image_features, image_features_dict = self.get_model().get_vision_tower()(images)
|
69 |
+
self.get_model().mm_projector.to(device=image_features.device, dtype=image_features.dtype)
|
70 |
+
image_features = self.get_model().mm_projector(image_features)
|
71 |
+
return image_features, image_features_dict
|
72 |
+
|
73 |
+
def prepare_inputs_labels_for_multimodal(
|
74 |
+
self, input_ids, masks, attention_mask, past_key_values, labels, images, preprocessed_img_dict=None, cropped_img=None
|
75 |
+
):
|
76 |
+
vision_tower = self.get_vision_tower()
|
77 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
78 |
+
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
|
79 |
+
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
80 |
+
return input_ids, attention_mask, past_key_values, None, labels
|
81 |
+
|
82 |
+
if preprocessed_img_dict is not None:
|
83 |
+
image_features, image_features_dict = images, preprocessed_img_dict
|
84 |
+
else:
|
85 |
+
if type(images) is list or images.ndim == 5:
|
86 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
87 |
+
image_features, image_features_dict = self.encode_images(concat_images)
|
88 |
+
split_sizes = [image.shape[0] for image in images]
|
89 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
90 |
+
image_features = [x.flatten(0, 1).to(concat_images.device) for x in image_features]
|
91 |
+
else:
|
92 |
+
image_features, image_features_dict = self.encode_images(images)
|
93 |
+
|
94 |
+
|
95 |
+
mask_feats, pos_feats = self.mask_extractor(image_features_dict, masks, cropped_img=cropped_img)
|
96 |
+
|
97 |
+
new_input_embeds = []
|
98 |
+
new_labels = [] if labels is not None else None
|
99 |
+
cur_image_idx = 0
|
100 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
101 |
+
|
102 |
+
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
103 |
+
# multimodal LLM, but the current sample is not multimodal
|
104 |
+
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
105 |
+
half_len = cur_input_ids.shape[0] // 2
|
106 |
+
cur_image_features = image_features[cur_image_idx]
|
107 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
|
108 |
+
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
|
109 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
110 |
+
new_input_embeds.append(cur_input_embeds)
|
111 |
+
if labels is not None:
|
112 |
+
new_labels.append(labels[batch_idx])
|
113 |
+
cur_image_idx += 1
|
114 |
+
continue
|
115 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
116 |
+
cur_new_input_embeds = []
|
117 |
+
if labels is not None:
|
118 |
+
cur_labels = labels[batch_idx]
|
119 |
+
cur_new_labels = []
|
120 |
+
assert cur_labels.shape == cur_input_ids.shape
|
121 |
+
while image_token_indices.numel() > 0:
|
122 |
+
cur_image_features = image_features[cur_image_idx]
|
123 |
+
image_token_start = image_token_indices[0]
|
124 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
125 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
|
126 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
|
127 |
+
cur_new_input_embeds.append(cur_image_features)
|
128 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
|
129 |
+
if labels is not None:
|
130 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
131 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
132 |
+
cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
|
133 |
+
cur_labels = cur_labels[image_token_start+2:]
|
134 |
+
else:
|
135 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
|
136 |
+
cur_new_input_embeds.append(cur_image_features)
|
137 |
+
if labels is not None:
|
138 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
139 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
140 |
+
cur_labels = cur_labels[image_token_start+1:]
|
141 |
+
cur_image_idx += 1
|
142 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
143 |
+
cur_input_ids = cur_input_ids[image_token_start+2:]
|
144 |
+
else:
|
145 |
+
cur_input_ids = cur_input_ids[image_token_start+1:]
|
146 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
147 |
+
if cur_input_ids.numel() > 0:
|
148 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
149 |
+
mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<global>'])[0])
|
150 |
+
|
151 |
+
_l = 0
|
152 |
+
for i, idx in enumerate(mask_idx):
|
153 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
|
154 |
+
## mask
|
155 |
+
cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].detach())
|
156 |
+
## pos
|
157 |
+
cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].detach())
|
158 |
+
if labels is not None:
|
159 |
+
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
|
160 |
+
_l = idx[0]+2
|
161 |
+
if _l< len(cur_input_ids):
|
162 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]).detach())
|
163 |
+
|
164 |
+
else:
|
165 |
+
|
166 |
+
mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<global>'])[0])
|
167 |
+
assert len(mask_idx) == len(mask_feats[batch_idx]), "mask num not equal to mask feats"
|
168 |
+
|
169 |
+
_l = 0
|
170 |
+
for i, idx in enumerate(mask_idx):
|
171 |
+
cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
|
172 |
+
cur_new_input_embeds.append(cur_raw_new_input_embeds)
|
173 |
+
## mask
|
174 |
+
cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
|
175 |
+
## pos
|
176 |
+
cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
|
177 |
+
|
178 |
+
if labels is not None:
|
179 |
+
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
|
180 |
+
|
181 |
+
_l = idx[0]+2
|
182 |
+
if _l< len(cur_input_ids):
|
183 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]))
|
184 |
+
|
185 |
+
if labels is not None:
|
186 |
+
cur_new_labels.append(cur_labels)
|
187 |
+
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
188 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
189 |
+
|
190 |
+
new_input_embeds.append(cur_new_input_embeds)
|
191 |
+
if labels is not None:
|
192 |
+
cur_new_labels = torch.cat(cur_new_labels, dim=0)
|
193 |
+
new_labels.append(cur_new_labels)
|
194 |
+
|
195 |
+
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
196 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
197 |
+
|
198 |
+
new_input_embeds_align = []
|
199 |
+
for cur_new_embed in new_input_embeds:
|
200 |
+
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
|
201 |
+
new_input_embeds_align.append(cur_new_embed)
|
202 |
+
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
203 |
+
|
204 |
+
if labels is not None:
|
205 |
+
new_labels_align = []
|
206 |
+
_new_labels = new_labels
|
207 |
+
for cur_new_label in new_labels:
|
208 |
+
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
|
209 |
+
new_labels_align.append(cur_new_label)
|
210 |
+
new_labels = torch.stack(new_labels_align, dim=0)
|
211 |
+
|
212 |
+
if attention_mask is not None:
|
213 |
+
new_attention_mask = []
|
214 |
+
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
|
215 |
+
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
216 |
+
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
|
217 |
+
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
218 |
+
new_attention_mask.append(cur_new_attention_mask)
|
219 |
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
220 |
+
assert attention_mask.shape == new_labels.shape
|
221 |
+
else:
|
222 |
+
new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
223 |
+
if labels is not None:
|
224 |
+
new_labels = torch.stack(new_labels, dim=0)
|
225 |
+
|
226 |
+
if attention_mask is not None:
|
227 |
+
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
228 |
+
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
229 |
+
assert attention_mask.shape == new_input_embeds.shape[:2]
|
230 |
+
|
231 |
+
return None, attention_mask, past_key_values, new_input_embeds, new_labels
|
232 |
+
|
233 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
234 |
+
if model_args.mm_use_im_patch_token:
|
235 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
236 |
+
self.resize_token_embeddings(len(tokenizer))
|
237 |
+
|
238 |
+
mask_tokens = ['<global>', '<pos>']
|
239 |
+
num_new_tokens = tokenizer.add_tokens(mask_tokens, special_tokens=True)
|
240 |
+
|
241 |
+
if model_args.mm_use_im_start_end:
|
242 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
243 |
+
self.resize_token_embeddings(len(tokenizer))
|
244 |
+
|
245 |
+
if num_new_tokens > 0:
|
246 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
247 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
248 |
+
|
249 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
250 |
+
dim=0, keepdim=True)
|
251 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
252 |
+
dim=0, keepdim=True)
|
253 |
+
|
254 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
255 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
256 |
+
|
257 |
+
if model_args.tune_mm_mlp_adapter:
|
258 |
+
for p in self.get_input_embeddings().parameters():
|
259 |
+
p.requires_grad = True
|
260 |
+
for p in self.get_output_embeddings().parameters():
|
261 |
+
p.requires_grad = False
|
262 |
+
|
263 |
+
if model_args.pretrain_mm_mlp_adapter:
|
264 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
265 |
+
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
266 |
+
assert num_new_tokens == 2
|
267 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
268 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
269 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
270 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
271 |
+
else:
|
272 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
273 |
+
elif model_args.mm_use_im_patch_token:
|
274 |
+
if model_args.tune_mm_mlp_adapter:
|
275 |
+
for p in self.get_input_embeddings().parameters():
|
276 |
+
p.requires_grad = False
|
277 |
+
for p in self.get_output_embeddings().parameters():
|
278 |
+
p.requires_grad = False
|
279 |
+
|
280 |
+
for m in self.modules():
|
281 |
+
m.tokenizer = tokenizer
|
seagull/train/__pycache__/seagull_trainer.cpython-310.pyc
ADDED
Binary file (8.35 kB). View file
|
|