Spaces:
Running
on
Zero
Running
on
Zero
first commit
Browse files- .gitignore +59 -0
- app.py +497 -0
- images/college.png +0 -0
- images/demo.png +0 -0
- images/diseases.png +0 -0
- images/immigrants.png +0 -0
- images/market.png +0 -0
- images/sails.png +0 -0
- pyproject.toml +37 -0
- requirements.txt +23 -0
- scripts/calculate_metric.py +72 -0
- scripts/merge_jsonl_sort.py +26 -0
- scripts/split_jsonl_dataset.py +40 -0
- tinychart/__init__.py +1 -0
- tinychart/arguments.py +77 -0
- tinychart/constants.py +13 -0
- tinychart/conversation.py +491 -0
- tinychart/data/__init__.py +0 -0
- tinychart/data/dataset.py +185 -0
- tinychart/data/preprocess/__init__.py +0 -0
- tinychart/data/preprocess/default.py +104 -0
- tinychart/data/preprocess/phi.py +100 -0
- tinychart/data/preprocess/v1.py +120 -0
- tinychart/data/process.py +83 -0
- tinychart/eval/__init__.py +0 -0
- tinychart/eval/eval_metric.py +159 -0
- tinychart/eval/eval_model.py +139 -0
- tinychart/eval/run_eval.py +72 -0
- tinychart/eval/run_tiny_chart.py +127 -0
- tinychart/mm_utils.py +111 -0
- tinychart/model/__init__.py +1 -0
- tinychart/model/builder.py +127 -0
- tinychart/model/language_model/__init__.py +0 -0
- tinychart/model/language_model/llava_phi.py +164 -0
- tinychart/model/language_model/phi/cache_utils.py +322 -0
- tinychart/model/language_model/phi/configuration_phi.py +186 -0
- tinychart/model/language_model/phi/convert_phi_weights_to_hf.py +175 -0
- tinychart/model/language_model/phi/modeling_attn_mask_utils.py +497 -0
- tinychart/model/language_model/phi/modeling_phi.py +1345 -0
- tinychart/model/language_model/phi/utils.py +1428 -0
- tinychart/model/llava_arch.py +383 -0
- tinychart/model/model_factory.py +64 -0
- tinychart/model/multimodal_encoder/builder.py +7 -0
- tinychart/model/multimodal_encoder/merge.py +239 -0
- tinychart/model/multimodal_encoder/siglip_encoder.py +751 -0
- tinychart/model/multimodal_projector/builder.py +215 -0
- tinychart/utils.py +134 -0
.gitignore
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# These are some examples of commonly ignored file patterns.
|
2 |
+
# You should customize this list as applicable to your project.
|
3 |
+
# Learn more about .gitignore:
|
4 |
+
# https://www.atlassian.com/git/tutorials/saving-changes/gitignore
|
5 |
+
|
6 |
+
# Node artifact files
|
7 |
+
node_modules/
|
8 |
+
dist/
|
9 |
+
|
10 |
+
# Compiled Java class files
|
11 |
+
*.class
|
12 |
+
|
13 |
+
# Compiled Python bytecode
|
14 |
+
*.py[cod]
|
15 |
+
|
16 |
+
# Log files
|
17 |
+
*.log
|
18 |
+
|
19 |
+
# Package files
|
20 |
+
*.jar
|
21 |
+
|
22 |
+
# Maven
|
23 |
+
target/
|
24 |
+
|
25 |
+
# JetBrains IDE
|
26 |
+
.idea/
|
27 |
+
|
28 |
+
# Unit test reports
|
29 |
+
TEST*.xml
|
30 |
+
|
31 |
+
# Generated by MacOS
|
32 |
+
.DS_Store
|
33 |
+
|
34 |
+
Thumbs.db
|
35 |
+
|
36 |
+
# Applications
|
37 |
+
*.app
|
38 |
+
*.exe
|
39 |
+
*.war
|
40 |
+
|
41 |
+
# Large media files
|
42 |
+
*.mp4
|
43 |
+
*.tiff
|
44 |
+
*.avi
|
45 |
+
*.flv
|
46 |
+
*.mov
|
47 |
+
*.wmv
|
48 |
+
|
49 |
+
.ipynb_checkpoints
|
50 |
+
__pycache__
|
51 |
+
*.egg-info
|
52 |
+
.vscode/*
|
53 |
+
.idea/*
|
54 |
+
playground/
|
55 |
+
|
56 |
+
checkpoints
|
57 |
+
.logs
|
58 |
+
core-*
|
59 |
+
*/.nfs*
|
app.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import hashlib
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
from threading import Thread
|
7 |
+
import logging
|
8 |
+
import gradio as gr
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from tinychart.model.builder import load_pretrained_model
|
12 |
+
from tinychart.mm_utils import (
|
13 |
+
KeywordsStoppingCriteria,
|
14 |
+
load_image_from_base64,
|
15 |
+
process_images,
|
16 |
+
tokenizer_image_token,
|
17 |
+
get_model_name_from_path,
|
18 |
+
)
|
19 |
+
from PIL import Image
|
20 |
+
from io import BytesIO
|
21 |
+
import base64
|
22 |
+
import torch
|
23 |
+
from transformers import StoppingCriteria
|
24 |
+
from tinychart.constants import (
|
25 |
+
DEFAULT_IM_END_TOKEN,
|
26 |
+
DEFAULT_IM_START_TOKEN,
|
27 |
+
DEFAULT_IMAGE_TOKEN,
|
28 |
+
IMAGE_TOKEN_INDEX,
|
29 |
+
)
|
30 |
+
from tinychart.conversation import SeparatorStyle, conv_templates, default_conversation
|
31 |
+
from tinychart.eval.eval_metric import parse_model_output, evaluate_cmds
|
32 |
+
|
33 |
+
from transformers import TextIteratorStreamer
|
34 |
+
from pathlib import Path
|
35 |
+
|
36 |
+
DEFAULT_MODEL_PATH = "mPLUG/TinyChart-3B-768"
|
37 |
+
DEFAULT_MODEL_NAME = "TinyChart-3B-768"
|
38 |
+
|
39 |
+
|
40 |
+
block_css = """
|
41 |
+
|
42 |
+
#buttons button {
|
43 |
+
min-width: min(120px,100%);
|
44 |
+
}
|
45 |
+
"""
|
46 |
+
title_markdown = """
|
47 |
+
# TinyChart: Efficient Chart Understanding with Visual Token Merging and Program-of-Thoughts Learning
|
48 |
+
🔗 [[Code](https://github.com/X-PLUG/mPLUG-DocOwl/tree/main/TinyChart)] | 📚 [[Paper](https://arxiv.org/abs/2404.16635)]
|
49 |
+
|
50 |
+
**Note:**
|
51 |
+
1. Currently, this demo only supports English chart understanding and may not work well with other languages.
|
52 |
+
2. To use Program-of-Thoughts answer, please append "Answer with detailed steps." to your question.
|
53 |
+
"""
|
54 |
+
tos_markdown = """
|
55 |
+
### Terms of use
|
56 |
+
By using this service, users are required to agree to the following terms:
|
57 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
|
58 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def regenerate(state, image_process_mode):
|
62 |
+
state.messages[-1][-1] = None
|
63 |
+
prev_human_msg = state.messages[-2]
|
64 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
65 |
+
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
66 |
+
state.skip_next = False
|
67 |
+
return (state, state.to_gradio_chatbot(), "", None)
|
68 |
+
|
69 |
+
|
70 |
+
def clear_history():
|
71 |
+
state = default_conversation.copy()
|
72 |
+
return (state, state.to_gradio_chatbot(), "", None)
|
73 |
+
|
74 |
+
|
75 |
+
def add_text(state, text, image, image_process_mode):
|
76 |
+
if len(text) <= 0 and image is None:
|
77 |
+
state.skip_next = True
|
78 |
+
return (state, state.to_gradio_chatbot(), "", None)
|
79 |
+
|
80 |
+
text = text[:1536] # Hard cut-off
|
81 |
+
if image is not None:
|
82 |
+
text = text[:1200] # Hard cut-off for images
|
83 |
+
if "<image>" not in text:
|
84 |
+
# text = '<Image><image></Image>' + text
|
85 |
+
# text = text + "\n<image>"
|
86 |
+
text = "<image>\n"+text
|
87 |
+
text = (text, image, image_process_mode)
|
88 |
+
if len(state.get_images(return_pil=True)) > 0:
|
89 |
+
state = default_conversation.copy()
|
90 |
+
state.append_message(state.roles[0], text)
|
91 |
+
state.append_message(state.roles[1], None)
|
92 |
+
state.skip_next = False
|
93 |
+
return (state, state.to_gradio_chatbot(), "", None)
|
94 |
+
|
95 |
+
|
96 |
+
def load_demo():
|
97 |
+
state = default_conversation.copy()
|
98 |
+
return state
|
99 |
+
|
100 |
+
def is_float(value):
|
101 |
+
try:
|
102 |
+
float(value)
|
103 |
+
return True
|
104 |
+
except ValueError:
|
105 |
+
return False
|
106 |
+
|
107 |
+
|
108 |
+
@torch.inference_mode()
|
109 |
+
def get_response(params):
|
110 |
+
prompt = params["prompt"]
|
111 |
+
ori_prompt = prompt
|
112 |
+
images = params.get("images", None)
|
113 |
+
num_image_tokens = 0
|
114 |
+
if images is not None and len(images) > 0:
|
115 |
+
if len(images) > 0:
|
116 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
117 |
+
raise ValueError(
|
118 |
+
"Number of images does not match number of <image> tokens in prompt"
|
119 |
+
)
|
120 |
+
|
121 |
+
images = [load_image_from_base64(image) for image in images]
|
122 |
+
images = process_images(images, image_processor, model.config)
|
123 |
+
|
124 |
+
if type(images) is list:
|
125 |
+
images = [
|
126 |
+
image.to(model.device, dtype=torch.float16) for image in images
|
127 |
+
]
|
128 |
+
else:
|
129 |
+
images = images.to(model.device, dtype=torch.float16)
|
130 |
+
|
131 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
132 |
+
if getattr(model.config, "mm_use_im_start_end", False):
|
133 |
+
replace_token = (
|
134 |
+
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
135 |
+
)
|
136 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
137 |
+
|
138 |
+
if hasattr(model.get_vision_tower().config, "tome_r"):
|
139 |
+
num_image_tokens = (
|
140 |
+
prompt.count(replace_token) * model.get_vision_tower().num_patches - 26 * model.get_vision_tower().config.tome_r
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
num_image_tokens = (
|
144 |
+
prompt.count(replace_token) * model.get_vision_tower().num_patches
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
images = None
|
148 |
+
image_args = {"images": images}
|
149 |
+
else:
|
150 |
+
images = None
|
151 |
+
image_args = {}
|
152 |
+
|
153 |
+
temperature = float(params.get("temperature", 1.0))
|
154 |
+
top_p = float(params.get("top_p", 1.0))
|
155 |
+
max_context_length = getattr(model.config, "max_position_embeddings", 2048)
|
156 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
157 |
+
stop_str = params.get("stop", None)
|
158 |
+
do_sample = True if temperature > 0.001 else False
|
159 |
+
logger.info(prompt)
|
160 |
+
input_ids = (
|
161 |
+
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
162 |
+
.unsqueeze(0)
|
163 |
+
.to(model.device)
|
164 |
+
)
|
165 |
+
keywords = [stop_str]
|
166 |
+
|
167 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
168 |
+
streamer = TextIteratorStreamer(
|
169 |
+
tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
|
170 |
+
)
|
171 |
+
|
172 |
+
max_new_tokens = min(
|
173 |
+
max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens
|
174 |
+
)
|
175 |
+
|
176 |
+
if max_new_tokens < 1:
|
177 |
+
yield json.dumps(
|
178 |
+
{
|
179 |
+
"text": ori_prompt
|
180 |
+
+ "Exceeds max token length. Please start a new conversation, thanks.",
|
181 |
+
"error_code": 0,
|
182 |
+
}
|
183 |
+
).encode() + b"\0"
|
184 |
+
return
|
185 |
+
|
186 |
+
# local inference
|
187 |
+
# BUG: If stopping_criteria is set, an error occur:
|
188 |
+
# RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0
|
189 |
+
generate_kwargs = dict(
|
190 |
+
inputs=input_ids,
|
191 |
+
do_sample=do_sample,
|
192 |
+
temperature=temperature,
|
193 |
+
top_p=top_p,
|
194 |
+
max_new_tokens=max_new_tokens,
|
195 |
+
streamer=streamer,
|
196 |
+
# stopping_criteria=[stopping_criteria],
|
197 |
+
use_cache=True,
|
198 |
+
**image_args,
|
199 |
+
)
|
200 |
+
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
201 |
+
thread.start()
|
202 |
+
logger.debug(ori_prompt)
|
203 |
+
logger.debug(generate_kwargs)
|
204 |
+
generated_text = ori_prompt
|
205 |
+
for new_text in streamer:
|
206 |
+
generated_text += new_text
|
207 |
+
if generated_text.endswith(stop_str):
|
208 |
+
generated_text = generated_text[: -len(stop_str)]
|
209 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
210 |
+
|
211 |
+
if '<step>' in generated_text and '</step>' in generated_text and '<comment>' in generated_text and '</comment>' in generated_text:
|
212 |
+
program = generated_text
|
213 |
+
program = '<comment>#' + program.split('ASSISTANT: <comment>#')[-1]
|
214 |
+
print(program)
|
215 |
+
try:
|
216 |
+
execuate_result = evaluate_cmds(parse_model_output(program))
|
217 |
+
if is_float(execuate_result):
|
218 |
+
execuate_result = round(float(execuate_result), 4)
|
219 |
+
generated_text += f'\n\nExecute result: {execuate_result}'
|
220 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
221 |
+
except:
|
222 |
+
# execuate_result = 'Failed.'
|
223 |
+
generated_text += f'\n\nIt seems the execution of the above code encounters bugs. I\'m trying to answer this question directly...'
|
224 |
+
ori_generated_text = generated_text + '\nDirect Answer: '
|
225 |
+
|
226 |
+
direct_prompt = ori_prompt.replace(' Answer with detailed steps.', '')
|
227 |
+
direct_input_ids = (
|
228 |
+
tokenizer_image_token(direct_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
229 |
+
.unsqueeze(0)
|
230 |
+
.to(model.device)
|
231 |
+
)
|
232 |
+
|
233 |
+
generate_kwargs = dict(
|
234 |
+
inputs=direct_input_ids,
|
235 |
+
do_sample=do_sample,
|
236 |
+
temperature=temperature,
|
237 |
+
top_p=top_p,
|
238 |
+
max_new_tokens=max_new_tokens,
|
239 |
+
streamer=streamer,
|
240 |
+
use_cache=True,
|
241 |
+
**image_args,
|
242 |
+
)
|
243 |
+
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
244 |
+
thread.start()
|
245 |
+
generated_text = ori_generated_text
|
246 |
+
for new_text in streamer:
|
247 |
+
generated_text += new_text
|
248 |
+
if generated_text.endswith(stop_str):
|
249 |
+
generated_text = generated_text[: -len(stop_str)]
|
250 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
def http_bot(state, temperature, top_p, max_new_tokens):
|
255 |
+
if state.skip_next:
|
256 |
+
# This generate call is skipped due to invalid inputs
|
257 |
+
yield (state, state.to_gradio_chatbot())
|
258 |
+
return
|
259 |
+
|
260 |
+
if len(state.messages) == state.offset + 2:
|
261 |
+
# First round of conversation
|
262 |
+
|
263 |
+
template_name = 'phi'
|
264 |
+
|
265 |
+
new_state = conv_templates[template_name].copy()
|
266 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
267 |
+
new_state.append_message(new_state.roles[1], None)
|
268 |
+
state = new_state
|
269 |
+
|
270 |
+
# Construct prompt
|
271 |
+
prompt = state.get_prompt()
|
272 |
+
|
273 |
+
all_images = state.get_images(return_pil=True)
|
274 |
+
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
275 |
+
|
276 |
+
# Make requests
|
277 |
+
# pload = {"model": model_name, "prompt": prompt, "temperature": float(temperature), "top_p": float(top_p),
|
278 |
+
# "max_new_tokens": min(int(max_new_tokens), 1536), "stop": (
|
279 |
+
# state.sep
|
280 |
+
# if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
|
281 |
+
# else state.sep2
|
282 |
+
# ), "images": state.get_images()}
|
283 |
+
|
284 |
+
pload = {
|
285 |
+
"model": model_name,
|
286 |
+
"prompt": prompt,
|
287 |
+
"temperature": float(temperature),
|
288 |
+
"top_p": float(top_p),
|
289 |
+
"max_new_tokens": min(int(max_new_tokens), 1536),
|
290 |
+
"stop": (
|
291 |
+
state.sep
|
292 |
+
if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
|
293 |
+
else state.sep2
|
294 |
+
), "images": state.get_images()}
|
295 |
+
|
296 |
+
state.messages[-1][-1] = "▌"
|
297 |
+
yield (state, state.to_gradio_chatbot())
|
298 |
+
|
299 |
+
# for stream
|
300 |
+
output = get_response(pload)
|
301 |
+
for chunk in output:
|
302 |
+
if chunk:
|
303 |
+
data = json.loads(chunk.decode().replace('\x00',''))
|
304 |
+
|
305 |
+
if data["error_code"] == 0:
|
306 |
+
output = data["text"][len(prompt) :].strip()
|
307 |
+
state.messages[-1][-1] = output + "▌"
|
308 |
+
yield (state, state.to_gradio_chatbot())
|
309 |
+
else:
|
310 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
311 |
+
state.messages[-1][-1] = output
|
312 |
+
yield (state, state.to_gradio_chatbot())
|
313 |
+
return
|
314 |
+
time.sleep(0.03)
|
315 |
+
|
316 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
317 |
+
yield (state, state.to_gradio_chatbot())
|
318 |
+
|
319 |
+
|
320 |
+
def build_demo():
|
321 |
+
textbox = gr.Textbox(
|
322 |
+
show_label=False, placeholder="Enter text and press ENTER", container=False
|
323 |
+
)
|
324 |
+
with gr.Blocks(title="TinyLLaVA", theme=gr.themes.Default(), css=block_css) as demo:
|
325 |
+
state = gr.State()
|
326 |
+
gr.Markdown(title_markdown)
|
327 |
+
|
328 |
+
with gr.Row():
|
329 |
+
with gr.Column(scale=5):
|
330 |
+
with gr.Row(elem_id="Model ID"):
|
331 |
+
gr.Dropdown(
|
332 |
+
choices=[DEFAULT_MODEL_NAME],
|
333 |
+
value=DEFAULT_MODEL_NAME,
|
334 |
+
interactive=True,
|
335 |
+
label="Model ID",
|
336 |
+
container=False,
|
337 |
+
)
|
338 |
+
imagebox = gr.Image(type="pil")
|
339 |
+
image_process_mode = gr.Radio(
|
340 |
+
["Crop", "Resize", "Pad", "Default"],
|
341 |
+
value="Default",
|
342 |
+
label="Preprocess for non-square image",
|
343 |
+
visible=False,
|
344 |
+
)
|
345 |
+
|
346 |
+
# cur_dir = os.path.dirname(os.path.abspath(__file__))
|
347 |
+
cur_dir = Path(__file__).parent
|
348 |
+
gr.Examples(
|
349 |
+
examples=[
|
350 |
+
[
|
351 |
+
f"{cur_dir}/examples/market.png",
|
352 |
+
"What is the highest number of companies in the domestic market? Answer with detailed steps.",
|
353 |
+
],
|
354 |
+
[
|
355 |
+
f"{cur_dir}/examples/college.png",
|
356 |
+
"What is the difference between Asians and Whites degree distribution? Answer with detailed steps."
|
357 |
+
],
|
358 |
+
[
|
359 |
+
f"{cur_dir}/examples/immigrants.png",
|
360 |
+
"How many immigrants are there in 1931?",
|
361 |
+
],
|
362 |
+
[
|
363 |
+
f"{cur_dir}/examples/sails.png",
|
364 |
+
"By how much percentage wholesale is less than retail? Answer with detailed steps."
|
365 |
+
],
|
366 |
+
[
|
367 |
+
f"{cur_dir}/examples/diseases.png",
|
368 |
+
"Is the median value of all the bars greater than 30? Answer with detailed steps.",
|
369 |
+
],
|
370 |
+
[
|
371 |
+
f"{cur_dir}/examples/economy.png",
|
372 |
+
"Which team has higher economy in 28 min?"
|
373 |
+
],
|
374 |
+
[
|
375 |
+
f"{cur_dir}/examples/workers.png",
|
376 |
+
"Generate underlying data table for the chart."
|
377 |
+
],
|
378 |
+
[
|
379 |
+
f"{cur_dir}/examples/sports.png",
|
380 |
+
"Create a brief summarization or extract key insights based on the chart image."
|
381 |
+
],
|
382 |
+
[
|
383 |
+
f"{cur_dir}/examples/albums.png",
|
384 |
+
"Redraw the chart with Python code."
|
385 |
+
]
|
386 |
+
],
|
387 |
+
inputs=[imagebox, textbox],
|
388 |
+
)
|
389 |
+
|
390 |
+
with gr.Accordion("Parameters", open=False) as _:
|
391 |
+
temperature = gr.Slider(
|
392 |
+
minimum=0.0,
|
393 |
+
maximum=1.0,
|
394 |
+
value=0.1,
|
395 |
+
step=0.1,
|
396 |
+
interactive=True,
|
397 |
+
label="Temperature",
|
398 |
+
)
|
399 |
+
top_p = gr.Slider(
|
400 |
+
minimum=0.0,
|
401 |
+
maximum=1.0,
|
402 |
+
value=0.7,
|
403 |
+
step=0.1,
|
404 |
+
interactive=True,
|
405 |
+
label="Top P",
|
406 |
+
)
|
407 |
+
max_output_tokens = gr.Slider(
|
408 |
+
minimum=0,
|
409 |
+
maximum=1024,
|
410 |
+
value=1024,
|
411 |
+
step=64,
|
412 |
+
interactive=True,
|
413 |
+
label="Max output tokens",
|
414 |
+
)
|
415 |
+
|
416 |
+
with gr.Column(scale=8):
|
417 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="Chatbot", height=550)
|
418 |
+
with gr.Row():
|
419 |
+
with gr.Column(scale=8):
|
420 |
+
textbox.render()
|
421 |
+
with gr.Column(scale=1, min_width=50):
|
422 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
423 |
+
with gr.Row(elem_id="buttons") as _:
|
424 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
|
425 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=True)
|
426 |
+
|
427 |
+
gr.Markdown(tos_markdown)
|
428 |
+
|
429 |
+
regenerate_btn.click(
|
430 |
+
regenerate,
|
431 |
+
[state, image_process_mode],
|
432 |
+
[state, chatbot, textbox, imagebox],
|
433 |
+
queue=False,
|
434 |
+
).then(
|
435 |
+
http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot]
|
436 |
+
)
|
437 |
+
|
438 |
+
clear_btn.click(
|
439 |
+
clear_history, None, [state, chatbot, textbox, imagebox], queue=False
|
440 |
+
)
|
441 |
+
|
442 |
+
textbox.submit(
|
443 |
+
add_text,
|
444 |
+
[state, textbox, imagebox, image_process_mode],
|
445 |
+
[state, chatbot, textbox, imagebox],
|
446 |
+
queue=False,
|
447 |
+
).then(
|
448 |
+
http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot]
|
449 |
+
)
|
450 |
+
|
451 |
+
submit_btn.click(
|
452 |
+
add_text,
|
453 |
+
[state, textbox, imagebox, image_process_mode],
|
454 |
+
[state, chatbot, textbox, imagebox],
|
455 |
+
queue=False,
|
456 |
+
).then(
|
457 |
+
http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot]
|
458 |
+
)
|
459 |
+
|
460 |
+
demo.load(load_demo, None, [state], queue=False)
|
461 |
+
return demo
|
462 |
+
|
463 |
+
|
464 |
+
def parse_args():
|
465 |
+
parser = argparse.ArgumentParser()
|
466 |
+
parser.add_argument("--host", type=str, default=None)
|
467 |
+
parser.add_argument("--port", type=int, default=None)
|
468 |
+
parser.add_argument("--share", default=None)
|
469 |
+
parser.add_argument("--model-path", type=str, default=DEFAULT_MODEL_PATH)
|
470 |
+
parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL_NAME)
|
471 |
+
parser.add_argument("--load-8bit", action="store_true")
|
472 |
+
parser.add_argument("--load-4bit", action="store_true")
|
473 |
+
args = parser.parse_args()
|
474 |
+
return args
|
475 |
+
|
476 |
+
|
477 |
+
if __name__ == "__main__":
|
478 |
+
logging.basicConfig(
|
479 |
+
level=logging.INFO,
|
480 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
481 |
+
)
|
482 |
+
logger = logging.getLogger(__name__)
|
483 |
+
logger.info(gr.__version__)
|
484 |
+
args = parse_args()
|
485 |
+
model_name = args.model_name
|
486 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
487 |
+
model_path=args.model_path,
|
488 |
+
model_base=None,
|
489 |
+
model_name=args.model_name,
|
490 |
+
device="cpu",
|
491 |
+
load_4bit=args.load_4bit,
|
492 |
+
load_8bit=args.load_8bit
|
493 |
+
)
|
494 |
+
|
495 |
+
demo = build_demo()
|
496 |
+
demo.queue()
|
497 |
+
demo.launch(server_name=args.host, server_port=args.port, share=args.share)
|
images/college.png
ADDED
images/demo.png
ADDED
images/diseases.png
ADDED
images/immigrants.png
ADDED
images/market.png
ADDED
images/sails.png
ADDED
pyproject.toml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=61.0"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "tinyllava"
|
7 |
+
version = "1.0.0"
|
8 |
+
description = "A Framework of Small-scale Large Multimodal Models."
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">=3.9"
|
11 |
+
classifiers = [
|
12 |
+
"Programming Language :: Python :: 3",
|
13 |
+
"License :: OSI Approved :: Apache Software License",
|
14 |
+
]
|
15 |
+
dependencies = [
|
16 |
+
"torch==2.0.1", "torchvision==0.15.2", "tiktoken",
|
17 |
+
"transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
|
18 |
+
"accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0",
|
19 |
+
"pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
|
20 |
+
"gradio==3.35.2", "gradio_client==0.2.9",
|
21 |
+
"requests", "httpx==0.24.0", "uvicorn", "fastapi",
|
22 |
+
"einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
|
23 |
+
]
|
24 |
+
|
25 |
+
[project.optional-dependencies]
|
26 |
+
train = ["deepspeed==0.9.5", "ninja", "wandb"]
|
27 |
+
|
28 |
+
[project.urls]
|
29 |
+
"Homepage" = "https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/TinyChart"
|
30 |
+
"Bug Tracker" = "https://github.com/X-PLUG/mPLUG-DocOwl/issues"
|
31 |
+
|
32 |
+
[tool.setuptools.packages.find]
|
33 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
34 |
+
|
35 |
+
[tool.wheel]
|
36 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
37 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
torchvision==0.15.2
|
3 |
+
tiktoken==0.5.2
|
4 |
+
transformers==4.37.2
|
5 |
+
tokenizers==0.15.1
|
6 |
+
sentencepiece==0.1.99
|
7 |
+
shortuuid==1.0.11
|
8 |
+
accelerate==0.21.0
|
9 |
+
peft==0.4.0
|
10 |
+
bitsandbytes==0.41.0
|
11 |
+
pydantic<2,>=1
|
12 |
+
markdown2[all]
|
13 |
+
numpy
|
14 |
+
scikit-learn==1.2.2
|
15 |
+
gradio==3.35.2
|
16 |
+
gradio_client==0.2.9
|
17 |
+
requests
|
18 |
+
httpx==0.24.0
|
19 |
+
uvicorn
|
20 |
+
fastapi
|
21 |
+
einops==0.6.1
|
22 |
+
einops-exts==0.0.4
|
23 |
+
timm==0.6.13
|
scripts/calculate_metric.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import pandas as pd
|
5 |
+
from collections import defaultdict
|
6 |
+
from tinychart.eval.eval_chartqa_metric import chartqa_evaluator, chartqapot_evaluator
|
7 |
+
from tinychart.eval.eval_chartqa_metric import chartqa_oracle_merger_evaluator, chartqa_rule_merger_evaluator
|
8 |
+
|
9 |
+
def read_jsonl(jsonl_path):
|
10 |
+
with open(jsonl_path, 'r') as f:
|
11 |
+
data = [json.loads(line) for line in f]
|
12 |
+
return data
|
13 |
+
|
14 |
+
def write_jsonl(data, jsonl_path):
|
15 |
+
with open(jsonl_path, 'w', encoding='utf-8') as f:
|
16 |
+
for item in data:
|
17 |
+
f.write(json.dumps(item) + '\n')
|
18 |
+
|
19 |
+
if __name__ == '__main__':
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument('--input', default='./output/')
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
result_files = os.listdir(args.input)
|
26 |
+
result_files = [f for f in result_files if f.endswith('.jsonl')]
|
27 |
+
result_files.sort()
|
28 |
+
direct_result, cot_result = None, None
|
29 |
+
|
30 |
+
dataset2metric = defaultdict(float)
|
31 |
+
for result_file in result_files:
|
32 |
+
# print(result_file)
|
33 |
+
dataset_name = '.'.join(result_file.split('.')[:-1])
|
34 |
+
file = os.path.join(args.input, result_file)
|
35 |
+
result_data = read_jsonl(file)
|
36 |
+
if 'chartqa-' in dataset_name:
|
37 |
+
direct_result, direct_acc = chartqa_evaluator(result_data, key='model_answer')
|
38 |
+
write_jsonl(direct_result, file)
|
39 |
+
dataset2metric[dataset_name] = round(direct_acc * 100, 2)
|
40 |
+
print(f'Direct Accuracy: {direct_acc}')
|
41 |
+
elif 'chartqagptpot-' in dataset_name or 'chartqatemplatepot-' in dataset_name:
|
42 |
+
pot_result, pot_acc, error_rate = chartqapot_evaluator(result_data)
|
43 |
+
write_jsonl(pot_result, file)
|
44 |
+
dataset2metric[dataset_name] = round(pot_acc * 100, 2)
|
45 |
+
print(f'PoT Accuracy: {pot_acc}')
|
46 |
+
print(f'PoT Error Rate: {error_rate}')
|
47 |
+
|
48 |
+
if direct_result is not None and pot_result is not None:
|
49 |
+
print("Calculate merging direct and pot results with simple divider")
|
50 |
+
oracle_results, oracle_acc = chartqa_oracle_merger_evaluator(direct_result, pot_result)
|
51 |
+
dataset2metric['merged-oracle'] = round(oracle_acc * 100, 2)
|
52 |
+
print(f'Oracle Merged Accuracy: {oracle_acc}')
|
53 |
+
write_jsonl(oracle_results, os.path.join(args.input, 'merged-oracle.jsonl'))
|
54 |
+
rule_results, rule_acc = chartqa_rule_merger_evaluator(direct_result, pot_result)
|
55 |
+
dataset2metric['merged-rule'] = round(rule_acc * 100, 2)
|
56 |
+
print(f'Rule Merged Accuracy: {rule_acc}')
|
57 |
+
write_jsonl(rule_results, os.path.join(args.input, 'merged-rule.jsonl'))
|
58 |
+
|
59 |
+
# save metrics into tsv with key as the first row
|
60 |
+
df = pd.DataFrame(dataset2metric, index=[0])
|
61 |
+
# if there is a metrics.tsv exists, add one in the name to avoid overwrite
|
62 |
+
tsv_name = os.path.join(args.input, 'metrics.tsv')
|
63 |
+
if os.path.exists(tsv_name):
|
64 |
+
# avoid overwrite. if there is metrics.1.tsv, name it metrics.2.tsv...
|
65 |
+
i = 1
|
66 |
+
tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
|
67 |
+
while os.path.exists(tsv_name):
|
68 |
+
i += 1
|
69 |
+
tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
|
70 |
+
df.to_csv(tsv_name, sep='\t', index=False)
|
71 |
+
print(f'Metrics saved at: {tsv_name}')
|
72 |
+
print(df)
|
scripts/merge_jsonl_sort.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
def read_jsonl(jsonl_path):
|
6 |
+
with open(jsonl_path, 'r') as f:
|
7 |
+
data = [json.loads(line) for line in f]
|
8 |
+
return data
|
9 |
+
|
10 |
+
def write_jsonl(data, jsonl_path):
|
11 |
+
with open(jsonl_path, 'w', encoding='utf-8') as f:
|
12 |
+
for item in data:
|
13 |
+
f.write(json.dumps(item) + '\n')
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument('--input', default='temp/')
|
18 |
+
parser.add_argument('--output', default='chartqa_val.json')
|
19 |
+
|
20 |
+
args = parser.parse_args()
|
21 |
+
files = os.listdir(args.input)
|
22 |
+
files.sort()
|
23 |
+
data = []
|
24 |
+
for file in files:
|
25 |
+
data.extend(read_jsonl(os.path.join(args.input, file)))
|
26 |
+
write_jsonl(data, args.output)
|
scripts/split_jsonl_dataset.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
def read_jsonl(jsonl_path):
|
7 |
+
with open(jsonl_path, 'r') as f:
|
8 |
+
data = [json.loads(line) for line in f]
|
9 |
+
return data
|
10 |
+
|
11 |
+
def write_jsonl(data, jsonl_path):
|
12 |
+
with open(jsonl_path, 'w', encoding='utf-8') as f:
|
13 |
+
for item in data:
|
14 |
+
f.write(json.dumps(item) + '\n')
|
15 |
+
|
16 |
+
if __name__ == '__main__':
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument('--input', default='all.json')
|
19 |
+
parser.add_argument('--output', default='./output/')
|
20 |
+
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
all_data = read_jsonl(args.input)
|
24 |
+
|
25 |
+
dataset2jsonl = defaultdict(list)
|
26 |
+
|
27 |
+
for item in all_data:
|
28 |
+
int_id = item['id'].split('_')[-1]
|
29 |
+
dataset_name_split = '_'.join(item['id'].split('_')[:-1])
|
30 |
+
|
31 |
+
if '-two_col-' in dataset_name_split:
|
32 |
+
dataset_name_split = dataset_name_split.replace('-two_col-', '-')
|
33 |
+
if '-multi_col-' in dataset_name_split:
|
34 |
+
dataset_name_split = dataset_name_split.replace('-multi_col-', '-')
|
35 |
+
|
36 |
+
dataset2jsonl[dataset_name_split].append(item)
|
37 |
+
|
38 |
+
for dataset_name_split, data in dataset2jsonl.items():
|
39 |
+
data.sort(key=lambda x: int(x['id'].split('_')[-1]))
|
40 |
+
write_jsonl(data, os.path.join(args.output, f'{dataset_name_split}.jsonl'))
|
tinychart/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from tinychart.model import *
|
tinychart/arguments.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Dict, Optional, Sequence, List
|
3 |
+
|
4 |
+
import transformers
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class ModelArguments:
|
8 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
9 |
+
version: Optional[str] = field(default="v0")
|
10 |
+
freeze_backbone: bool = field(default=False)
|
11 |
+
tune_mm_mlp_adapter: bool = field(default=False)
|
12 |
+
vision_tower: Optional[str] = field(default=None)
|
13 |
+
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
|
14 |
+
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
|
15 |
+
mm_projector_type: Optional[str] = field(default='linear')
|
16 |
+
mm_use_im_start_end: bool = field(default=False)
|
17 |
+
mm_use_im_patch_token: bool = field(default=True)
|
18 |
+
mm_patch_merge_type: Optional[str] = field(default='flat')
|
19 |
+
mm_vision_select_feature: Optional[str] = field(default="patch")
|
20 |
+
resampler_hidden_size: Optional[int] = field(default=768)
|
21 |
+
num_queries: Optional[int] = field(default=128)
|
22 |
+
num_resampler_layers: Optional[int] = field(default=3)
|
23 |
+
tune_vision_tower: bool = field(default=False)
|
24 |
+
tune_entire_model: bool = field(default=False)
|
25 |
+
tune_vit_from_layer: Optional[int] = field(default=100)
|
26 |
+
tune_embed_tokens: Optional[int] = field(default=False)
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class DataArguments:
|
31 |
+
data_path: str = field(default=None,
|
32 |
+
metadata={"help": "Path to the training data."})
|
33 |
+
eval_data_path: str = field(default=None,
|
34 |
+
metadata={"help": "Path to the evaluation data."})
|
35 |
+
lazy_preprocess: bool = False
|
36 |
+
is_multimodal: bool = False
|
37 |
+
image_folder: Optional[str] = field(default=None)
|
38 |
+
image_aspect_ratio: str = 'square'
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class TrainingArguments(transformers.TrainingArguments):
|
43 |
+
cache_dir: Optional[str] = field(default=None)
|
44 |
+
optim: str = field(default="adamw_torch")
|
45 |
+
remove_unused_columns: bool = field(default=False)
|
46 |
+
freeze_mm_mlp_adapter: bool = field(default=False)
|
47 |
+
mpt_attn_impl: Optional[str] = field(default="triton")
|
48 |
+
model_max_length: int = field(
|
49 |
+
default=512,
|
50 |
+
metadata={
|
51 |
+
"help":
|
52 |
+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
53 |
+
},
|
54 |
+
)
|
55 |
+
double_quant: bool = field(
|
56 |
+
default=True,
|
57 |
+
metadata={"help": "Compress the quantization statistics through double quantization."}
|
58 |
+
)
|
59 |
+
quant_type: str = field(
|
60 |
+
default="nf4",
|
61 |
+
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
62 |
+
)
|
63 |
+
bits: int = field(
|
64 |
+
default=16,
|
65 |
+
metadata={"help": "How many bits to use."}
|
66 |
+
)
|
67 |
+
lora_enable: bool = False
|
68 |
+
lora_r: int = 64
|
69 |
+
lora_alpha: int = 16
|
70 |
+
lora_dropout: float = 0.05
|
71 |
+
lora_weight_path: str = ""
|
72 |
+
lora_bias: str = "none"
|
73 |
+
mm_projector_lr: Optional[float] = None
|
74 |
+
group_by_modality_length: bool = field(default=False)
|
75 |
+
vision_tower_lr: Optional[float] = None
|
76 |
+
tune_vit_posemb_only: bool = field(default=False)
|
77 |
+
tune_vit_only: bool = field(default=False)
|
tinychart/constants.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
13 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
tinychart/conversation.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
MPT = auto()
|
11 |
+
PLAIN = auto()
|
12 |
+
LLAMA_2 = auto()
|
13 |
+
TINY_LLAMA = auto()
|
14 |
+
QWEN_2 = auto()
|
15 |
+
|
16 |
+
|
17 |
+
@dataclasses.dataclass
|
18 |
+
class Conversation:
|
19 |
+
"""A class that keeps all conversation history."""
|
20 |
+
system: str
|
21 |
+
roles: List[str]
|
22 |
+
messages: List[List[str]]
|
23 |
+
offset: int
|
24 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
25 |
+
sep: str = "###"
|
26 |
+
sep2: str = None
|
27 |
+
version: str = "Unknown"
|
28 |
+
|
29 |
+
skip_next: bool = False
|
30 |
+
|
31 |
+
def get_prompt(self):
|
32 |
+
messages = self.messages
|
33 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
34 |
+
messages = self.messages.copy()
|
35 |
+
init_role, init_msg = messages[0].copy()
|
36 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
37 |
+
if 'mmtag' in self.version:
|
38 |
+
messages[0] = (init_role, init_msg)
|
39 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
40 |
+
messages.insert(1, (self.roles[1], "Received."))
|
41 |
+
else:
|
42 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
43 |
+
|
44 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
45 |
+
ret = self.system + self.sep
|
46 |
+
for role, message in messages:
|
47 |
+
if message:
|
48 |
+
if type(message) is tuple:
|
49 |
+
message, _, _ = message
|
50 |
+
ret += role + ": " + message + self.sep
|
51 |
+
else:
|
52 |
+
ret += role + ":"
|
53 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
54 |
+
seps = [self.sep, self.sep2]
|
55 |
+
ret = self.system + seps[0]
|
56 |
+
for i, (role, message) in enumerate(messages):
|
57 |
+
if message:
|
58 |
+
if type(message) is tuple:
|
59 |
+
message, _, _ = message
|
60 |
+
ret += role + ": " + message + seps[i % 2]
|
61 |
+
else:
|
62 |
+
ret += role + ":"
|
63 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
64 |
+
ret = self.system + self.sep
|
65 |
+
for role, message in messages:
|
66 |
+
if message:
|
67 |
+
if type(message) is tuple:
|
68 |
+
message, _, _ = message
|
69 |
+
ret += role + message + self.sep
|
70 |
+
else:
|
71 |
+
ret += role
|
72 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
73 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
74 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
75 |
+
ret = ""
|
76 |
+
|
77 |
+
for i, (role, message) in enumerate(messages):
|
78 |
+
if i == 0:
|
79 |
+
assert message, "first message should not be none"
|
80 |
+
assert role == self.roles[0], "first message should come from user"
|
81 |
+
if message:
|
82 |
+
if type(message) is tuple:
|
83 |
+
message, _, _ = message
|
84 |
+
if i == 0: message = wrap_sys(self.system) + message
|
85 |
+
if i % 2 == 0:
|
86 |
+
message = wrap_inst(message)
|
87 |
+
ret += self.sep + message
|
88 |
+
else:
|
89 |
+
ret += " " + message + " " + self.sep2
|
90 |
+
else:
|
91 |
+
ret += ""
|
92 |
+
ret = ret.lstrip(self.sep)
|
93 |
+
elif self.sep_style == SeparatorStyle.TINY_LLAMA:
|
94 |
+
sep = "</s>"
|
95 |
+
wrap_sys = lambda msg: f"<|system|>\n{msg}\n"
|
96 |
+
wrap_user = lambda msg: f"<|user|>\n{msg}\n"
|
97 |
+
wrap_assistant = lambda msg: f"<|assistant|>\n{msg}"
|
98 |
+
ret = ""
|
99 |
+
|
100 |
+
for i, (role, message) in enumerate(messages):
|
101 |
+
if i == 0:
|
102 |
+
assert message, "first message should not be none"
|
103 |
+
assert role == self.roles[0], "first message should come from user"
|
104 |
+
if message:
|
105 |
+
if type(message) is tuple:
|
106 |
+
message, _, _ = message
|
107 |
+
if i % 2 == 0:
|
108 |
+
message = wrap_user(message)
|
109 |
+
if i == 0:
|
110 |
+
message = wrap_sys(self.system) + message
|
111 |
+
ret += self.sep + message
|
112 |
+
else:
|
113 |
+
message = wrap_assistant(message) + self.sep2
|
114 |
+
ret += message
|
115 |
+
else:
|
116 |
+
ret += "<|assistant|>\n"
|
117 |
+
ret = ret.lstrip(self.sep)
|
118 |
+
elif self.sep_style == SeparatorStyle.QWEN_2:
|
119 |
+
ret = self.system + self.sep
|
120 |
+
for role, message in messages:
|
121 |
+
if message:
|
122 |
+
if type(message) is tuple:
|
123 |
+
message, _, _ = message
|
124 |
+
ret += role + message + self.sep
|
125 |
+
else:
|
126 |
+
ret += role
|
127 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
128 |
+
seps = [self.sep, self.sep2]
|
129 |
+
ret = self.system
|
130 |
+
for i, (role, message) in enumerate(messages):
|
131 |
+
if message:
|
132 |
+
if type(message) is tuple:
|
133 |
+
message, _, _ = message
|
134 |
+
ret += message + seps[i % 2]
|
135 |
+
else:
|
136 |
+
ret += ""
|
137 |
+
else:
|
138 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
139 |
+
|
140 |
+
return ret
|
141 |
+
|
142 |
+
def append_message(self, role, message):
|
143 |
+
self.messages.append([role, message])
|
144 |
+
|
145 |
+
def get_images(self, return_pil=False):
|
146 |
+
images = []
|
147 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
148 |
+
if i % 2 == 0:
|
149 |
+
if type(msg) is tuple:
|
150 |
+
import base64
|
151 |
+
from io import BytesIO
|
152 |
+
from PIL import Image
|
153 |
+
msg, image, image_process_mode = msg
|
154 |
+
if image_process_mode == "Pad":
|
155 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
156 |
+
width, height = pil_img.size
|
157 |
+
if width == height:
|
158 |
+
return pil_img
|
159 |
+
elif width > height:
|
160 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
161 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
162 |
+
return result
|
163 |
+
else:
|
164 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
165 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
166 |
+
return result
|
167 |
+
image = expand2square(image)
|
168 |
+
elif image_process_mode in ["Default", "Crop"]:
|
169 |
+
pass
|
170 |
+
elif image_process_mode == "Resize":
|
171 |
+
image = image.resize((336, 336))
|
172 |
+
else:
|
173 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
174 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
175 |
+
aspect_ratio = max_hw / min_hw
|
176 |
+
max_len, min_len = 800, 400
|
177 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
178 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
179 |
+
W, H = image.size
|
180 |
+
if longest_edge != max(image.size):
|
181 |
+
if H > W:
|
182 |
+
H, W = longest_edge, shortest_edge
|
183 |
+
else:
|
184 |
+
H, W = shortest_edge, longest_edge
|
185 |
+
image = image.resize((W, H))
|
186 |
+
if return_pil:
|
187 |
+
images.append(image)
|
188 |
+
else:
|
189 |
+
buffered = BytesIO()
|
190 |
+
image.save(buffered, format="PNG")
|
191 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
192 |
+
images.append(img_b64_str)
|
193 |
+
return images
|
194 |
+
|
195 |
+
def to_gradio_chatbot(self):
|
196 |
+
ret = []
|
197 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
198 |
+
if i % 2 == 0:
|
199 |
+
if type(msg) is tuple:
|
200 |
+
import base64
|
201 |
+
from io import BytesIO
|
202 |
+
msg, image, image_process_mode = msg
|
203 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
204 |
+
aspect_ratio = max_hw / min_hw
|
205 |
+
max_len, min_len = 800, 400
|
206 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
207 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
208 |
+
W, H = image.size
|
209 |
+
if H > W:
|
210 |
+
H, W = longest_edge, shortest_edge
|
211 |
+
else:
|
212 |
+
H, W = shortest_edge, longest_edge
|
213 |
+
image = image.resize((W, H))
|
214 |
+
buffered = BytesIO()
|
215 |
+
image.save(buffered, format="JPEG")
|
216 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
217 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
218 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
219 |
+
ret.append([msg, None])
|
220 |
+
else:
|
221 |
+
ret.append([msg, None])
|
222 |
+
else:
|
223 |
+
ret[-1][-1] = msg
|
224 |
+
return ret
|
225 |
+
|
226 |
+
def copy(self):
|
227 |
+
return Conversation(
|
228 |
+
system=self.system,
|
229 |
+
roles=self.roles,
|
230 |
+
messages=[[x, y] for x, y in self.messages],
|
231 |
+
offset=self.offset,
|
232 |
+
sep_style=self.sep_style,
|
233 |
+
sep=self.sep,
|
234 |
+
sep2=self.sep2,
|
235 |
+
version=self.version)
|
236 |
+
|
237 |
+
def dict(self):
|
238 |
+
if len(self.get_images()) > 0:
|
239 |
+
return {
|
240 |
+
"system": self.system,
|
241 |
+
"roles": self.roles,
|
242 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
243 |
+
"offset": self.offset,
|
244 |
+
"sep": self.sep,
|
245 |
+
"sep2": self.sep2,
|
246 |
+
}
|
247 |
+
return {
|
248 |
+
"system": self.system,
|
249 |
+
"roles": self.roles,
|
250 |
+
"messages": self.messages,
|
251 |
+
"offset": self.offset,
|
252 |
+
"sep": self.sep,
|
253 |
+
"sep2": self.sep2,
|
254 |
+
}
|
255 |
+
|
256 |
+
|
257 |
+
conv_vicuna_v0 = Conversation(
|
258 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
259 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
260 |
+
roles=("Human", "Assistant"),
|
261 |
+
messages=(
|
262 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
263 |
+
("Assistant",
|
264 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
265 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
266 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
267 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
268 |
+
"renewable and non-renewable energy sources:\n"
|
269 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
270 |
+
"energy sources are finite and will eventually run out.\n"
|
271 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
272 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
273 |
+
"and other negative effects.\n"
|
274 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
275 |
+
"have lower operational costs than non-renewable sources.\n"
|
276 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
277 |
+
"locations than non-renewable sources.\n"
|
278 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
279 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
280 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
281 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
282 |
+
),
|
283 |
+
offset=2,
|
284 |
+
sep_style=SeparatorStyle.SINGLE,
|
285 |
+
sep="###",
|
286 |
+
)
|
287 |
+
|
288 |
+
conv_vicuna_v1 = Conversation(
|
289 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
290 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
291 |
+
roles=("USER", "ASSISTANT"),
|
292 |
+
version="v1",
|
293 |
+
messages=(),
|
294 |
+
offset=0,
|
295 |
+
sep_style=SeparatorStyle.TWO,
|
296 |
+
sep=" ",
|
297 |
+
sep2="</s>",
|
298 |
+
)
|
299 |
+
|
300 |
+
conv_llama_2 = Conversation(
|
301 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
302 |
+
|
303 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
304 |
+
roles=("USER", "ASSISTANT"),
|
305 |
+
version="llama_v2",
|
306 |
+
messages=(),
|
307 |
+
offset=0,
|
308 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
309 |
+
sep="<s>",
|
310 |
+
sep2="</s>",
|
311 |
+
)
|
312 |
+
|
313 |
+
conv_llava_llama_2 = Conversation(
|
314 |
+
system="You are a helpful language and vision assistant. "
|
315 |
+
"You are able to understand the visual content that the user provides, "
|
316 |
+
"and assist the user with a variety of tasks using natural language.",
|
317 |
+
roles=("USER", "ASSISTANT"),
|
318 |
+
version="llama_v2",
|
319 |
+
messages=(),
|
320 |
+
offset=0,
|
321 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
322 |
+
sep="<s>",
|
323 |
+
sep2="</s>",
|
324 |
+
)
|
325 |
+
|
326 |
+
conv_tiny_llava_tiny_llama = Conversation(
|
327 |
+
system="You are a helpful language and vision assistant. "
|
328 |
+
"You are able to understand the visual content that the user provides, "
|
329 |
+
"and assist the user with a variety of tasks using natural language.",
|
330 |
+
roles=("USER", "ASSISTANT"),
|
331 |
+
version="tiny_llama",
|
332 |
+
messages=(),
|
333 |
+
offset=0,
|
334 |
+
sep_style=SeparatorStyle.TINY_LLAMA,
|
335 |
+
sep="<s>",
|
336 |
+
sep2="</s>"
|
337 |
+
)
|
338 |
+
|
339 |
+
|
340 |
+
conv_mpt = Conversation(
|
341 |
+
system="""<|im_start|>system
|
342 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
343 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
344 |
+
version="mpt",
|
345 |
+
messages=(),
|
346 |
+
offset=0,
|
347 |
+
sep_style=SeparatorStyle.MPT,
|
348 |
+
sep="<|im_end|>",
|
349 |
+
)
|
350 |
+
|
351 |
+
conv_llava_plain = Conversation(
|
352 |
+
system="",
|
353 |
+
roles=("", ""),
|
354 |
+
messages=(
|
355 |
+
),
|
356 |
+
version='plain',
|
357 |
+
offset=0,
|
358 |
+
sep_style=SeparatorStyle.PLAIN,
|
359 |
+
sep="\n",
|
360 |
+
)
|
361 |
+
|
362 |
+
conv_llava_v0 = Conversation(
|
363 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
364 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
365 |
+
roles=("Human", "Assistant"),
|
366 |
+
messages=(
|
367 |
+
),
|
368 |
+
offset=0,
|
369 |
+
sep_style=SeparatorStyle.SINGLE,
|
370 |
+
sep="###",
|
371 |
+
)
|
372 |
+
|
373 |
+
conv_llava_v0_mmtag = Conversation(
|
374 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
375 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
376 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
377 |
+
roles=("Human", "Assistant"),
|
378 |
+
messages=(
|
379 |
+
),
|
380 |
+
offset=0,
|
381 |
+
sep_style=SeparatorStyle.SINGLE,
|
382 |
+
sep="###",
|
383 |
+
version="v0_mmtag",
|
384 |
+
)
|
385 |
+
|
386 |
+
conv_llava_v1 = Conversation(
|
387 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
388 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
389 |
+
roles=("USER", "ASSISTANT"),
|
390 |
+
version="v1",
|
391 |
+
messages=(),
|
392 |
+
offset=0,
|
393 |
+
sep_style=SeparatorStyle.TWO,
|
394 |
+
sep=" ",
|
395 |
+
sep2="</s>",
|
396 |
+
)
|
397 |
+
|
398 |
+
conv_llava_v1_mmtag = Conversation(
|
399 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
400 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
401 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
402 |
+
roles=("USER", "ASSISTANT"),
|
403 |
+
messages=(),
|
404 |
+
offset=0,
|
405 |
+
sep_style=SeparatorStyle.TWO,
|
406 |
+
sep=" ",
|
407 |
+
sep2="</s>",
|
408 |
+
version="v1_mmtag",
|
409 |
+
)
|
410 |
+
|
411 |
+
conv_phi_v0 = Conversation(
|
412 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
413 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
414 |
+
roles=("USER", "ASSISTANT"),
|
415 |
+
version="phi",
|
416 |
+
messages=(),
|
417 |
+
offset=0,
|
418 |
+
sep_style=SeparatorStyle.TWO,
|
419 |
+
sep=" ",
|
420 |
+
sep2="<|endoftext|>",
|
421 |
+
)
|
422 |
+
|
423 |
+
conv_stablelm = Conversation(
|
424 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
425 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
426 |
+
roles=("USER", "ASSISTANT"),
|
427 |
+
version="stablelm",
|
428 |
+
messages=(),
|
429 |
+
offset=0,
|
430 |
+
sep_style=SeparatorStyle.TWO,
|
431 |
+
sep=" ",
|
432 |
+
sep2="<|endoftext|>",
|
433 |
+
)
|
434 |
+
|
435 |
+
conv_mistral_instruct = Conversation(
|
436 |
+
system="",
|
437 |
+
roles=("USER", "ASSISTANT"),
|
438 |
+
version="llama_v2",
|
439 |
+
messages=(),
|
440 |
+
offset=0,
|
441 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
442 |
+
sep="",
|
443 |
+
sep2="</s>",
|
444 |
+
)
|
445 |
+
|
446 |
+
conv_chatml_direct = Conversation(
|
447 |
+
system="""<|im_start|>system
|
448 |
+
Answer the questions.""",
|
449 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
450 |
+
version="mpt",
|
451 |
+
messages=(),
|
452 |
+
offset=0,
|
453 |
+
sep_style=SeparatorStyle.MPT,
|
454 |
+
sep="<|im_end|>",
|
455 |
+
)
|
456 |
+
|
457 |
+
conv_qwen2 = Conversation(
|
458 |
+
system="<|im_start|>system\nYou are a helpful assistant",
|
459 |
+
roles=("<im_start>user\n", "<im_start>assistant\n"),
|
460 |
+
version="mpt",
|
461 |
+
messages=(),
|
462 |
+
offset=0,
|
463 |
+
sep_style=SeparatorStyle.MPT,
|
464 |
+
sep="<im_end>"
|
465 |
+
)
|
466 |
+
|
467 |
+
default_conversation = conv_vicuna_v1
|
468 |
+
conv_templates = {
|
469 |
+
"default": conv_vicuna_v0,
|
470 |
+
"v0": conv_vicuna_v0,
|
471 |
+
"v1": conv_vicuna_v1,
|
472 |
+
"vicuna_v1": conv_vicuna_v1,
|
473 |
+
"llama_2": conv_llama_2,
|
474 |
+
|
475 |
+
"plain": conv_llava_plain,
|
476 |
+
"v0_plain": conv_llava_plain,
|
477 |
+
"llava_v0": conv_llava_v0,
|
478 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
479 |
+
"llava_v1": conv_llava_v1,
|
480 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
481 |
+
"llava_llama_2": conv_llava_llama_2,
|
482 |
+
|
483 |
+
"mpt": conv_mpt,
|
484 |
+
|
485 |
+
"tiny_llama": conv_tiny_llava_tiny_llama,
|
486 |
+
"phi": conv_phi_v0,
|
487 |
+
}
|
488 |
+
|
489 |
+
|
490 |
+
if __name__ == "__main__":
|
491 |
+
print(default_conversation.get_prompt())
|
tinychart/data/__init__.py
ADDED
File without changes
|
tinychart/data/dataset.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import json
|
4 |
+
from typing import Dict, Sequence
|
5 |
+
|
6 |
+
|
7 |
+
import transformers
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from PIL import Image, ImageFile
|
11 |
+
|
12 |
+
from tinychart.arguments import *
|
13 |
+
from tinychart.utils import *
|
14 |
+
from tinychart.data.process import *
|
15 |
+
from tinychart.constants import *
|
16 |
+
|
17 |
+
|
18 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
19 |
+
|
20 |
+
class LazySupervisedDataset(Dataset):
|
21 |
+
"""Dataset for supervised fine-tuning."""
|
22 |
+
|
23 |
+
def __init__(self, data_path: str,
|
24 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
25 |
+
data_args: DataArguments):
|
26 |
+
super(LazySupervisedDataset, self).__init__()
|
27 |
+
list_data_dict = json.load(open(data_path, "r"))
|
28 |
+
|
29 |
+
rank0_print("Formatting inputs...Skip in lazy mode")
|
30 |
+
self.tokenizer = tokenizer
|
31 |
+
self.list_data_dict = list_data_dict
|
32 |
+
self.data_args = data_args
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.list_data_dict)
|
36 |
+
|
37 |
+
@property
|
38 |
+
def lengths(self):
|
39 |
+
length_list = []
|
40 |
+
for sample in self.list_data_dict:
|
41 |
+
img_tokens = 128 if 'image' in sample else 0
|
42 |
+
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
|
43 |
+
return length_list
|
44 |
+
|
45 |
+
@property
|
46 |
+
def modality_lengths(self):
|
47 |
+
length_list = []
|
48 |
+
for sample in self.list_data_dict:
|
49 |
+
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
|
50 |
+
cur_len = cur_len if 'image' in sample else -cur_len
|
51 |
+
length_list.append(cur_len)
|
52 |
+
return length_list
|
53 |
+
|
54 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
55 |
+
sources = self.list_data_dict[i]
|
56 |
+
if isinstance(i, int):
|
57 |
+
sources = [sources]
|
58 |
+
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
59 |
+
if 'image' in sources[0]:
|
60 |
+
image_file = self.list_data_dict[i]['image']
|
61 |
+
image_folder = self.data_args.image_folder
|
62 |
+
processor = self.data_args.image_processor
|
63 |
+
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
|
64 |
+
if self.data_args.image_aspect_ratio == 'pad':
|
65 |
+
def expand2square(pil_img, background_color):
|
66 |
+
width, height = pil_img.size
|
67 |
+
if width == height:
|
68 |
+
return pil_img
|
69 |
+
elif width > height:
|
70 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
71 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
72 |
+
return result
|
73 |
+
else:
|
74 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
75 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
76 |
+
return result
|
77 |
+
|
78 |
+
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
|
79 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
80 |
+
else:
|
81 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
82 |
+
sources = preprocess_multimodal(
|
83 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
84 |
+
self.data_args)
|
85 |
+
else:
|
86 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
87 |
+
data_dict = preprocess(
|
88 |
+
sources,
|
89 |
+
self.tokenizer,
|
90 |
+
has_image=('image' in self.list_data_dict[i]))
|
91 |
+
if isinstance(i, int):
|
92 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
93 |
+
labels=data_dict["labels"][0])
|
94 |
+
|
95 |
+
# image exist in the data
|
96 |
+
if 'image' in self.list_data_dict[i]:
|
97 |
+
data_dict['image'] = image
|
98 |
+
elif self.data_args.is_multimodal:
|
99 |
+
# image does not exist in the data, but the model is multimodal
|
100 |
+
crop_size = self.data_args.image_processor.crop_size
|
101 |
+
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
|
102 |
+
return data_dict
|
103 |
+
|
104 |
+
|
105 |
+
@dataclass
|
106 |
+
class DataCollatorForSupervisedDataset(object):
|
107 |
+
"""Collate examples for supervised fine-tuning."""
|
108 |
+
|
109 |
+
tokenizer: transformers.PreTrainedTokenizer
|
110 |
+
|
111 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
112 |
+
input_ids, labels = tuple([instance[key] for instance in instances]
|
113 |
+
for key in ("input_ids", "labels"))
|
114 |
+
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
115 |
+
for input_id in input_ids:
|
116 |
+
input_id[input_id == self.tokenizer.eos_token_id] = -300
|
117 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
118 |
+
input_ids,
|
119 |
+
batch_first=True,
|
120 |
+
padding_value=self.tokenizer.pad_token_id)
|
121 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
122 |
+
batch_first=True,
|
123 |
+
padding_value=IGNORE_INDEX)
|
124 |
+
input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
125 |
+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
|
126 |
+
labels = labels[:, :self.tokenizer.model_max_length]
|
127 |
+
# FIXME: This is a hack for handling phi and stablelm, as they have the same eos, pad and unk. We want the model
|
128 |
+
# FIXME: to predict the eos in the input ids, but we also use the id of eos to pad sequence, so we use a temp
|
129 |
+
# FIXME: eos id first, and convert them back.
|
130 |
+
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
131 |
+
for input_id in input_ids:
|
132 |
+
input_id[input_id == -300] = self.tokenizer.eos_token_id
|
133 |
+
|
134 |
+
batch = dict(
|
135 |
+
input_ids=input_ids,
|
136 |
+
labels=labels,
|
137 |
+
attention_mask=attention_mask,
|
138 |
+
)
|
139 |
+
|
140 |
+
if 'image' in instances[0]:
|
141 |
+
images = [instance['image'] for instance in instances]
|
142 |
+
if all(x is not None and x.shape == images[0].shape for x in images):
|
143 |
+
batch['images'] = torch.stack(images)
|
144 |
+
else:
|
145 |
+
batch['images'] = images
|
146 |
+
|
147 |
+
if 'question' in instances[0]:
|
148 |
+
questions = [instance['question'] for instance in instances]
|
149 |
+
batch['questions'] = questions
|
150 |
+
|
151 |
+
if 'answer' in instances[0]:
|
152 |
+
answers = [instance['answer'] for instance in instances]
|
153 |
+
batch['answers'] = answers
|
154 |
+
|
155 |
+
return batch
|
156 |
+
|
157 |
+
|
158 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
|
159 |
+
data_args) -> Dict:
|
160 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
161 |
+
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
162 |
+
data_path=data_args.data_path,
|
163 |
+
data_args=data_args)
|
164 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
165 |
+
return dict(train_dataset=train_dataset,
|
166 |
+
eval_dataset=None,
|
167 |
+
data_collator=data_collator)
|
168 |
+
|
169 |
+
def make_supervised_data_module_with_eval(tokenizer: transformers.PreTrainedTokenizer,
|
170 |
+
data_args) -> Dict:
|
171 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
172 |
+
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
173 |
+
data_path=data_args.data_path,
|
174 |
+
data_args=data_args)
|
175 |
+
if data_args.eval_data_path is None or data_args.eval_data_path == "":
|
176 |
+
print('Evaluation dataset not specified, skipping...')
|
177 |
+
eval_dataset = None
|
178 |
+
else:
|
179 |
+
eval_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
180 |
+
data_path=data_args.eval_data_path,
|
181 |
+
data_args=data_args)
|
182 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
183 |
+
return dict(train_dataset=train_dataset,
|
184 |
+
eval_dataset=eval_dataset,
|
185 |
+
data_collator=data_collator)
|
tinychart/data/preprocess/__init__.py
ADDED
File without changes
|
tinychart/data/preprocess/default.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Sequence, List
|
2 |
+
import copy
|
3 |
+
|
4 |
+
import transformers
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from tinychart.data.process import register_preprocess
|
8 |
+
from tinychart.mm_utils import tokenizer_image_token
|
9 |
+
from tinychart import conversation as conversation_lib
|
10 |
+
from tinychart.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
|
11 |
+
DEFAULT_IM_END_TOKEN
|
12 |
+
|
13 |
+
|
14 |
+
@register_preprocess('default')
|
15 |
+
def preprocess_default(
|
16 |
+
sources: Sequence[str],
|
17 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
18 |
+
has_image: bool = False
|
19 |
+
) -> Dict:
|
20 |
+
conversations = []
|
21 |
+
for source in sources:
|
22 |
+
header = f"{conversation_lib.default_conversation.system}\n\n"
|
23 |
+
conversation = _add_speaker_and_signal(header, source)
|
24 |
+
conversations.append(conversation)
|
25 |
+
|
26 |
+
# tokenize conversations
|
27 |
+
def get_tokenize_len(prompts):
|
28 |
+
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
|
29 |
+
|
30 |
+
if has_image:
|
31 |
+
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
32 |
+
else:
|
33 |
+
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
|
34 |
+
input_ids = conversations_tokenized["input_ids"]
|
35 |
+
|
36 |
+
targets = copy.deepcopy(input_ids)
|
37 |
+
for target, source in zip(targets, sources):
|
38 |
+
if has_image:
|
39 |
+
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
|
40 |
+
else:
|
41 |
+
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
|
42 |
+
speakers = [sentence["from"] for sentence in source]
|
43 |
+
_mask_targets(target, tokenized_lens, speakers)
|
44 |
+
|
45 |
+
return dict(input_ids=input_ids, labels=targets)
|
46 |
+
|
47 |
+
|
48 |
+
def _tokenize_fn(strings: Sequence[str],
|
49 |
+
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
50 |
+
"""Tokenize a list of strings."""
|
51 |
+
tokenized_list = [
|
52 |
+
tokenizer(
|
53 |
+
text,
|
54 |
+
return_tensors="pt",
|
55 |
+
padding="longest",
|
56 |
+
max_length=tokenizer.model_max_length,
|
57 |
+
truncation=True,
|
58 |
+
) for text in strings
|
59 |
+
]
|
60 |
+
input_ids = labels = [
|
61 |
+
tokenized.input_ids[0] for tokenized in tokenized_list
|
62 |
+
]
|
63 |
+
input_ids_lens = labels_lens = [
|
64 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
65 |
+
for tokenized in tokenized_list
|
66 |
+
]
|
67 |
+
return dict(
|
68 |
+
input_ids=input_ids,
|
69 |
+
labels=labels,
|
70 |
+
input_ids_lens=input_ids_lens,
|
71 |
+
labels_lens=labels_lens,
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def _add_speaker_and_signal(header, source, get_conversation=True):
|
76 |
+
"""Add speaker and start/end signal on each round."""
|
77 |
+
BEGIN_SIGNAL = "### "
|
78 |
+
END_SIGNAL = "\n"
|
79 |
+
conversation = header
|
80 |
+
for sentence in source:
|
81 |
+
from_str = sentence["from"]
|
82 |
+
if from_str.lower() == "human":
|
83 |
+
from_str = conversation_lib.default_conversation.roles[0]
|
84 |
+
elif from_str.lower() == "gpt":
|
85 |
+
from_str = conversation_lib.default_conversation.roles[1]
|
86 |
+
else:
|
87 |
+
from_str = 'unknown'
|
88 |
+
sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
|
89 |
+
sentence["value"] + END_SIGNAL)
|
90 |
+
if get_conversation:
|
91 |
+
conversation += sentence["value"]
|
92 |
+
conversation += BEGIN_SIGNAL
|
93 |
+
return conversation
|
94 |
+
|
95 |
+
|
96 |
+
def _mask_targets(target, tokenized_lens, speakers):
|
97 |
+
# cur_idx = 0
|
98 |
+
cur_idx = tokenized_lens[0]
|
99 |
+
tokenized_lens = tokenized_lens[1:]
|
100 |
+
target[:cur_idx] = IGNORE_INDEX
|
101 |
+
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
102 |
+
if speaker == "human":
|
103 |
+
target[cur_idx + 2:cur_idx + tokenized_len] = IGNORE_INDEX
|
104 |
+
cur_idx += tokenized_len
|
tinychart/data/preprocess/phi.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Sequence, List
|
2 |
+
import copy
|
3 |
+
|
4 |
+
import transformers
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from tinychart.data.process import register_preprocess
|
8 |
+
from tinychart.mm_utils import tokenizer_image_token
|
9 |
+
from tinychart import conversation as conversation_lib
|
10 |
+
from tinychart.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
|
11 |
+
DEFAULT_IM_END_TOKEN
|
12 |
+
|
13 |
+
|
14 |
+
@register_preprocess('phi')
|
15 |
+
def preprocess_phi(
|
16 |
+
sources,
|
17 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
18 |
+
has_image: bool = False
|
19 |
+
) -> Dict:
|
20 |
+
conv = conversation_lib.default_conversation.copy()
|
21 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
22 |
+
|
23 |
+
# print('00000000000', sources)
|
24 |
+
# Apply prompt templates
|
25 |
+
conversations = []
|
26 |
+
|
27 |
+
for i, source in enumerate(sources):
|
28 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
29 |
+
# Skip the first one if it is not from human
|
30 |
+
source = source[1:]
|
31 |
+
|
32 |
+
conv.messages = []
|
33 |
+
for j, sentence in enumerate(source):
|
34 |
+
role = roles[sentence["from"]]
|
35 |
+
assert role == conv.roles[j % 2], f"{i}"
|
36 |
+
conv.append_message(role, sentence["value"])
|
37 |
+
conversations.append(conv.get_prompt())
|
38 |
+
# Tokenize conversations
|
39 |
+
if has_image:
|
40 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
41 |
+
else:
|
42 |
+
input_ids = tokenizer(
|
43 |
+
conversations,
|
44 |
+
return_tensors="pt",
|
45 |
+
padding="longest",
|
46 |
+
max_length=tokenizer.model_max_length,
|
47 |
+
truncation=True,
|
48 |
+
).input_ids
|
49 |
+
|
50 |
+
targets = input_ids.clone()
|
51 |
+
|
52 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
53 |
+
# print(tokenizer)
|
54 |
+
# Mask targets
|
55 |
+
sep = conv.sep + conv.roles[1] + ": "
|
56 |
+
for conversation, target in zip(conversations, targets):
|
57 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum()) + conversation.count(conv.sep2)
|
58 |
+
|
59 |
+
rounds = conversation.split(conv.sep2)
|
60 |
+
cur_len = 0
|
61 |
+
# target[:cur_len] = IGNORE_INDEX
|
62 |
+
for i, rou in enumerate(rounds):
|
63 |
+
if rou == "":
|
64 |
+
break
|
65 |
+
|
66 |
+
parts = rou.split(sep)
|
67 |
+
if len(parts) != 2:
|
68 |
+
break
|
69 |
+
parts[0] += sep
|
70 |
+
|
71 |
+
if has_image:
|
72 |
+
round_len = len(tokenizer_image_token(rou, tokenizer)) + 1
|
73 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
|
74 |
+
else:
|
75 |
+
round_len = len(tokenizer(rou).input_ids) + 1
|
76 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
|
77 |
+
|
78 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
79 |
+
|
80 |
+
cur_len += round_len
|
81 |
+
target[cur_len:] = IGNORE_INDEX
|
82 |
+
|
83 |
+
if cur_len < tokenizer.model_max_length:
|
84 |
+
if cur_len != total_len:
|
85 |
+
print(
|
86 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
87 |
+
f" (ignored)"
|
88 |
+
)
|
89 |
+
print("number of rounds: ", len(rounds) - 1)
|
90 |
+
print("rounds: ", rounds[:-1])
|
91 |
+
print("conversation: ", conversations)
|
92 |
+
print(target)
|
93 |
+
print(input_ids)
|
94 |
+
time.sleep(5)
|
95 |
+
target[:] = IGNORE_INDEX
|
96 |
+
|
97 |
+
return dict(
|
98 |
+
input_ids=input_ids,
|
99 |
+
labels=targets,
|
100 |
+
)
|
tinychart/data/preprocess/v1.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import Dict, Optional, Sequence, List
|
3 |
+
import copy
|
4 |
+
|
5 |
+
import transformers
|
6 |
+
import tokenizers
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from tinychart.data.process import register_preprocess
|
10 |
+
from tinychart.mm_utils import tokenizer_image_token
|
11 |
+
from tinychart import conversation as conversation_lib
|
12 |
+
from tinychart.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
|
13 |
+
DEFAULT_IM_END_TOKEN
|
14 |
+
|
15 |
+
from packaging import version
|
16 |
+
|
17 |
+
# IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
|
18 |
+
|
19 |
+
|
20 |
+
@register_preprocess('v1')
|
21 |
+
def preprocess_v1(
|
22 |
+
sources,
|
23 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
24 |
+
has_image: bool = False
|
25 |
+
) -> Dict:
|
26 |
+
# conv = conversation_lib.default_conversation.copy()
|
27 |
+
conv = conversation_lib.conv_phi_v0.copy()
|
28 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
29 |
+
|
30 |
+
# Apply prompt templates
|
31 |
+
conversations = []
|
32 |
+
for i, source in enumerate(sources):
|
33 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
34 |
+
# Skip the first one if it is not from human
|
35 |
+
source = source[1:]
|
36 |
+
|
37 |
+
conv.messages = []
|
38 |
+
for j, sentence in enumerate(source):
|
39 |
+
role = roles[sentence["from"]]
|
40 |
+
assert role == conv.roles[j % 2], f"{i}"
|
41 |
+
conv.append_message(role, sentence["value"])
|
42 |
+
conversations.append(conv.get_prompt())
|
43 |
+
|
44 |
+
# Tokenize conversations
|
45 |
+
|
46 |
+
if has_image:
|
47 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
48 |
+
else:
|
49 |
+
input_ids = tokenizer(
|
50 |
+
conversations,
|
51 |
+
return_tensors="pt",
|
52 |
+
padding="longest",
|
53 |
+
max_length=tokenizer.model_max_length,
|
54 |
+
truncation=True,
|
55 |
+
).input_ids
|
56 |
+
|
57 |
+
targets = input_ids.clone()
|
58 |
+
|
59 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
60 |
+
|
61 |
+
# Mask targets
|
62 |
+
sep = conv.sep + conv.roles[1] + ": "
|
63 |
+
for conversation, target in zip(conversations, targets):
|
64 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
65 |
+
# total_len = len(target)
|
66 |
+
|
67 |
+
rounds = conversation.split(conv.sep2)
|
68 |
+
cur_len = 0
|
69 |
+
# cur_len = 1
|
70 |
+
# cur_len = 1 + 1
|
71 |
+
target[:cur_len] = IGNORE_INDEX
|
72 |
+
for i, rou in enumerate(rounds):
|
73 |
+
if rou == "":
|
74 |
+
break
|
75 |
+
|
76 |
+
parts = rou.split(sep)
|
77 |
+
if len(parts) != 2:
|
78 |
+
break
|
79 |
+
parts[0] += sep
|
80 |
+
|
81 |
+
if has_image:
|
82 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
83 |
+
# round_len = len(tokenizer_image_token(rou, tokenizer)) - 2 + 1
|
84 |
+
# instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
85 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
|
86 |
+
else:
|
87 |
+
round_len = len(tokenizer(rou).input_ids)
|
88 |
+
# round_len = len(tokenizer(rou).input_ids) - 2 + 1
|
89 |
+
# instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
90 |
+
instruction_len = len(tokenizer(parts[0]).input_ids)
|
91 |
+
|
92 |
+
# if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
|
93 |
+
# round_len -= 1
|
94 |
+
# instruction_len -= 1
|
95 |
+
instruction_len -= 1
|
96 |
+
|
97 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
98 |
+
|
99 |
+
cur_len += round_len
|
100 |
+
# target[cur_len:] = IGNORE_INDEX
|
101 |
+
# import pdb;pdb.set_trace()
|
102 |
+
|
103 |
+
if cur_len < tokenizer.model_max_length:
|
104 |
+
if cur_len != total_len:
|
105 |
+
|
106 |
+
print(
|
107 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
108 |
+
f" (ignored)"
|
109 |
+
)
|
110 |
+
print("number of rounds: ", len(rounds) - 1)
|
111 |
+
print("rounds: ", rounds[:-1])
|
112 |
+
print("conversation: ", conversations)
|
113 |
+
print(target)
|
114 |
+
print(input_ids)
|
115 |
+
time.sleep(5)
|
116 |
+
target[:] = IGNORE_INDEX
|
117 |
+
return dict(
|
118 |
+
input_ids=input_ids,
|
119 |
+
labels=targets,
|
120 |
+
)
|
tinychart/data/process.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib
|
3 |
+
from typing import Dict, Optional, Sequence, List
|
4 |
+
|
5 |
+
import transformers
|
6 |
+
|
7 |
+
from tinychart.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
8 |
+
from tinychart import conversation as conversation_lib
|
9 |
+
from tinychart.arguments import *
|
10 |
+
|
11 |
+
PREPROCESS_REGISTRY = {}
|
12 |
+
|
13 |
+
def register_preprocess(name):
|
14 |
+
def register_preprocess_cls(cls):
|
15 |
+
if name in PREPROCESS_REGISTRY:
|
16 |
+
return PREPROCESS_REGISTRY[name]
|
17 |
+
|
18 |
+
PREPROCESS_REGISTRY[name] = cls
|
19 |
+
return cls
|
20 |
+
|
21 |
+
return register_preprocess_cls
|
22 |
+
|
23 |
+
|
24 |
+
def import_modules(modules_dir, namespace):
|
25 |
+
for file in os.listdir(modules_dir):
|
26 |
+
path = os.path.join(modules_dir, file)
|
27 |
+
|
28 |
+
if (
|
29 |
+
not file.startswith("_")
|
30 |
+
and not file.startswith(".")
|
31 |
+
and (file.endswith(".py") or os.path.isdir(path))
|
32 |
+
):
|
33 |
+
module_name = file[: file.find(".py")] if file.endswith(".py") else file
|
34 |
+
importlib.import_module(namespace + "." + module_name)
|
35 |
+
|
36 |
+
models_dir = os.path.join(os.path.dirname(__file__), 'preprocess')
|
37 |
+
import_modules(models_dir, "tinychart.data.preprocess")
|
38 |
+
|
39 |
+
|
40 |
+
def PreprocessSelect(version):
|
41 |
+
result = PREPROCESS_REGISTRY.get(version, None)
|
42 |
+
if result is None:
|
43 |
+
for name in PREPROCESS_REGISTRY.keys():
|
44 |
+
if version in name:
|
45 |
+
result = PREPROCESS_REGISTRY[name]
|
46 |
+
break
|
47 |
+
if result is None:
|
48 |
+
result = PREPROCESS_REGISTRY['default']
|
49 |
+
return result
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
def preprocess_multimodal(
|
54 |
+
sources: Sequence[str],
|
55 |
+
data_args: DataArguments
|
56 |
+
) -> Dict:
|
57 |
+
is_multimodal = data_args.is_multimodal
|
58 |
+
if not is_multimodal:
|
59 |
+
return sources
|
60 |
+
|
61 |
+
for source in sources:
|
62 |
+
for sentence in source:
|
63 |
+
if DEFAULT_IMAGE_TOKEN in sentence['value']:
|
64 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
65 |
+
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
|
66 |
+
sentence['value'] = sentence['value'].strip()
|
67 |
+
if "mmtag" in conversation_lib.default_conversation.version:
|
68 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN,
|
69 |
+
'<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
|
70 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
71 |
+
if data_args.mm_use_im_start_end:
|
72 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
73 |
+
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
74 |
+
|
75 |
+
return sources
|
76 |
+
|
77 |
+
|
78 |
+
def preprocess(
|
79 |
+
sources: Sequence[str],
|
80 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
81 |
+
has_image: bool = False
|
82 |
+
) -> Dict:
|
83 |
+
return PreprocessSelect(conversation_lib.default_conversation.version)(sources, tokenizer, has_image)
|
tinychart/eval/__init__.py
ADDED
File without changes
|
tinychart/eval/eval_metric.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
import copy
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
def write_jsonl(data, filename):
|
10 |
+
with open(filename, 'w') as f:
|
11 |
+
for item in data:
|
12 |
+
f.write(json.dumps(item) + '\n')
|
13 |
+
|
14 |
+
def RelaxedAccuracy(pred, gt):
|
15 |
+
try:
|
16 |
+
gt = float(gt)
|
17 |
+
pred = float(pred)
|
18 |
+
if gt == 0.0:
|
19 |
+
if pred == gt:
|
20 |
+
return 1.0
|
21 |
+
else:
|
22 |
+
return 0.0
|
23 |
+
else:
|
24 |
+
if abs(pred-gt) / gt <= 0.05:
|
25 |
+
return 1.0
|
26 |
+
else:
|
27 |
+
return 0.0
|
28 |
+
except:
|
29 |
+
if str(gt) == str(pred):
|
30 |
+
return 1.0
|
31 |
+
else:
|
32 |
+
return 0.0
|
33 |
+
|
34 |
+
def evaluate_cmds(cmds):
|
35 |
+
for cmd in cmds:
|
36 |
+
exec(cmd)
|
37 |
+
answer = eval('Answer')
|
38 |
+
if (isinstance(answer, list) or isinstance(answer, np.ndarray)) and len(answer) == 1:
|
39 |
+
answer = answer[0]
|
40 |
+
if isinstance(answer, list) or isinstance(answer, np.ndarray):
|
41 |
+
new_answer = answer[0]
|
42 |
+
for i in range(1, len(answer)-1):
|
43 |
+
new_answer = new_answer + ', ' + answer[i]
|
44 |
+
new_answer += ' and ' + answer[-1]
|
45 |
+
answer = new_answer
|
46 |
+
if isinstance(answer, bool) or isinstance(answer, np.bool_):
|
47 |
+
if answer == True:
|
48 |
+
answer = 'Yes'
|
49 |
+
elif answer == False:
|
50 |
+
answer = 'No'
|
51 |
+
return answer
|
52 |
+
|
53 |
+
def parse_model_output(cmdstr):
|
54 |
+
lines = cmdstr.split('\n')
|
55 |
+
new_lines = []
|
56 |
+
for line in lines:
|
57 |
+
if '<step>' in line or '</step>' in line:
|
58 |
+
line = line.replace('<step>', '').replace('</step>', '')
|
59 |
+
new_lines.append(line)
|
60 |
+
return new_lines
|
61 |
+
|
62 |
+
def chartqa_evaluator(data, key='final_model_answer'):
|
63 |
+
acc = 0
|
64 |
+
for item in data:
|
65 |
+
item['relaxed_acc'] = RelaxedAccuracy(item[key], item['gt_answer'].split('<pot_note>')[0])
|
66 |
+
if item['relaxed_acc'] == 1.0:
|
67 |
+
acc += 1
|
68 |
+
accuracy = acc/len(data)
|
69 |
+
return data, accuracy
|
70 |
+
|
71 |
+
def chartqapot_evaluator(output_data):
|
72 |
+
correct_items = []
|
73 |
+
wrong_items = []
|
74 |
+
error_items = []
|
75 |
+
output_data = copy.deepcopy(output_data)
|
76 |
+
acc = 0
|
77 |
+
for item in output_data:
|
78 |
+
# cmds = parse_gpt_cmd(gpt_item['eval_cmd'])
|
79 |
+
eval_cmds = parse_model_output(item['model_answer'])
|
80 |
+
try:
|
81 |
+
answer = evaluate_cmds(eval_cmds)
|
82 |
+
item['final_model_answer'] = str(answer)
|
83 |
+
except:
|
84 |
+
error_items.append(item)
|
85 |
+
item['final_model_answer'] = 'Execute <error>'
|
86 |
+
item['relaxed_acc'] = 0.0
|
87 |
+
continue
|
88 |
+
item['gt_answer'] = item['gt_answer'].split('<cot_note>')[0]
|
89 |
+
item['relaxed_acc'] = RelaxedAccuracy(str(answer), item['gt_answer'])
|
90 |
+
|
91 |
+
if item['relaxed_acc'] == 1.0:
|
92 |
+
correct_items.append(item)
|
93 |
+
else:
|
94 |
+
wrong_items.append(item)
|
95 |
+
total = len(output_data)
|
96 |
+
accuracy = len(correct_items)/total
|
97 |
+
error_rate = len(error_items)/total
|
98 |
+
|
99 |
+
return output_data, accuracy, error_rate
|
100 |
+
|
101 |
+
def rule_based_divider(question):
|
102 |
+
calculate_words = [
|
103 |
+
'sum', 'difference', 'times', 'summation', 'exceed',
|
104 |
+
'below', 'addition', 'fewer', 'subtract', ' mode ',
|
105 |
+
'ratio', 'division', 'average', 'mean', 'bigger',
|
106 |
+
'greater', ' less ', 'tallest', 'number', 'divide',
|
107 |
+
' add ', 'absolute', 'dividing', 'differ', ' minus ',
|
108 |
+
'how many colors', 'lowest', 'what is the value', 'higher',
|
109 |
+
'longer', ' biggest ', 'lowest'
|
110 |
+
]
|
111 |
+
|
112 |
+
for w in calculate_words:
|
113 |
+
if w in question.lower():
|
114 |
+
return 'pot'
|
115 |
+
return 'direct'
|
116 |
+
|
117 |
+
def chartqa_rule_merger_evaluator(direct_data, pot_data):
|
118 |
+
direct_data, _ = chartqa_evaluator(direct_data, key='model_answer')
|
119 |
+
assert len(direct_data) == len(pot_data), 'direct and pot num inconsistent'
|
120 |
+
acc_count = 0
|
121 |
+
merged_data = []
|
122 |
+
for datum1, datum2 in zip(direct_data, pot_data):
|
123 |
+
if rule_based_divider(datum1['question']) == 'pot' and '<error>' not in datum2['final_model_answer'] and datum2['final_model_answer'] not in ['inf', '-inf', 'nan', 'np.nan', 'np.inf', '-np.inf']:
|
124 |
+
acc_count += datum2['relaxed_acc']
|
125 |
+
merged_data.append(datum2)
|
126 |
+
else:
|
127 |
+
acc_count += datum1['relaxed_acc']
|
128 |
+
merged_data.append(datum1)
|
129 |
+
accuracy = acc_count/len(direct_data)
|
130 |
+
return merged_data, accuracy
|
131 |
+
|
132 |
+
def chartqa_oracle_merger_evaluator(direct_data, pot_data):
|
133 |
+
direct_data, _ = chartqa_evaluator(direct_data, key='model_answer')
|
134 |
+
assert len(direct_data) == len(pot_data), 'direct and pot num inconsistent'
|
135 |
+
acc_count = 0
|
136 |
+
merged_data = []
|
137 |
+
for datum1, datum2 in zip(direct_data, pot_data):
|
138 |
+
if datum1['relaxed_acc'] != 1.0:
|
139 |
+
acc_count += datum2['relaxed_acc']
|
140 |
+
merged_data.append(datum2)
|
141 |
+
else:
|
142 |
+
acc_count += datum1['relaxed_acc']
|
143 |
+
merged_data.append(datum1)
|
144 |
+
accuracy = acc_count/len(direct_data)
|
145 |
+
return merged_data, accuracy
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == '__main__':
|
149 |
+
parser = argparse.ArgumentParser()
|
150 |
+
parser.add_argument('--direct', default='../eval_iter12000_0226/ChartQA_test_12000_pred.jsonl')
|
151 |
+
parser.add_argument('--pot', default='../eval_iter12000_0226/ChartQA_test_pot_12000_eval.jsonl')
|
152 |
+
parser.add_argument('--output', default='../eval_iter12000_0226/ChartQA_test_pot_12000_merged.jsonl')
|
153 |
+
|
154 |
+
args = parser.parse_args()
|
155 |
+
|
156 |
+
merged = oracle_merger(args.direct, args.pot)
|
157 |
+
merged = rule_based_merger(args.direct, args.pot)
|
158 |
+
|
159 |
+
write_jsonl(merged, args.output)
|
tinychart/eval/eval_model.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
from tqdm import tqdm
|
8 |
+
import shortuuid
|
9 |
+
|
10 |
+
from tinychart.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
11 |
+
from tinychart.conversation import conv_templates, SeparatorStyle
|
12 |
+
from tinychart.model.builder import load_pretrained_model
|
13 |
+
from tinychart.utils import disable_torch_init
|
14 |
+
from tinychart.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path, KeywordsStoppingCriteria
|
15 |
+
from torch.utils.data import Dataset, DataLoader
|
16 |
+
|
17 |
+
from PIL import Image
|
18 |
+
import math
|
19 |
+
|
20 |
+
|
21 |
+
def split_list(lst, n):
|
22 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
23 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
24 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
25 |
+
|
26 |
+
|
27 |
+
def get_chunk(lst, n, k):
|
28 |
+
chunks = split_list(lst, n)
|
29 |
+
return chunks[k]
|
30 |
+
|
31 |
+
class EvalDataset(Dataset):
|
32 |
+
def __init__(self, data_items, image_folder, tokenizer, image_processor, model_config):
|
33 |
+
self.data_items = data_items
|
34 |
+
self.image_folder = image_folder
|
35 |
+
self.tokenizer = tokenizer
|
36 |
+
self.image_processor = image_processor
|
37 |
+
self.model_config = model_config
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
line = self.data_items[index]
|
41 |
+
image_file = line["image"]
|
42 |
+
qs = line["conversations"][0]["value"]
|
43 |
+
# if self.model_config.mm_use_im_start_end:
|
44 |
+
# qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
45 |
+
# else:
|
46 |
+
# qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
47 |
+
|
48 |
+
conv = conv_templates[args.conv_mode].copy()
|
49 |
+
conv.append_message(conv.roles[0], qs)
|
50 |
+
conv.append_message(conv.roles[1], None)
|
51 |
+
prompt = conv.get_prompt()
|
52 |
+
|
53 |
+
image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
|
54 |
+
image_tensor = process_images([image], self.image_processor, self.model_config)[0]
|
55 |
+
|
56 |
+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
57 |
+
|
58 |
+
return input_ids, image_tensor, image.size
|
59 |
+
|
60 |
+
def __len__(self):
|
61 |
+
return len(self.data_items)
|
62 |
+
|
63 |
+
|
64 |
+
def collate_fn(batch):
|
65 |
+
input_ids, image_tensors, image_sizes = zip(*batch)
|
66 |
+
input_ids = torch.stack(input_ids, dim=0)
|
67 |
+
image_tensors = torch.stack(image_tensors, dim=0)
|
68 |
+
return input_ids, image_tensors, image_sizes
|
69 |
+
|
70 |
+
|
71 |
+
# DataLoader
|
72 |
+
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
|
73 |
+
assert batch_size == 1, "batch_size must be 1"
|
74 |
+
dataset = EvalDataset(questions, image_folder, tokenizer, image_processor, model_config)
|
75 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
|
76 |
+
return data_loader
|
77 |
+
|
78 |
+
|
79 |
+
def eval_model(args):
|
80 |
+
disable_torch_init()
|
81 |
+
model_path = os.path.expanduser(args.model_path)
|
82 |
+
model_name = get_model_name_from_path(model_path)
|
83 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
|
84 |
+
|
85 |
+
all_data = json.load(open(args.data_path, "r"))
|
86 |
+
all_data = get_chunk(all_data, args.num_chunks, args.chunk_idx)
|
87 |
+
answers_file = os.path.expanduser(args.output_path)
|
88 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
89 |
+
ans_file = open(answers_file, "w")
|
90 |
+
|
91 |
+
if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
|
92 |
+
args.conv_mode = args.conv_mode + '_mmtag'
|
93 |
+
print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
|
94 |
+
|
95 |
+
data_loader = create_data_loader(all_data, args.image_folder, tokenizer, image_processor, model.config)
|
96 |
+
for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, all_data), total=len(all_data)):
|
97 |
+
idx = line["id"]
|
98 |
+
cur_prompt = line["conversations"][0]["value"]
|
99 |
+
input_ids = input_ids.to(device='cuda', non_blocking=True)
|
100 |
+
with torch.inference_mode():
|
101 |
+
output_ids = model.generate(
|
102 |
+
input_ids,
|
103 |
+
images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
|
104 |
+
pad_token_id=tokenizer.pad_token_id,
|
105 |
+
do_sample=True if args.temperature > 0 else False,
|
106 |
+
temperature=args.temperature,
|
107 |
+
top_p=args.top_p,
|
108 |
+
num_beams=args.num_beams,
|
109 |
+
max_new_tokens=args.max_new_tokens,
|
110 |
+
min_new_tokens=args.min_new_tokens,
|
111 |
+
use_cache=True)
|
112 |
+
|
113 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
114 |
+
ans_id = shortuuid.uuid()
|
115 |
+
ans_file.write(json.dumps({"id": idx,
|
116 |
+
"question": cur_prompt,
|
117 |
+
"gt_answer": line["conversations"][1]["value"],
|
118 |
+
"model_answer": outputs}) + "\n")
|
119 |
+
ans_file.flush()
|
120 |
+
ans_file.close()
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
parser = argparse.ArgumentParser()
|
124 |
+
parser.add_argument("--model_path", type=str, default="facebook/opt-350m")
|
125 |
+
parser.add_argument("--model_base", type=str, default=None)
|
126 |
+
parser.add_argument("--image_folder", type=str, default="")
|
127 |
+
parser.add_argument("--data_path", type=str, default="./data/test_chartqa+cot_shuffle.json")
|
128 |
+
parser.add_argument("--output_path", type=str, default="./output/")
|
129 |
+
parser.add_argument("--conv_mode", type=str, default="phi")
|
130 |
+
parser.add_argument("--num_chunks", type=int, default=1)
|
131 |
+
parser.add_argument("--chunk_idx", type=int, default=0)
|
132 |
+
parser.add_argument("--temperature", type=float, default=0.0)
|
133 |
+
parser.add_argument("--top_p", type=float, default=None)
|
134 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
135 |
+
parser.add_argument("--max_new_tokens", type=int, default=1024)
|
136 |
+
parser.add_argument("--min_new_tokens", type=int, default=0)
|
137 |
+
args = parser.parse_args()
|
138 |
+
|
139 |
+
eval_model(args)
|
tinychart/eval/run_eval.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import pandas as pd
|
5 |
+
from collections import defaultdict
|
6 |
+
from tinychart.eval.eval_metric import chartqa_evaluator, chartqapot_evaluator
|
7 |
+
from tinychart.eval.eval_metric import chartqa_oracle_merger_evaluator, chartqa_rule_merger_evaluator
|
8 |
+
|
9 |
+
def read_jsonl(jsonl_path):
|
10 |
+
with open(jsonl_path, 'r') as f:
|
11 |
+
data = [json.loads(line) for line in f]
|
12 |
+
return data
|
13 |
+
|
14 |
+
def write_jsonl(data, jsonl_path):
|
15 |
+
with open(jsonl_path, 'w', encoding='utf-8') as f:
|
16 |
+
for item in data:
|
17 |
+
f.write(json.dumps(item) + '\n')
|
18 |
+
|
19 |
+
if __name__ == '__main__':
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument('--input', default='./output/')
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
result_files = os.listdir(args.input)
|
26 |
+
result_files = [f for f in result_files if f.endswith('.jsonl')]
|
27 |
+
result_files.sort()
|
28 |
+
direct_result, pot_result = None, None
|
29 |
+
|
30 |
+
dataset2metric = defaultdict(float)
|
31 |
+
for result_file in result_files:
|
32 |
+
# print(result_file)
|
33 |
+
dataset_name = '.'.join(result_file.split('.')[:-1])
|
34 |
+
file = os.path.join(args.input, result_file)
|
35 |
+
result_data = read_jsonl(file)
|
36 |
+
if 'chartqa-' in dataset_name:
|
37 |
+
direct_result, direct_acc = chartqa_evaluator(result_data, key='model_answer')
|
38 |
+
write_jsonl(direct_result, file)
|
39 |
+
dataset2metric[dataset_name] = round(direct_acc * 100, 2)
|
40 |
+
print(f'Direct Accuracy: {direct_acc}')
|
41 |
+
elif 'chartqagptpot-' in dataset_name or 'chartqatemplatepot-' in dataset_name:
|
42 |
+
pot_result, pot_acc, error_rate = chartqapot_evaluator(result_data)
|
43 |
+
write_jsonl(pot_result, file)
|
44 |
+
dataset2metric[dataset_name] = round(pot_acc * 100, 2)
|
45 |
+
print(f'PoT Accuracy: {pot_acc}')
|
46 |
+
print(f'PoT Error Rate: {error_rate}')
|
47 |
+
|
48 |
+
if direct_result is not None and pot_result is not None:
|
49 |
+
print("Calculate merging direct and pot results with simple divider")
|
50 |
+
oracle_results, oracle_acc = chartqa_oracle_merger_evaluator(direct_result, pot_result)
|
51 |
+
dataset2metric['merged-oracle'] = round(oracle_acc * 100, 2)
|
52 |
+
print(f'Oracle Merged Accuracy: {oracle_acc}')
|
53 |
+
write_jsonl(oracle_results, os.path.join(args.input, 'merged-oracle.jsonl'))
|
54 |
+
rule_results, rule_acc = chartqa_rule_merger_evaluator(direct_result, pot_result)
|
55 |
+
dataset2metric['merged-rule'] = round(rule_acc * 100, 2)
|
56 |
+
print(f'Rule Merged Accuracy: {rule_acc}')
|
57 |
+
write_jsonl(rule_results, os.path.join(args.input, 'merged-rule.jsonl'))
|
58 |
+
|
59 |
+
# save metrics into tsv with key as the first row
|
60 |
+
df = pd.DataFrame(dataset2metric, index=[0])
|
61 |
+
# if there is a metrics.tsv exists, add one in the name to avoid overwrite
|
62 |
+
tsv_name = os.path.join(args.input, 'metrics.tsv')
|
63 |
+
if os.path.exists(tsv_name):
|
64 |
+
# avoid overwrite. if there is metrics.1.tsv, name it metrics.2.tsv...
|
65 |
+
i = 1
|
66 |
+
tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
|
67 |
+
while os.path.exists(tsv_name):
|
68 |
+
i += 1
|
69 |
+
tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
|
70 |
+
df.to_csv(tsv_name, sep='\t', index=False)
|
71 |
+
print(f'Metrics saved at: {tsv_name}')
|
72 |
+
print(df)
|
tinychart/eval/run_tiny_chart.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from tinychart.constants import (
|
5 |
+
IMAGE_TOKEN_INDEX,
|
6 |
+
DEFAULT_IMAGE_TOKEN,
|
7 |
+
DEFAULT_IM_START_TOKEN,
|
8 |
+
DEFAULT_IM_END_TOKEN,
|
9 |
+
IMAGE_PLACEHOLDER,
|
10 |
+
)
|
11 |
+
from tinychart.conversation import conv_templates, SeparatorStyle
|
12 |
+
from tinychart.model.builder import load_pretrained_model
|
13 |
+
from tinychart.utils import disable_torch_init
|
14 |
+
from tinychart.mm_utils import (
|
15 |
+
process_images,
|
16 |
+
tokenizer_image_token,
|
17 |
+
get_model_name_from_path,
|
18 |
+
KeywordsStoppingCriteria,
|
19 |
+
)
|
20 |
+
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
import requests
|
24 |
+
from PIL import Image
|
25 |
+
from io import BytesIO
|
26 |
+
import re
|
27 |
+
|
28 |
+
|
29 |
+
def image_parser(args):
|
30 |
+
out = args.image_file.split(args.sep)
|
31 |
+
return out
|
32 |
+
|
33 |
+
|
34 |
+
def load_image(image_file):
|
35 |
+
if image_file.startswith("http") or image_file.startswith("https"):
|
36 |
+
response = requests.get(image_file)
|
37 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
38 |
+
else:
|
39 |
+
image = Image.open(image_file).convert("RGB")
|
40 |
+
return image
|
41 |
+
|
42 |
+
|
43 |
+
def load_images(image_files):
|
44 |
+
out = []
|
45 |
+
for image_file in image_files:
|
46 |
+
image = load_image(image_file)
|
47 |
+
out.append(image)
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
def inference_model(image_files, query, model, tokenizer, image_processor, context_len, conv_mode, temperature=0, max_new_tokens=100):
|
52 |
+
qs = query
|
53 |
+
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
54 |
+
if IMAGE_PLACEHOLDER in qs:
|
55 |
+
if model.config.mm_use_im_start_end:
|
56 |
+
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
|
57 |
+
else:
|
58 |
+
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
|
59 |
+
else:
|
60 |
+
if model.config.mm_use_im_start_end:
|
61 |
+
qs = image_token_se + "\n" + qs
|
62 |
+
else:
|
63 |
+
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
|
64 |
+
|
65 |
+
conv = conv_templates[conv_mode].copy()
|
66 |
+
conv.append_message(conv.roles[0], qs)
|
67 |
+
conv.append_message(conv.roles[1], None)
|
68 |
+
prompt = conv.get_prompt()
|
69 |
+
|
70 |
+
images = load_images(image_files)
|
71 |
+
images_tensor = process_images(
|
72 |
+
images,
|
73 |
+
image_processor,
|
74 |
+
model.config
|
75 |
+
).to(model.device, dtype=torch.float16)
|
76 |
+
|
77 |
+
input_ids = (
|
78 |
+
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
79 |
+
.unsqueeze(0)
|
80 |
+
.cuda()
|
81 |
+
)
|
82 |
+
|
83 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
84 |
+
keywords = [stop_str]
|
85 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
86 |
+
|
87 |
+
with torch.inference_mode():
|
88 |
+
output_ids = model.generate(
|
89 |
+
input_ids,
|
90 |
+
images=images_tensor,
|
91 |
+
do_sample=True if temperature > 0 else False,
|
92 |
+
temperature=temperature,
|
93 |
+
# top_p=top_p,
|
94 |
+
# num_beams=args.num_beams,
|
95 |
+
pad_token_id=tokenizer.pad_token_id,
|
96 |
+
max_new_tokens=max_new_tokens,
|
97 |
+
use_cache=True,
|
98 |
+
stopping_criteria=[stopping_criteria],
|
99 |
+
)
|
100 |
+
|
101 |
+
outputs = tokenizer.batch_decode(
|
102 |
+
output_ids, skip_special_tokens=True
|
103 |
+
)[0]
|
104 |
+
outputs = outputs.strip()
|
105 |
+
if outputs.endswith(stop_str):
|
106 |
+
outputs = outputs[: -len(stop_str)]
|
107 |
+
outputs = outputs.strip()
|
108 |
+
print(outputs)
|
109 |
+
return outputs
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
parser = argparse.ArgumentParser()
|
115 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
116 |
+
parser.add_argument("--model-base", type=str, default=None)
|
117 |
+
parser.add_argument("--image-file", type=str, required=True)
|
118 |
+
parser.add_argument("--query", type=str, required=True)
|
119 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
120 |
+
parser.add_argument("--sep", type=str, default=",")
|
121 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
122 |
+
parser.add_argument("--top_p", type=float, default=None)
|
123 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
124 |
+
parser.add_argument("--max_new_tokens", type=int, default=512)
|
125 |
+
args = parser.parse_args()
|
126 |
+
|
127 |
+
inference_model(args)
|
tinychart/mm_utils.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import StoppingCriteria
|
7 |
+
from tinychart.constants import IMAGE_TOKEN_INDEX
|
8 |
+
import math
|
9 |
+
import ast
|
10 |
+
|
11 |
+
|
12 |
+
def load_image_from_base64(image):
|
13 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
14 |
+
|
15 |
+
|
16 |
+
def expand2square(pil_img, background_color):
|
17 |
+
width, height = pil_img.size
|
18 |
+
if width == height:
|
19 |
+
return pil_img
|
20 |
+
elif width > height:
|
21 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
22 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
23 |
+
return result
|
24 |
+
else:
|
25 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
26 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
27 |
+
return result
|
28 |
+
|
29 |
+
|
30 |
+
def process_images(images, image_processor, model_cfg):
|
31 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
32 |
+
new_images = []
|
33 |
+
if image_aspect_ratio == 'pad':
|
34 |
+
for image in images:
|
35 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
36 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
37 |
+
new_images.append(image)
|
38 |
+
elif image_aspect_ratio == "anyres":
|
39 |
+
for image in images:
|
40 |
+
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
|
41 |
+
new_images.append(image)
|
42 |
+
else:
|
43 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
44 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
45 |
+
new_images = torch.stack(new_images, dim=0)
|
46 |
+
return new_images
|
47 |
+
|
48 |
+
|
49 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
50 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
51 |
+
|
52 |
+
def insert_separator(X, sep):
|
53 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
54 |
+
|
55 |
+
input_ids = []
|
56 |
+
offset = 0
|
57 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
58 |
+
offset = 1
|
59 |
+
input_ids.append(prompt_chunks[0][0])
|
60 |
+
|
61 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
62 |
+
input_ids.extend(x[offset:])
|
63 |
+
|
64 |
+
if return_tensors is not None:
|
65 |
+
if return_tensors == 'pt':
|
66 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
67 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
68 |
+
return input_ids
|
69 |
+
|
70 |
+
|
71 |
+
def get_model_name_from_path(model_path):
|
72 |
+
model_path = model_path.strip("/")
|
73 |
+
model_paths = model_path.split("/")
|
74 |
+
if model_paths[-1].startswith('checkpoint-'):
|
75 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
76 |
+
else:
|
77 |
+
return model_paths[-1]
|
78 |
+
|
79 |
+
|
80 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
81 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
82 |
+
self.keywords = keywords
|
83 |
+
self.keyword_ids = []
|
84 |
+
self.max_keyword_len = 0
|
85 |
+
for keyword in keywords:
|
86 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
87 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
88 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
89 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
90 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
91 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
92 |
+
self.tokenizer = tokenizer
|
93 |
+
self.start_len = input_ids.shape[1]
|
94 |
+
|
95 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
96 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
97 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
98 |
+
for keyword_id in self.keyword_ids:
|
99 |
+
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
100 |
+
return True
|
101 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
102 |
+
for keyword in self.keywords:
|
103 |
+
if keyword in outputs:
|
104 |
+
return True
|
105 |
+
return False
|
106 |
+
|
107 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
108 |
+
outputs = []
|
109 |
+
for i in range(output_ids.shape[0]):
|
110 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
111 |
+
return all(outputs)
|
tinychart/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from tinychart.model.language_model.llava_phi import TinyChartPhiForCausalLM, TinyChartPhiConfig
|
tinychart/model/builder.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Baichuan Zhou , Junlong Jia, Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import os
|
17 |
+
import warnings
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
21 |
+
import torch
|
22 |
+
from tinychart.model import *
|
23 |
+
from tinychart.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
+
|
25 |
+
|
26 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto",
|
27 |
+
device="cuda", **kwargs):
|
28 |
+
kwargs = {"device_map": device_map, **kwargs}
|
29 |
+
|
30 |
+
if device != "cuda":
|
31 |
+
kwargs['device_map'] = {"": device}
|
32 |
+
|
33 |
+
if load_8bit:
|
34 |
+
kwargs['load_in_8bit'] = True
|
35 |
+
elif load_4bit:
|
36 |
+
kwargs['load_in_4bit'] = True
|
37 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
38 |
+
load_in_4bit=True,
|
39 |
+
bnb_4bit_compute_dtype=torch.float16,
|
40 |
+
bnb_4bit_use_double_quant=True,
|
41 |
+
bnb_4bit_quant_type='nf4'
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
kwargs['torch_dtype'] = torch.float16
|
45 |
+
|
46 |
+
# Load LLaVA model
|
47 |
+
if 'lora' in model_name.lower() and model_base is None:
|
48 |
+
warnings.warn(
|
49 |
+
'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
|
50 |
+
if 'lora' in model_name.lower() and model_base is not None:
|
51 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
52 |
+
|
53 |
+
print('Loading LLaVA from base model...')
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, padding_side="right")
|
55 |
+
model = TinyChartPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
|
56 |
+
config=lora_cfg_pretrained, **kwargs)
|
57 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
58 |
+
if model.lm_head.weight.shape[0] != token_num:
|
59 |
+
model.lm_head.weight = torch.nn.Parameter(
|
60 |
+
torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
61 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(
|
62 |
+
torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
63 |
+
|
64 |
+
print('Loading additional LLaVA weights...')
|
65 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
66 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
67 |
+
else:
|
68 |
+
# this is probably from HF Hub
|
69 |
+
from huggingface_hub import hf_hub_download
|
70 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
71 |
+
cache_file = hf_hub_download(
|
72 |
+
repo_id=repo_id,
|
73 |
+
filename=filename,
|
74 |
+
subfolder=subfolder)
|
75 |
+
return torch.load(cache_file, map_location='cpu')
|
76 |
+
|
77 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
78 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
|
79 |
+
non_lora_trainables.items()}
|
80 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
81 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
82 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
83 |
+
|
84 |
+
from peft import PeftModel
|
85 |
+
print('Loading LoRA weights...')
|
86 |
+
model = PeftModel.from_pretrained(model, model_path)
|
87 |
+
print('Merging LoRA weights...')
|
88 |
+
model = model.merge_and_unload()
|
89 |
+
print('Model is loaded...')
|
90 |
+
elif model_base is not None:
|
91 |
+
# this may be mm projector only
|
92 |
+
print('Loading LLaVA from base model...')
|
93 |
+
|
94 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, padding_side="right")
|
95 |
+
cfg_pretrained = TinyChartPhiConfig.from_pretrained(model_path)
|
96 |
+
model = TinyChartPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
|
97 |
+
**kwargs)
|
98 |
+
|
99 |
+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
100 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
101 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
102 |
+
else:
|
103 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="right")
|
104 |
+
model = TinyChartPhiForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
105 |
+
|
106 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
107 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
108 |
+
if mm_use_im_patch_token:
|
109 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
110 |
+
if mm_use_im_start_end:
|
111 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
112 |
+
model.resize_token_embeddings(len(tokenizer))
|
113 |
+
vision_tower = model.get_vision_tower()
|
114 |
+
if not vision_tower.is_loaded:
|
115 |
+
vision_tower.load_model()
|
116 |
+
|
117 |
+
if device != "auto":
|
118 |
+
vision_tower.to(device=device, dtype=torch.float16)
|
119 |
+
|
120 |
+
image_processor = vision_tower.image_processor
|
121 |
+
|
122 |
+
if hasattr(model.config, "max_sequence_length"):
|
123 |
+
context_len = model.config.max_sequence_length
|
124 |
+
else:
|
125 |
+
context_len = 2048
|
126 |
+
|
127 |
+
return tokenizer, model, image_processor, context_len
|
tinychart/model/language_model/__init__.py
ADDED
File without changes
|
tinychart/model/language_model/llava_phi.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Baichuan Zhou , Junlong Jia, Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
|
22 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
23 |
+
|
24 |
+
from transformers import PhiConfig, PhiModel, PhiForCausalLM
|
25 |
+
|
26 |
+
from transformers.generation.utils import GenerateOutput
|
27 |
+
|
28 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
29 |
+
|
30 |
+
from tinychart.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
31 |
+
from tinychart.model.model_factory import *
|
32 |
+
|
33 |
+
class TinyChartPhiConfig(PhiConfig):
|
34 |
+
model_type = "tiny_chart_phi"
|
35 |
+
|
36 |
+
|
37 |
+
class TinyChartPhiModel(LlavaMetaModel, PhiModel):
|
38 |
+
config_class = TinyChartPhiConfig
|
39 |
+
|
40 |
+
def __init__(self, config: PhiConfig):
|
41 |
+
super(TinyChartPhiModel, self).__init__(config)
|
42 |
+
self.gradient_checkpointing = False
|
43 |
+
|
44 |
+
@register_model('tinychart-3b')
|
45 |
+
class TinyChartPhiForCausalLM(PhiForCausalLM, LlavaMetaForCausalLM):
|
46 |
+
config_class = TinyChartPhiConfig
|
47 |
+
|
48 |
+
def __init__(self, config):
|
49 |
+
super(PhiForCausalLM, self).__init__(config)
|
50 |
+
self.model = TinyChartPhiModel(config)
|
51 |
+
self.vocab_size = config.vocab_size
|
52 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
53 |
+
|
54 |
+
# Initialize weights and apply final processing
|
55 |
+
self.post_init()
|
56 |
+
|
57 |
+
def get_model(self):
|
58 |
+
return self.model
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
input_ids: torch.LongTensor = None,
|
63 |
+
attention_mask: Optional[torch.Tensor] = None,
|
64 |
+
position_ids: Optional[torch.LongTensor] = None,
|
65 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
66 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
67 |
+
labels: Optional[torch.LongTensor] = None,
|
68 |
+
use_cache: Optional[bool] = None,
|
69 |
+
output_attentions: Optional[bool] = None,
|
70 |
+
output_hidden_states: Optional[bool] = None,
|
71 |
+
images: Optional[torch.FloatTensor] = None,
|
72 |
+
return_dict: Optional[bool] = None,
|
73 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
74 |
+
|
75 |
+
if inputs_embeds is None:
|
76 |
+
(
|
77 |
+
input_ids,
|
78 |
+
position_ids,
|
79 |
+
attention_mask,
|
80 |
+
past_key_values,
|
81 |
+
inputs_embeds,
|
82 |
+
labels
|
83 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
84 |
+
input_ids,
|
85 |
+
position_ids,
|
86 |
+
attention_mask,
|
87 |
+
past_key_values,
|
88 |
+
labels,
|
89 |
+
images,
|
90 |
+
)
|
91 |
+
|
92 |
+
return super().forward(
|
93 |
+
input_ids=input_ids,
|
94 |
+
attention_mask=attention_mask,
|
95 |
+
position_ids=position_ids,
|
96 |
+
past_key_values=past_key_values,
|
97 |
+
inputs_embeds=inputs_embeds,
|
98 |
+
labels=labels,
|
99 |
+
use_cache=use_cache,
|
100 |
+
output_attentions=output_attentions,
|
101 |
+
output_hidden_states=output_hidden_states,
|
102 |
+
return_dict=return_dict
|
103 |
+
)
|
104 |
+
|
105 |
+
@torch.no_grad()
|
106 |
+
def generate(
|
107 |
+
self,
|
108 |
+
inputs: Optional[torch.Tensor] = None,
|
109 |
+
images: Optional[torch.Tensor] = None,
|
110 |
+
**kwargs,
|
111 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
112 |
+
position_ids = kwargs.pop("position_ids", None)
|
113 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
114 |
+
if "inputs_embeds" in kwargs:
|
115 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
116 |
+
|
117 |
+
if images is not None:
|
118 |
+
(
|
119 |
+
inputs,
|
120 |
+
position_ids,
|
121 |
+
attention_mask,
|
122 |
+
_,
|
123 |
+
inputs_embeds,
|
124 |
+
_
|
125 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
126 |
+
inputs,
|
127 |
+
position_ids,
|
128 |
+
attention_mask,
|
129 |
+
None,
|
130 |
+
None,
|
131 |
+
images,
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
135 |
+
|
136 |
+
return super().generate(
|
137 |
+
position_ids=position_ids,
|
138 |
+
attention_mask=attention_mask,
|
139 |
+
inputs_embeds=inputs_embeds,
|
140 |
+
**kwargs
|
141 |
+
)
|
142 |
+
|
143 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
144 |
+
inputs_embeds=None, **kwargs):
|
145 |
+
images = kwargs.pop("images", None)
|
146 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
147 |
+
inputs = super().prepare_inputs_for_generation(
|
148 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
149 |
+
)
|
150 |
+
if images is not None:
|
151 |
+
inputs['images'] = images
|
152 |
+
if image_sizes is not None:
|
153 |
+
inputs['image_sizes'] = image_sizes
|
154 |
+
return inputs
|
155 |
+
|
156 |
+
@register_tokenizer('phi')
|
157 |
+
def get_tokenizer():
|
158 |
+
from transformers import AutoTokenizer
|
159 |
+
def post_init(tokenizer):
|
160 |
+
return tokenizer
|
161 |
+
return AutoTokenizer, post_init
|
162 |
+
|
163 |
+
AutoConfig.register("tiny_chart_phi", TinyChartPhiConfig)
|
164 |
+
AutoModelForCausalLM.register(TinyChartPhiConfig, TinyChartPhiForCausalLM)
|
tinychart/model/language_model/phi/cache_utils.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class Cache:
|
7 |
+
"""
|
8 |
+
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def update(
|
12 |
+
self,
|
13 |
+
key_states: torch.Tensor,
|
14 |
+
value_states: torch.Tensor,
|
15 |
+
layer_idx: int,
|
16 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
17 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
18 |
+
"""
|
19 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
key_states (`torch.Tensor`):
|
23 |
+
The new key states to cache.
|
24 |
+
value_states (`torch.Tensor`):
|
25 |
+
The new value states to cache.
|
26 |
+
layer_idx (`int`):
|
27 |
+
The index of the layer to cache the states for.
|
28 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
29 |
+
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
|
30 |
+
cache to be created.
|
31 |
+
|
32 |
+
Return:
|
33 |
+
A tuple containing the updated key and value states.
|
34 |
+
"""
|
35 |
+
raise NotImplementedError("Make sure to implement `update` in a subclass.")
|
36 |
+
|
37 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
38 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
39 |
+
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
40 |
+
|
41 |
+
def get_max_length(self) -> Optional[int]:
|
42 |
+
"""Returns the maximum sequence length of the cached states, if there is any."""
|
43 |
+
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
|
44 |
+
|
45 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
46 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
47 |
+
# Cache without size limit -> all cache is usable
|
48 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
49 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
50 |
+
max_length = self.get_max_length()
|
51 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
52 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
53 |
+
return max_length - new_seq_length
|
54 |
+
return previous_seq_length
|
55 |
+
|
56 |
+
|
57 |
+
class DynamicCache(Cache):
|
58 |
+
"""
|
59 |
+
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
60 |
+
|
61 |
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
62 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self) -> None:
|
66 |
+
self.key_cache: List[torch.Tensor] = []
|
67 |
+
self.value_cache: List[torch.Tensor] = []
|
68 |
+
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
69 |
+
|
70 |
+
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
71 |
+
"""
|
72 |
+
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
73 |
+
sequence length.
|
74 |
+
"""
|
75 |
+
if layer_idx < len(self):
|
76 |
+
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
77 |
+
else:
|
78 |
+
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
79 |
+
|
80 |
+
def __iter__(self):
|
81 |
+
"""
|
82 |
+
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
|
83 |
+
keys and values
|
84 |
+
"""
|
85 |
+
for layer_idx in range(len(self)):
|
86 |
+
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
"""
|
90 |
+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
91 |
+
to the number of layers in the model.
|
92 |
+
"""
|
93 |
+
return len(self.key_cache)
|
94 |
+
|
95 |
+
def update(
|
96 |
+
self,
|
97 |
+
key_states: torch.Tensor,
|
98 |
+
value_states: torch.Tensor,
|
99 |
+
layer_idx: int,
|
100 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
101 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
102 |
+
"""
|
103 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
104 |
+
|
105 |
+
Parameters:
|
106 |
+
key_states (`torch.Tensor`):
|
107 |
+
The new key states to cache.
|
108 |
+
value_states (`torch.Tensor`):
|
109 |
+
The new value states to cache.
|
110 |
+
layer_idx (`int`):
|
111 |
+
The index of the layer to cache the states for.
|
112 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
113 |
+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
114 |
+
|
115 |
+
Return:
|
116 |
+
A tuple containing the updated key and value states.
|
117 |
+
"""
|
118 |
+
# Update the number of seen tokens
|
119 |
+
if layer_idx == 0:
|
120 |
+
self.seen_tokens += key_states.shape[-2]
|
121 |
+
|
122 |
+
# Update the cache
|
123 |
+
if len(self.key_cache) <= layer_idx:
|
124 |
+
self.key_cache.append(key_states)
|
125 |
+
self.value_cache.append(value_states)
|
126 |
+
else:
|
127 |
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
128 |
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
129 |
+
|
130 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
131 |
+
|
132 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
133 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
134 |
+
if len(self.key_cache) <= layer_idx:
|
135 |
+
return 0
|
136 |
+
return self.key_cache[layer_idx].shape[-2]
|
137 |
+
|
138 |
+
def get_max_length(self) -> Optional[int]:
|
139 |
+
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
140 |
+
return None
|
141 |
+
|
142 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
143 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
144 |
+
for layer_idx in range(len(self.key_cache)):
|
145 |
+
device = self.key_cache[layer_idx].device
|
146 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
147 |
+
device = self.value_cache[layer_idx].device
|
148 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
149 |
+
|
150 |
+
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
151 |
+
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
|
152 |
+
legacy_cache = ()
|
153 |
+
for layer_idx in range(len(self)):
|
154 |
+
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
|
155 |
+
return legacy_cache
|
156 |
+
|
157 |
+
@classmethod
|
158 |
+
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
159 |
+
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
|
160 |
+
cache = cls()
|
161 |
+
if past_key_values is not None:
|
162 |
+
for layer_idx in range(len(past_key_values)):
|
163 |
+
key_states, value_states = past_key_values[layer_idx]
|
164 |
+
cache.update(key_states, value_states, layer_idx)
|
165 |
+
return cache
|
166 |
+
|
167 |
+
|
168 |
+
class SinkCache(Cache):
|
169 |
+
"""
|
170 |
+
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
171 |
+
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
|
172 |
+
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
|
173 |
+
|
174 |
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
175 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
176 |
+
|
177 |
+
Parameters:
|
178 |
+
window_length (`int`):
|
179 |
+
The length of the context window.
|
180 |
+
num_sink_tokens (`int`):
|
181 |
+
The number of sink tokens. See the original paper for more information.
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
185 |
+
self.key_cache: List[torch.Tensor] = []
|
186 |
+
self.value_cache: List[torch.Tensor] = []
|
187 |
+
self.window_length = window_length
|
188 |
+
self.num_sink_tokens = num_sink_tokens
|
189 |
+
self.cos_sin_cache = {}
|
190 |
+
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def _rotate_half(x):
|
194 |
+
x1 = x[..., : x.shape[-1] // 2]
|
195 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
196 |
+
return torch.cat((-x2, x1), dim=-1)
|
197 |
+
|
198 |
+
def _apply_key_rotary_pos_emb(
|
199 |
+
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
200 |
+
) -> torch.Tensor:
|
201 |
+
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
|
202 |
+
return rotated_key_states
|
203 |
+
|
204 |
+
def _get_rerotation_cos_sin(
|
205 |
+
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
206 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
207 |
+
if key_states.shape[-2] not in self.cos_sin_cache:
|
208 |
+
# Upcast to float32 temporarily for better accuracy
|
209 |
+
cos = cos.to(torch.float32)
|
210 |
+
sin = sin.to(torch.float32)
|
211 |
+
|
212 |
+
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
|
213 |
+
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
|
214 |
+
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
|
215 |
+
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
|
216 |
+
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
|
217 |
+
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
218 |
+
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
219 |
+
|
220 |
+
self.cos_sin_cache[key_states.shape[-2]] = (
|
221 |
+
rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
222 |
+
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
223 |
+
)
|
224 |
+
return self.cos_sin_cache[key_states.shape[-2]]
|
225 |
+
|
226 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
227 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
228 |
+
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
229 |
+
if len(self.key_cache) <= layer_idx:
|
230 |
+
return 0
|
231 |
+
return self.key_cache[layer_idx].shape[-2]
|
232 |
+
|
233 |
+
def get_max_length(self) -> Optional[int]:
|
234 |
+
"""Returns the maximum sequence length of the cached states."""
|
235 |
+
return self.window_length
|
236 |
+
|
237 |
+
def update(
|
238 |
+
self,
|
239 |
+
key_states: torch.Tensor,
|
240 |
+
value_states: torch.Tensor,
|
241 |
+
layer_idx: int,
|
242 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
243 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
244 |
+
"""
|
245 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
246 |
+
|
247 |
+
Parameters:
|
248 |
+
key_states (`torch.Tensor`):
|
249 |
+
The new key states to cache.
|
250 |
+
value_states (`torch.Tensor`):
|
251 |
+
The new value states to cache.
|
252 |
+
layer_idx (`int`):
|
253 |
+
The index of the layer to cache the states for.
|
254 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
255 |
+
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
|
256 |
+
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
|
257 |
+
rotation as the tokens are shifted.
|
258 |
+
|
259 |
+
Return:
|
260 |
+
A tuple containing the updated key and value states.
|
261 |
+
"""
|
262 |
+
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
|
263 |
+
# with partially rotated position embeddings, like Phi or Persimmon.
|
264 |
+
sin = cache_kwargs.get("sin")
|
265 |
+
cos = cache_kwargs.get("cos")
|
266 |
+
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
267 |
+
using_rope = cos is not None and sin is not None
|
268 |
+
|
269 |
+
# Update the number of seen tokens
|
270 |
+
if layer_idx == 0:
|
271 |
+
self.seen_tokens += key_states.shape[-2]
|
272 |
+
|
273 |
+
# [bsz, num_heads, seq_len, head_dim]
|
274 |
+
if len(self.key_cache) <= layer_idx:
|
275 |
+
# Empty cache
|
276 |
+
self.key_cache.append(key_states)
|
277 |
+
self.value_cache.append(value_states)
|
278 |
+
|
279 |
+
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
280 |
+
# Growing cache
|
281 |
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
282 |
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
283 |
+
|
284 |
+
else:
|
285 |
+
# Shifting cache
|
286 |
+
keys_to_keep = self.key_cache[layer_idx][
|
287 |
+
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
|
288 |
+
]
|
289 |
+
|
290 |
+
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
291 |
+
if using_rope:
|
292 |
+
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
293 |
+
key_states, cos[: self.window_length], sin[: self.window_length]
|
294 |
+
)
|
295 |
+
if partial_rotation_size is not None:
|
296 |
+
keys_to_keep, keys_pass = (
|
297 |
+
keys_to_keep[..., :partial_rotation_size],
|
298 |
+
keys_to_keep[..., partial_rotation_size:],
|
299 |
+
)
|
300 |
+
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
|
301 |
+
if partial_rotation_size is not None:
|
302 |
+
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
|
303 |
+
|
304 |
+
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
|
305 |
+
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
|
306 |
+
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
|
307 |
+
|
308 |
+
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
|
309 |
+
values_to_keep = self.value_cache[layer_idx][
|
310 |
+
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
|
311 |
+
]
|
312 |
+
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
|
313 |
+
|
314 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
315 |
+
|
316 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
317 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
318 |
+
for layer_idx in range(len(self.key_cache)):
|
319 |
+
device = self.key_cache[layer_idx].device
|
320 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
321 |
+
device = self.value_cache[layer_idx].device
|
322 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
tinychart/model/language_model/phi/configuration_phi.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" Phi model configuration"""
|
17 |
+
|
18 |
+
|
19 |
+
from transformers.configuration_utils import PretrainedConfig
|
20 |
+
from transformers.utils import logging
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
26 |
+
"microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json",
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
class PhiConfig(PretrainedConfig):
|
31 |
+
r"""
|
32 |
+
This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
|
33 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
34 |
+
defaults will yield a similar configuration to that of the Phi
|
35 |
+
[microsoft/phi-1](https://huggingface.co/microsoft/phi-1).
|
36 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
37 |
+
documentation from [`PretrainedConfig`] for more information.
|
38 |
+
Args:
|
39 |
+
vocab_size (`int`, *optional*, defaults to 51200):
|
40 |
+
Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the
|
41 |
+
`inputs_ids` passed when calling [`PhiModel`].
|
42 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
43 |
+
Dimension of the hidden representations.
|
44 |
+
intermediate_size (`int`, *optional*, defaults to 8192):
|
45 |
+
Dimension of the MLP representations.
|
46 |
+
num_hidden_layers (`int`, *optional*, defaults to 24):
|
47 |
+
Number of hidden layers in the Transformer decoder.
|
48 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
49 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
50 |
+
num_key_value_heads (`int`, *optional*):
|
51 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
52 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
53 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
54 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
55 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
56 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
57 |
+
`num_attention_heads`.
|
58 |
+
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
59 |
+
Dropout probability for mlp outputs.
|
60 |
+
embd_pdrop (`int`, *optional*, defaults to 0.0):
|
61 |
+
The dropout ratio for the embeddings.
|
62 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
63 |
+
The dropout ratio after computing the attention scores.
|
64 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
|
65 |
+
The non-linear activation function (function or string) in the decoder.
|
66 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
67 |
+
The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
|
68 |
+
tokens.
|
69 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
70 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
71 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
72 |
+
The epsilon used by the rms normalization layers.
|
73 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
74 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
75 |
+
relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
|
76 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
77 |
+
Whether to tie weight embeddings
|
78 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
79 |
+
The base period of the RoPE embeddings.
|
80 |
+
rope_scaling (`Dict`, *optional*):
|
81 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
82 |
+
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
83 |
+
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
84 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
85 |
+
these scaling strategies behave:
|
86 |
+
https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
|
87 |
+
is an experimental feature, subject to breaking API changes in future versions.
|
88 |
+
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
|
89 |
+
Percentage of the query and keys which will have rotary embedding.
|
90 |
+
qk_layernorm (`bool`, *optional*, defaults to `False`):
|
91 |
+
Whether or not to normalize the Queries and Keys after projecting the hidden states.
|
92 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
93 |
+
Denotes beginning of sequences token id.
|
94 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
95 |
+
Denotes end of sequences token id.
|
96 |
+
Example:
|
97 |
+
```python
|
98 |
+
>>> from transformers import PhiModel, PhiConfig
|
99 |
+
>>> # Initializing a Phi-1 style configuration
|
100 |
+
>>> configuration = PhiConfig.from_pretrained("microsoft/phi-1")
|
101 |
+
>>> # Initializing a model from the configuration
|
102 |
+
>>> model = PhiModel(configuration)
|
103 |
+
>>> # Accessing the model configuration
|
104 |
+
>>> configuration = model.config
|
105 |
+
```"""
|
106 |
+
|
107 |
+
model_type = "phi"
|
108 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
vocab_size=51200,
|
113 |
+
hidden_size=2048,
|
114 |
+
intermediate_size=8192,
|
115 |
+
num_hidden_layers=24,
|
116 |
+
num_attention_heads=32,
|
117 |
+
num_key_value_heads=None,
|
118 |
+
resid_pdrop=0.0,
|
119 |
+
embd_pdrop=0.0,
|
120 |
+
attention_dropout=0.0,
|
121 |
+
hidden_act="gelu_new",
|
122 |
+
max_position_embeddings=2048,
|
123 |
+
initializer_range=0.02,
|
124 |
+
layer_norm_eps=1e-5,
|
125 |
+
use_cache=True,
|
126 |
+
tie_word_embeddings=False,
|
127 |
+
rope_theta=10000.0,
|
128 |
+
rope_scaling=None,
|
129 |
+
partial_rotary_factor=0.5,
|
130 |
+
qk_layernorm=False,
|
131 |
+
bos_token_id=1,
|
132 |
+
eos_token_id=2,
|
133 |
+
**kwargs,
|
134 |
+
):
|
135 |
+
self.vocab_size = vocab_size
|
136 |
+
self.hidden_size = hidden_size
|
137 |
+
self.intermediate_size = intermediate_size
|
138 |
+
self.num_hidden_layers = num_hidden_layers
|
139 |
+
self.num_attention_heads = num_attention_heads
|
140 |
+
|
141 |
+
if num_key_value_heads is None:
|
142 |
+
num_key_value_heads = num_attention_heads
|
143 |
+
|
144 |
+
self.num_key_value_heads = num_key_value_heads
|
145 |
+
self.resid_pdrop = resid_pdrop
|
146 |
+
self.embd_pdrop = embd_pdrop
|
147 |
+
self.attention_dropout = attention_dropout
|
148 |
+
self.hidden_act = hidden_act
|
149 |
+
self.max_position_embeddings = max_position_embeddings
|
150 |
+
self.initializer_range = initializer_range
|
151 |
+
self.layer_norm_eps = layer_norm_eps
|
152 |
+
self.use_cache = use_cache
|
153 |
+
self.rope_theta = rope_theta
|
154 |
+
self.rope_scaling = rope_scaling
|
155 |
+
self.partial_rotary_factor = partial_rotary_factor
|
156 |
+
self.qk_layernorm = qk_layernorm
|
157 |
+
self._rope_scaling_validation()
|
158 |
+
|
159 |
+
super().__init__(
|
160 |
+
bos_token_id=bos_token_id,
|
161 |
+
eos_token_id=eos_token_id,
|
162 |
+
tie_word_embeddings=tie_word_embeddings,
|
163 |
+
**kwargs,
|
164 |
+
)
|
165 |
+
|
166 |
+
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
167 |
+
def _rope_scaling_validation(self):
|
168 |
+
"""
|
169 |
+
Validate the `rope_scaling` configuration.
|
170 |
+
"""
|
171 |
+
if self.rope_scaling is None:
|
172 |
+
return
|
173 |
+
|
174 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
175 |
+
raise ValueError(
|
176 |
+
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
177 |
+
f"got {self.rope_scaling}"
|
178 |
+
)
|
179 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
180 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
181 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
182 |
+
raise ValueError(
|
183 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
184 |
+
)
|
185 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
186 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
tinychart/model/language_model/phi/convert_phi_weights_to_hf.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""
|
17 |
+
Weights conversion script for Phi
|
18 |
+
|
19 |
+
This script downloads both Phi-1 and Phi-1.5 checkpoints to "checkpoint_path" and then converts the weights to
|
20 |
+
HugfgingFace model's format and saves them in "pytorch_dump_folder_path".
|
21 |
+
"""
|
22 |
+
|
23 |
+
import argparse
|
24 |
+
import gc
|
25 |
+
import os
|
26 |
+
|
27 |
+
import torch
|
28 |
+
from huggingface_hub import hf_hub_download
|
29 |
+
|
30 |
+
from modeling_phi import PhiConfig, PhiForCausalLM
|
31 |
+
|
32 |
+
|
33 |
+
_MODELS = {
|
34 |
+
"microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/blob/main/pytorch_model.bin",
|
35 |
+
"microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/blob/main/pytorch_model.bin",
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
PHI_MAPPING = {
|
40 |
+
"layers.0.wte.weight": "model.embed_tokens.weight",
|
41 |
+
"layers.25.linear.bias": "lm_head.bias",
|
42 |
+
"layers.25.linear.weight": "lm_head.weight",
|
43 |
+
"layers.25.ln.bias": "model.final_layernorm.bias",
|
44 |
+
"layers.25.ln.weight": "model.final_layernorm.weight",
|
45 |
+
"layers": "model.layers",
|
46 |
+
"ln": "input_layernorm",
|
47 |
+
"mixer": "self_attn",
|
48 |
+
"Wqkv": "query_key_value",
|
49 |
+
"out_proj": "dense",
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
def convert_weights(original_weights, mapping, config):
|
54 |
+
converted_weights = {}
|
55 |
+
original_weights_keys = sorted(original_weights.keys())
|
56 |
+
|
57 |
+
# we change names (1-24) -> layers(0-23) for Phi model layers
|
58 |
+
range_change = {
|
59 |
+
f"layers.{k}.": f"layers.{v}."
|
60 |
+
for k, v in zip(range(1, config.num_hidden_layers + 1), range(0, config.num_hidden_layers))
|
61 |
+
}
|
62 |
+
|
63 |
+
mapping.update(**range_change)
|
64 |
+
|
65 |
+
for original_weights_key in original_weights_keys:
|
66 |
+
new_key = original_weights_key
|
67 |
+
|
68 |
+
if "rotary_emb" in new_key:
|
69 |
+
continue
|
70 |
+
|
71 |
+
if "Wqkv" in new_key:
|
72 |
+
if "weight" in new_key:
|
73 |
+
weight = original_weights[new_key]
|
74 |
+
weights_shape = weight.shape
|
75 |
+
weight = (
|
76 |
+
weight.view(3, config.num_attention_heads, -1, config.hidden_size)
|
77 |
+
.transpose(0, 1)
|
78 |
+
.reshape(*weights_shape)
|
79 |
+
)
|
80 |
+
original_weights[new_key] = weight
|
81 |
+
elif "bias" in new_key:
|
82 |
+
bias = original_weights[new_key]
|
83 |
+
bias_shape = bias.shape
|
84 |
+
bias = bias.view(3, config.num_attention_heads, -1).transpose(0, 1).reshape(*bias_shape)
|
85 |
+
original_weights[new_key] = bias
|
86 |
+
|
87 |
+
for k, v in mapping.items():
|
88 |
+
if k in new_key:
|
89 |
+
new_key = new_key.replace(k, v)
|
90 |
+
|
91 |
+
converted_weights[new_key] = original_weights.pop(original_weights_key)
|
92 |
+
|
93 |
+
return converted_weights
|
94 |
+
|
95 |
+
|
96 |
+
def _download(url: str, root: str):
|
97 |
+
repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
|
98 |
+
filename = f"{url.split('/')[-1]}"
|
99 |
+
hf_hub_download(
|
100 |
+
repo_id=repo_id,
|
101 |
+
filename=filename,
|
102 |
+
force_filename=root,
|
103 |
+
local_dir_use_symlinks=False,
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
def convert_phi_weights(checkpoint_path, pytorch_dump_folder_path, use_cuda, save_weights_directly):
|
108 |
+
device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
|
109 |
+
for each_model_name, each_model_url in _MODELS.items():
|
110 |
+
converted_checkpoint = {}
|
111 |
+
|
112 |
+
model_path = os.path.join(checkpoint_path, each_model_name + "_" + each_model_url.split("/")[-1])
|
113 |
+
if not os.path.exists(model_path):
|
114 |
+
print(f"\n{each_model_name} was not found! Downloading it to {model_path}")
|
115 |
+
_download(url=each_model_url, root=model_path)
|
116 |
+
model_checkpoint = torch.load(model_path, map_location=device)
|
117 |
+
model_type = each_model_name.split("/")[1] # phi-1 or phi-1_5
|
118 |
+
config = PhiConfig.from_pretrained(f"susnato/{model_type}_dev")
|
119 |
+
|
120 |
+
# Converting the weights
|
121 |
+
converted_checkpoint.update(**convert_weights(model_checkpoint, PHI_MAPPING, config))
|
122 |
+
|
123 |
+
# Save either the whole model or the converted weights
|
124 |
+
if save_weights_directly:
|
125 |
+
save_weights_path = os.path.join(
|
126 |
+
pytorch_dump_folder_path, each_model_name.split("/")[-1] + "_" + each_model_url.split("/")[-1]
|
127 |
+
)
|
128 |
+
torch.save(converted_checkpoint, save_weights_path)
|
129 |
+
print(f"Model weights saved at {save_weights_path}!")
|
130 |
+
|
131 |
+
else:
|
132 |
+
model = PhiForCausalLM(config).to(device)
|
133 |
+
model.load_state_dict(converted_checkpoint, strict=True)
|
134 |
+
save_model_path = os.path.join(pytorch_dump_folder_path, model_type)
|
135 |
+
model.save_pretrained(save_model_path)
|
136 |
+
print(f"Model saved at {save_model_path}!")
|
137 |
+
|
138 |
+
# release GPU memory for the 2nd model if cuda was used.
|
139 |
+
del config, model
|
140 |
+
|
141 |
+
# release GPU memory for the 2nd model if cuda was used.
|
142 |
+
del model_checkpoint, converted_checkpoint
|
143 |
+
if use_cuda:
|
144 |
+
torch.cuda.empty_cache()
|
145 |
+
gc.collect()
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
parser = argparse.ArgumentParser()
|
150 |
+
# # Required parameters
|
151 |
+
parser.add_argument(
|
152 |
+
"--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--pytorch_dump_folder_path",
|
156 |
+
default=None,
|
157 |
+
type=str,
|
158 |
+
help="Path to the output PyTorch model. (Please enter full path)",
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--use_cuda",
|
162 |
+
default=False,
|
163 |
+
type=bool,
|
164 |
+
help="Whether to load the weights on GPU during conversion or not, False by default",
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--save_weights_directly",
|
168 |
+
default=True,
|
169 |
+
type=bool,
|
170 |
+
help="Whether to save the weights directly after conversion or load the weight to the Phi model and then save "
|
171 |
+
"the Phi model along with weights. True by default",
|
172 |
+
)
|
173 |
+
|
174 |
+
args = parser.parse_args()
|
175 |
+
convert_phi_weights(args.checkpoint_path, args.pytorch_dump_folder_path, args.use_cuda, args.save_weights_directly)
|
tinychart/model/language_model/phi/modeling_attn_mask_utils.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class AttentionMaskConverter:
|
22 |
+
"""
|
23 |
+
A utility attention mask class that allows one to:
|
24 |
+
- Create a causal 4d mask
|
25 |
+
- Create a causal 4d mask with slided window
|
26 |
+
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
|
27 |
+
key_value_length) that can be multiplied with attention scores
|
28 |
+
|
29 |
+
Examples:
|
30 |
+
|
31 |
+
```python
|
32 |
+
>>> import torch
|
33 |
+
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
34 |
+
|
35 |
+
>>> converter = AttentionMaskConverter(True)
|
36 |
+
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
|
37 |
+
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
38 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
39 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
40 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
|
41 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
|
42 |
+
```
|
43 |
+
|
44 |
+
Parameters:
|
45 |
+
is_causal (`bool`):
|
46 |
+
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
|
47 |
+
|
48 |
+
sliding_window (`int`, *optional*):
|
49 |
+
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
|
50 |
+
"""
|
51 |
+
|
52 |
+
is_causal: bool
|
53 |
+
sliding_window: int
|
54 |
+
|
55 |
+
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
|
56 |
+
self.is_causal = is_causal
|
57 |
+
self.sliding_window = sliding_window
|
58 |
+
|
59 |
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
60 |
+
raise ValueError(
|
61 |
+
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
|
62 |
+
)
|
63 |
+
|
64 |
+
def to_causal_4d(
|
65 |
+
self,
|
66 |
+
batch_size: int,
|
67 |
+
query_length: int,
|
68 |
+
key_value_length: int,
|
69 |
+
dtype: torch.dtype,
|
70 |
+
device: Union[torch.device, "str"] = "cpu",
|
71 |
+
) -> Optional[torch.Tensor]:
|
72 |
+
"""
|
73 |
+
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
|
74 |
+
bias to upper right hand triangular matrix (causal mask).
|
75 |
+
"""
|
76 |
+
if not self.is_causal:
|
77 |
+
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
|
78 |
+
|
79 |
+
# If shape is not cached, create a new causal mask and cache it
|
80 |
+
input_shape = (batch_size, query_length)
|
81 |
+
past_key_values_length = key_value_length - query_length
|
82 |
+
|
83 |
+
# create causal mask
|
84 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
85 |
+
causal_4d_mask = None
|
86 |
+
if input_shape[-1] > 1 or self.sliding_window is not None:
|
87 |
+
causal_4d_mask = self._make_causal_mask(
|
88 |
+
input_shape,
|
89 |
+
dtype,
|
90 |
+
device=device,
|
91 |
+
past_key_values_length=past_key_values_length,
|
92 |
+
sliding_window=self.sliding_window,
|
93 |
+
)
|
94 |
+
|
95 |
+
return causal_4d_mask
|
96 |
+
|
97 |
+
def to_4d(
|
98 |
+
self,
|
99 |
+
attention_mask_2d: torch.Tensor,
|
100 |
+
query_length: int,
|
101 |
+
dtype: torch.dtype,
|
102 |
+
key_value_length: Optional[int] = None,
|
103 |
+
) -> torch.Tensor:
|
104 |
+
"""
|
105 |
+
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
106 |
+
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
|
107 |
+
causal, a causal mask will be added.
|
108 |
+
"""
|
109 |
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
110 |
+
|
111 |
+
# create causal mask
|
112 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
113 |
+
causal_4d_mask = None
|
114 |
+
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
115 |
+
if key_value_length is None:
|
116 |
+
raise ValueError(
|
117 |
+
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
|
118 |
+
)
|
119 |
+
|
120 |
+
past_key_values_length = key_value_length - query_length
|
121 |
+
causal_4d_mask = self._make_causal_mask(
|
122 |
+
input_shape,
|
123 |
+
dtype,
|
124 |
+
device=attention_mask_2d.device,
|
125 |
+
past_key_values_length=past_key_values_length,
|
126 |
+
sliding_window=self.sliding_window,
|
127 |
+
)
|
128 |
+
elif self.sliding_window is not None:
|
129 |
+
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
|
130 |
+
|
131 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
132 |
+
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
|
133 |
+
attention_mask_2d.device
|
134 |
+
)
|
135 |
+
if causal_4d_mask is not None:
|
136 |
+
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
|
137 |
+
|
138 |
+
# expanded_attn_mask + causal_4d_mask can cause some overflow
|
139 |
+
expanded_4d_mask = expanded_attn_mask
|
140 |
+
|
141 |
+
return expanded_4d_mask
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def _make_causal_mask(
|
145 |
+
input_ids_shape: torch.Size,
|
146 |
+
dtype: torch.dtype,
|
147 |
+
device: torch.device,
|
148 |
+
past_key_values_length: int = 0,
|
149 |
+
sliding_window: Optional[int] = None,
|
150 |
+
):
|
151 |
+
"""
|
152 |
+
Make causal mask used for bi-directional self-attention.
|
153 |
+
"""
|
154 |
+
bsz, tgt_len = input_ids_shape
|
155 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
156 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
157 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
158 |
+
|
159 |
+
mask = mask.to(dtype)
|
160 |
+
|
161 |
+
if past_key_values_length > 0:
|
162 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
163 |
+
|
164 |
+
# add lower triangular sliding window mask if necessary
|
165 |
+
if sliding_window is not None:
|
166 |
+
diagonal = past_key_values_length - sliding_window + 1
|
167 |
+
|
168 |
+
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
|
169 |
+
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
|
170 |
+
|
171 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
172 |
+
|
173 |
+
@staticmethod
|
174 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
175 |
+
"""
|
176 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
177 |
+
"""
|
178 |
+
bsz, src_len = mask.size()
|
179 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
180 |
+
|
181 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
182 |
+
|
183 |
+
inverted_mask = 1.0 - expanded_mask
|
184 |
+
|
185 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
def _unmask_unattended(
|
189 |
+
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
|
190 |
+
):
|
191 |
+
# fmt: off
|
192 |
+
"""
|
193 |
+
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
|
194 |
+
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
195 |
+
Details: https://github.com/pytorch/pytorch/issues/110213
|
196 |
+
|
197 |
+
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
|
198 |
+
`attention_mask` is [bsz, src_seq_len].
|
199 |
+
|
200 |
+
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
|
201 |
+
|
202 |
+
For example, if `attention_mask` is
|
203 |
+
```
|
204 |
+
[[0, 0, 1],
|
205 |
+
[1, 1, 1],
|
206 |
+
[0, 1, 1]]
|
207 |
+
```
|
208 |
+
and `expanded_mask` is (e.g. here left-padding case)
|
209 |
+
```
|
210 |
+
[[[[0, 0, 0],
|
211 |
+
[0, 0, 0],
|
212 |
+
[0, 0, 1]]],
|
213 |
+
[[[1, 0, 0],
|
214 |
+
[1, 1, 0],
|
215 |
+
[1, 1, 1]]],
|
216 |
+
[[[0, 0, 0],
|
217 |
+
[0, 1, 0],
|
218 |
+
[0, 1, 1]]]]
|
219 |
+
```
|
220 |
+
then the modified `expanded_mask` will be
|
221 |
+
```
|
222 |
+
[[[[1, 1, 1], <-- modified
|
223 |
+
[1, 1, 1], <-- modified
|
224 |
+
[0, 0, 1]]],
|
225 |
+
[[[1, 0, 0],
|
226 |
+
[1, 1, 0],
|
227 |
+
[1, 1, 1]]],
|
228 |
+
[[[1, 1, 1], <-- modified
|
229 |
+
[0, 1, 0],
|
230 |
+
[0, 1, 1]]]]
|
231 |
+
```
|
232 |
+
"""
|
233 |
+
# fmt: on
|
234 |
+
|
235 |
+
# Get the index of the first non-zero value for every sample in the batch.
|
236 |
+
# In the above example, indices = [[2], [0], [1]]]
|
237 |
+
tmp = torch.arange(attention_mask.shape[1], 0, -1)
|
238 |
+
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
|
239 |
+
|
240 |
+
# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
|
241 |
+
# expanded mask will be completely unattended.
|
242 |
+
left_masked_rows = torch.where(indices > 0)[0]
|
243 |
+
|
244 |
+
if left_masked_rows.shape[0] == 0:
|
245 |
+
return expanded_mask
|
246 |
+
indices = indices[left_masked_rows]
|
247 |
+
|
248 |
+
max_len = torch.max(indices)
|
249 |
+
range_tensor = torch.arange(max_len).unsqueeze(0)
|
250 |
+
range_tensor = range_tensor.repeat(indices.size(0), 1)
|
251 |
+
|
252 |
+
# Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
|
253 |
+
range_tensor[range_tensor >= indices] = 0
|
254 |
+
|
255 |
+
# TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
|
256 |
+
if expanded_mask.dim() == 4:
|
257 |
+
num_masks = expanded_mask.shape[1]
|
258 |
+
if num_masks == 1:
|
259 |
+
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
|
260 |
+
mask_slice = (left_masked_rows[:, None], 0, range_tensor)
|
261 |
+
else:
|
262 |
+
# Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
|
263 |
+
mask_slice = (
|
264 |
+
left_masked_rows[:, None, None],
|
265 |
+
torch.arange(num_masks)[None, :, None],
|
266 |
+
range_tensor[:, None, :],
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
|
270 |
+
mask_slice = (left_masked_rows[:, None], range_tensor)
|
271 |
+
|
272 |
+
expanded_mask[mask_slice] = unmasked_value
|
273 |
+
|
274 |
+
return expanded_mask
|
275 |
+
|
276 |
+
|
277 |
+
def _prepare_4d_causal_attention_mask(
|
278 |
+
attention_mask: Optional[torch.Tensor],
|
279 |
+
input_shape: Union[torch.Size, Tuple, List],
|
280 |
+
inputs_embeds: torch.Tensor,
|
281 |
+
past_key_values_length: int,
|
282 |
+
sliding_window: Optional[int] = None,
|
283 |
+
):
|
284 |
+
"""
|
285 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
286 |
+
`(batch_size, key_value_length)`
|
287 |
+
|
288 |
+
Args:
|
289 |
+
attention_mask (`torch.Tensor` or `None`):
|
290 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
291 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
292 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
293 |
+
inputs_embeds (`torch.Tensor`):
|
294 |
+
The embedded inputs as a torch Tensor.
|
295 |
+
past_key_values_length (`int`):
|
296 |
+
The length of the key value cache.
|
297 |
+
sliding_window (`int`, *optional*):
|
298 |
+
If the model uses windowed attention, a sliding window should be passed.
|
299 |
+
"""
|
300 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
301 |
+
|
302 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
303 |
+
|
304 |
+
# 4d mask is passed through the layers
|
305 |
+
if attention_mask is not None and len(attention_mask.shape) == 2:
|
306 |
+
attention_mask = attn_mask_converter.to_4d(
|
307 |
+
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
|
308 |
+
)
|
309 |
+
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
310 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
311 |
+
if tuple(attention_mask.shape) != expected_shape:
|
312 |
+
raise ValueError(
|
313 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
317 |
+
inverted_mask = 1.0 - attention_mask
|
318 |
+
attention_mask = inverted_mask.masked_fill(
|
319 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
320 |
+
)
|
321 |
+
else:
|
322 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
323 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
324 |
+
)
|
325 |
+
|
326 |
+
return attention_mask
|
327 |
+
|
328 |
+
|
329 |
+
# Adapted from _prepare_4d_causal_attention_mask
|
330 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(
|
331 |
+
attention_mask: Optional[torch.Tensor],
|
332 |
+
input_shape: Union[torch.Size, Tuple, List],
|
333 |
+
inputs_embeds: torch.Tensor,
|
334 |
+
past_key_values_length: int,
|
335 |
+
sliding_window: Optional[int] = None,
|
336 |
+
):
|
337 |
+
"""
|
338 |
+
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
|
339 |
+
|
340 |
+
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
|
341 |
+
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
|
342 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
343 |
+
"""
|
344 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
345 |
+
|
346 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
347 |
+
batch_size, query_length = input_shape
|
348 |
+
|
349 |
+
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
350 |
+
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
351 |
+
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
|
352 |
+
is_tracing = torch.jit.is_tracing()
|
353 |
+
|
354 |
+
if attention_mask is not None:
|
355 |
+
# 4d mask is passed through
|
356 |
+
if len(attention_mask.shape) == 4:
|
357 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
358 |
+
if tuple(attention_mask.shape) != expected_shape:
|
359 |
+
raise ValueError(
|
360 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
364 |
+
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
|
365 |
+
attention_mask = inverted_mask.masked_fill(
|
366 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
367 |
+
)
|
368 |
+
return attention_mask
|
369 |
+
|
370 |
+
elif torch.all(attention_mask == 1):
|
371 |
+
if is_tracing:
|
372 |
+
pass
|
373 |
+
elif query_length == 1:
|
374 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
375 |
+
attention_mask = None
|
376 |
+
elif key_value_length == query_length:
|
377 |
+
attention_mask = None
|
378 |
+
else:
|
379 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
|
380 |
+
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
381 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
382 |
+
pass
|
383 |
+
elif query_length > 1 and key_value_length != query_length:
|
384 |
+
# See the comment above (https://github.com/pytorch/pytorch/issues/108108).
|
385 |
+
# Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
|
386 |
+
attention_mask = True
|
387 |
+
elif is_tracing:
|
388 |
+
raise ValueError(
|
389 |
+
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
|
390 |
+
)
|
391 |
+
|
392 |
+
if attention_mask is None:
|
393 |
+
expanded_4d_mask = None
|
394 |
+
elif attention_mask is True:
|
395 |
+
expanded_4d_mask = attn_mask_converter.to_causal_4d(
|
396 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
397 |
+
)
|
398 |
+
else:
|
399 |
+
expanded_4d_mask = attn_mask_converter.to_4d(
|
400 |
+
attention_mask,
|
401 |
+
input_shape[-1],
|
402 |
+
dtype=inputs_embeds.dtype,
|
403 |
+
key_value_length=key_value_length,
|
404 |
+
)
|
405 |
+
|
406 |
+
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
407 |
+
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
408 |
+
if query_length > 1:
|
409 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
410 |
+
expanded_4d_mask, attention_mask, unmasked_value=0.0
|
411 |
+
)
|
412 |
+
|
413 |
+
return expanded_4d_mask
|
414 |
+
|
415 |
+
|
416 |
+
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
417 |
+
"""
|
418 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
419 |
+
`(batch_size, key_value_length)`
|
420 |
+
|
421 |
+
Args:
|
422 |
+
mask (`torch.Tensor` or `None`):
|
423 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
424 |
+
dtype (`torch.dtype`):
|
425 |
+
The torch dtype the created mask shall have.
|
426 |
+
tgt_len (`int`):
|
427 |
+
The target length or query length the created mask shall have.
|
428 |
+
"""
|
429 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
430 |
+
|
431 |
+
|
432 |
+
def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
433 |
+
"""
|
434 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
435 |
+
`(batch_size, key_value_length)`
|
436 |
+
|
437 |
+
Args:
|
438 |
+
mask (`torch.Tensor` or `None`):
|
439 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
440 |
+
dtype (`torch.dtype`):
|
441 |
+
The torch dtype the created mask shall have.
|
442 |
+
tgt_len (`int`):
|
443 |
+
The target length or query length the created mask shall have.
|
444 |
+
"""
|
445 |
+
batch_size, key_value_length = mask.shape
|
446 |
+
tgt_len = tgt_len if tgt_len is not None else key_value_length
|
447 |
+
|
448 |
+
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
449 |
+
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
450 |
+
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
|
451 |
+
is_tracing = torch.jit.is_tracing()
|
452 |
+
|
453 |
+
if torch.all(mask == 1):
|
454 |
+
if is_tracing:
|
455 |
+
pass
|
456 |
+
elif tgt_len == 1:
|
457 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
458 |
+
return None
|
459 |
+
elif key_value_length == tgt_len:
|
460 |
+
return None
|
461 |
+
else:
|
462 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
|
463 |
+
# may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
464 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
465 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
466 |
+
else:
|
467 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
468 |
+
|
469 |
+
|
470 |
+
def _create_4d_causal_attention_mask(
|
471 |
+
input_shape: Union[torch.Size, Tuple, List],
|
472 |
+
dtype: torch.dtype,
|
473 |
+
device: torch.device,
|
474 |
+
past_key_values_length: int = 0,
|
475 |
+
sliding_window: Optional[int] = None,
|
476 |
+
) -> Optional[torch.Tensor]:
|
477 |
+
"""
|
478 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
|
479 |
+
|
480 |
+
Args:
|
481 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
482 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
483 |
+
dtype (`torch.dtype`):
|
484 |
+
The torch dtype the created mask shall have.
|
485 |
+
device (`int`):
|
486 |
+
The torch device the created mask shall have.
|
487 |
+
sliding_window (`int`, *optional*):
|
488 |
+
If the model uses windowed attention, a sliding window should be passed.
|
489 |
+
"""
|
490 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
491 |
+
|
492 |
+
key_value_length = past_key_values_length + input_shape[-1]
|
493 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
494 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
|
495 |
+
)
|
496 |
+
|
497 |
+
return attention_mask
|
tinychart/model/language_model/phi/modeling_phi.py
ADDED
@@ -0,0 +1,1345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" PyTorch Phi model."""
|
17 |
+
|
18 |
+
|
19 |
+
import math
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
|
28 |
+
from transformers.activations import ACT2FN
|
29 |
+
from transformers.cache_utils import Cache, DynamicCache
|
30 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
31 |
+
from transformers.modeling_outputs import (
|
32 |
+
BaseModelOutputWithPast,
|
33 |
+
CausalLMOutputWithPast,
|
34 |
+
SequenceClassifierOutputWithPast,
|
35 |
+
TokenClassifierOutput,
|
36 |
+
)
|
37 |
+
from transformers.modeling_utils import PreTrainedModel
|
38 |
+
from transformers.utils import (
|
39 |
+
add_code_sample_docstrings,
|
40 |
+
add_start_docstrings,
|
41 |
+
add_start_docstrings_to_model_forward,
|
42 |
+
is_flash_attn_2_available,
|
43 |
+
is_flash_attn_greater_or_equal_2_10,
|
44 |
+
logging,
|
45 |
+
replace_return_docstrings,
|
46 |
+
)
|
47 |
+
from .configuration_phi import PhiConfig
|
48 |
+
|
49 |
+
|
50 |
+
try:
|
51 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
52 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
53 |
+
except:
|
54 |
+
pass
|
55 |
+
|
56 |
+
|
57 |
+
logger = logging.get_logger(__name__)
|
58 |
+
|
59 |
+
_CHECKPOINT_FOR_DOC = "microsoft/phi-2"
|
60 |
+
_CONFIG_FOR_DOC = "PhiConfig"
|
61 |
+
|
62 |
+
PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
63 |
+
"microsoft/phi-2",
|
64 |
+
# See all Phi models at https://huggingface.co/models?filter=phi
|
65 |
+
]
|
66 |
+
|
67 |
+
|
68 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
69 |
+
def _get_unpad_data(attention_mask):
|
70 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
71 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
72 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
73 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
74 |
+
return (
|
75 |
+
indices,
|
76 |
+
cu_seqlens,
|
77 |
+
max_seqlen_in_batch,
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
|
82 |
+
class PhiRotaryEmbedding(nn.Module):
|
83 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
self.dim = dim
|
87 |
+
self.max_position_embeddings = max_position_embeddings
|
88 |
+
self.base = base
|
89 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
90 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
91 |
+
|
92 |
+
# Build here to make `torch.jit.trace` work.
|
93 |
+
self._set_cos_sin_cache(
|
94 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
95 |
+
)
|
96 |
+
|
97 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
98 |
+
self.max_seq_len_cached = seq_len
|
99 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
100 |
+
|
101 |
+
freqs = torch.outer(t, self.inv_freq)
|
102 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
103 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
104 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
105 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
106 |
+
|
107 |
+
def forward(self, x, seq_len=None):
|
108 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
109 |
+
if seq_len > self.max_seq_len_cached:
|
110 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
111 |
+
|
112 |
+
return (
|
113 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
114 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
|
119 |
+
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
|
120 |
+
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
121 |
+
|
122 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
123 |
+
self.scaling_factor = scaling_factor
|
124 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
125 |
+
|
126 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
127 |
+
self.max_seq_len_cached = seq_len
|
128 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
129 |
+
t = t / self.scaling_factor
|
130 |
+
|
131 |
+
freqs = torch.outer(t, self.inv_freq)
|
132 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
133 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
134 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
135 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
136 |
+
|
137 |
+
|
138 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
|
139 |
+
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
|
140 |
+
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
141 |
+
|
142 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
143 |
+
self.scaling_factor = scaling_factor
|
144 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
145 |
+
|
146 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
147 |
+
self.max_seq_len_cached = seq_len
|
148 |
+
|
149 |
+
if seq_len > self.max_position_embeddings:
|
150 |
+
base = self.base * (
|
151 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
152 |
+
) ** (self.dim / (self.dim - 2))
|
153 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
154 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
155 |
+
|
156 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
157 |
+
|
158 |
+
freqs = torch.outer(t, self.inv_freq)
|
159 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
160 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
161 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
162 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
163 |
+
|
164 |
+
|
165 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
166 |
+
def rotate_half(x):
|
167 |
+
"""Rotates half the hidden dims of the input."""
|
168 |
+
x1 = x[..., : x.shape[-1] // 2]
|
169 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
170 |
+
return torch.cat((-x2, x1), dim=-1)
|
171 |
+
|
172 |
+
|
173 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
174 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
175 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
176 |
+
Args:
|
177 |
+
q (`torch.Tensor`): The query tensor.
|
178 |
+
k (`torch.Tensor`): The key tensor.
|
179 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
180 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
181 |
+
position_ids (`torch.Tensor`):
|
182 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
183 |
+
used to pass offsetted position ids when working with a KV-cache.
|
184 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
185 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
186 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
187 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
188 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
189 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
190 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
191 |
+
Returns:
|
192 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
193 |
+
"""
|
194 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
195 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
196 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
197 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
198 |
+
return q_embed, k_embed
|
199 |
+
|
200 |
+
|
201 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
|
202 |
+
class PhiMLP(nn.Module):
|
203 |
+
def __init__(self, config):
|
204 |
+
super().__init__()
|
205 |
+
self.config = config
|
206 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
207 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
208 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
209 |
+
|
210 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
211 |
+
hidden_states = self.fc1(hidden_states)
|
212 |
+
hidden_states = self.activation_fn(hidden_states)
|
213 |
+
hidden_states = self.fc2(hidden_states)
|
214 |
+
return hidden_states
|
215 |
+
|
216 |
+
|
217 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
|
218 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
219 |
+
"""
|
220 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
221 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
222 |
+
"""
|
223 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
224 |
+
if n_rep == 1:
|
225 |
+
return hidden_states
|
226 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
227 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
228 |
+
|
229 |
+
|
230 |
+
class PhiAttention(nn.Module):
|
231 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
232 |
+
|
233 |
+
def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
|
234 |
+
super().__init__()
|
235 |
+
self.config = config
|
236 |
+
self.layer_idx = layer_idx
|
237 |
+
if layer_idx is None:
|
238 |
+
logger.warning_once(
|
239 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
240 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
241 |
+
"when creating this class."
|
242 |
+
)
|
243 |
+
|
244 |
+
self.attention_dropout = config.attention_dropout
|
245 |
+
self.hidden_size = config.hidden_size
|
246 |
+
self.num_heads = config.num_attention_heads
|
247 |
+
self.head_dim = self.hidden_size // self.num_heads
|
248 |
+
self.num_key_value_heads = config.num_key_value_heads
|
249 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
250 |
+
self.max_position_embeddings = config.max_position_embeddings
|
251 |
+
self.rope_theta = config.rope_theta
|
252 |
+
self.partial_rotary_factor = config.partial_rotary_factor
|
253 |
+
self.is_causal = True
|
254 |
+
|
255 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
256 |
+
raise ValueError(
|
257 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
258 |
+
f" and `num_heads`: {self.num_heads})."
|
259 |
+
)
|
260 |
+
|
261 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
262 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
263 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
264 |
+
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
|
265 |
+
|
266 |
+
self.qk_layernorm = config.qk_layernorm
|
267 |
+
if self.qk_layernorm:
|
268 |
+
self.q_layernorm = nn.LayerNorm(
|
269 |
+
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
270 |
+
)
|
271 |
+
self.k_layernorm = nn.LayerNorm(
|
272 |
+
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
273 |
+
)
|
274 |
+
|
275 |
+
self._init_rope()
|
276 |
+
|
277 |
+
def _init_rope(self):
|
278 |
+
if self.config.rope_scaling is None:
|
279 |
+
self.rotary_emb = PhiRotaryEmbedding(
|
280 |
+
int(self.partial_rotary_factor * self.head_dim),
|
281 |
+
max_position_embeddings=self.max_position_embeddings,
|
282 |
+
base=self.rope_theta,
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
scaling_type = self.config.rope_scaling["type"]
|
286 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
287 |
+
if scaling_type == "linear":
|
288 |
+
self.rotary_emb = PhiLinearScalingRotaryEmbedding(
|
289 |
+
int(self.partial_rotary_factor * self.head_dim),
|
290 |
+
max_position_embeddings=self.max_position_embeddings,
|
291 |
+
scaling_factor=scaling_factor,
|
292 |
+
base=self.rope_theta,
|
293 |
+
)
|
294 |
+
elif scaling_type == "dynamic":
|
295 |
+
self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
|
296 |
+
int(self.partial_rotary_factor * self.head_dim),
|
297 |
+
max_position_embeddings=self.max_position_embeddings,
|
298 |
+
scaling_factor=scaling_factor,
|
299 |
+
base=self.rope_theta,
|
300 |
+
)
|
301 |
+
else:
|
302 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
303 |
+
|
304 |
+
# Phi-2 has an attention overflow issue (with FP16) and requires autocast to be disabled
|
305 |
+
@torch.autocast("cpu", enabled=False)
|
306 |
+
@torch.autocast("cuda", enabled=False)
|
307 |
+
def forward(
|
308 |
+
self,
|
309 |
+
hidden_states: torch.Tensor,
|
310 |
+
attention_mask: Optional[torch.Tensor] = None,
|
311 |
+
position_ids: Optional[torch.LongTensor] = None,
|
312 |
+
past_key_value: Optional[Cache] = None,
|
313 |
+
output_attentions: bool = False,
|
314 |
+
use_cache: bool = False,
|
315 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
316 |
+
bsz, q_len, _ = hidden_states.size()
|
317 |
+
|
318 |
+
query_states = self.q_proj(hidden_states)
|
319 |
+
key_states = self.k_proj(hidden_states)
|
320 |
+
value_states = self.v_proj(hidden_states)
|
321 |
+
|
322 |
+
if self.qk_layernorm:
|
323 |
+
query_states = self.q_layernorm(query_states)
|
324 |
+
key_states = self.k_layernorm(key_states)
|
325 |
+
|
326 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
327 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
328 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
329 |
+
|
330 |
+
kv_seq_len = key_states.shape[-2]
|
331 |
+
if past_key_value is not None:
|
332 |
+
if self.layer_idx is None:
|
333 |
+
raise ValueError(
|
334 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
335 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
336 |
+
"with a layer index."
|
337 |
+
)
|
338 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
339 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
340 |
+
|
341 |
+
# Partial rotary embedding
|
342 |
+
query_rot, query_pass = (
|
343 |
+
query_states[..., : self.rotary_emb.dim],
|
344 |
+
query_states[..., self.rotary_emb.dim :],
|
345 |
+
)
|
346 |
+
key_rot, key_pass = (
|
347 |
+
key_states[..., : self.rotary_emb.dim],
|
348 |
+
key_states[..., self.rotary_emb.dim :],
|
349 |
+
)
|
350 |
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
351 |
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
352 |
+
|
353 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
354 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
355 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
356 |
+
|
357 |
+
if past_key_value is not None:
|
358 |
+
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
359 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
360 |
+
|
361 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
362 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
363 |
+
|
364 |
+
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
365 |
+
attn_weights = torch.matmul(
|
366 |
+
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
367 |
+
) / math.sqrt(self.head_dim)
|
368 |
+
|
369 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
370 |
+
raise ValueError(
|
371 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
372 |
+
f" {attn_weights.size()}"
|
373 |
+
)
|
374 |
+
|
375 |
+
if attention_mask is not None:
|
376 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
377 |
+
raise ValueError(
|
378 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
379 |
+
)
|
380 |
+
attn_weights = attn_weights + attention_mask
|
381 |
+
|
382 |
+
# upcast attention to fp32
|
383 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
|
384 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
385 |
+
|
386 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
387 |
+
|
388 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
389 |
+
raise ValueError(
|
390 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
391 |
+
f" {attn_output.size()}"
|
392 |
+
)
|
393 |
+
|
394 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
395 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
396 |
+
|
397 |
+
attn_output = self.dense(attn_output)
|
398 |
+
|
399 |
+
if not output_attentions:
|
400 |
+
attn_weights = None
|
401 |
+
|
402 |
+
return attn_output, attn_weights, past_key_value
|
403 |
+
|
404 |
+
|
405 |
+
class PhiFlashAttention2(PhiAttention):
|
406 |
+
"""
|
407 |
+
Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
|
408 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
409 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
410 |
+
"""
|
411 |
+
|
412 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
413 |
+
def __init__(self, *args, **kwargs):
|
414 |
+
super().__init__(*args, **kwargs)
|
415 |
+
|
416 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
417 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
418 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
419 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
420 |
+
|
421 |
+
def forward(
|
422 |
+
self,
|
423 |
+
hidden_states: torch.Tensor,
|
424 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
425 |
+
position_ids: Optional[torch.LongTensor] = None,
|
426 |
+
past_key_value: Optional[Cache] = None,
|
427 |
+
output_attentions: bool = False,
|
428 |
+
use_cache: bool = False,
|
429 |
+
**kwargs,
|
430 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
431 |
+
# PhiFlashAttention2 attention does not support output_attentions
|
432 |
+
|
433 |
+
output_attentions = False
|
434 |
+
|
435 |
+
bsz, q_len, _ = hidden_states.size()
|
436 |
+
|
437 |
+
query_states = self.q_proj(hidden_states)
|
438 |
+
key_states = self.k_proj(hidden_states)
|
439 |
+
value_states = self.v_proj(hidden_states)
|
440 |
+
|
441 |
+
if self.qk_layernorm:
|
442 |
+
query_states = self.q_layernorm(query_states)
|
443 |
+
key_states = self.k_layernorm(key_states)
|
444 |
+
|
445 |
+
# Flash attention requires the input to have the shape
|
446 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
447 |
+
# therefore we just need to keep the original shape
|
448 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
449 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
450 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
451 |
+
|
452 |
+
kv_seq_len = key_states.shape[-2]
|
453 |
+
if past_key_value is not None:
|
454 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
455 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
456 |
+
|
457 |
+
# Partial rotary embedding
|
458 |
+
query_rot, query_pass = (
|
459 |
+
query_states[..., : self.rotary_emb.dim],
|
460 |
+
query_states[..., self.rotary_emb.dim :],
|
461 |
+
)
|
462 |
+
key_rot, key_pass = (
|
463 |
+
key_states[..., : self.rotary_emb.dim],
|
464 |
+
key_states[..., self.rotary_emb.dim :],
|
465 |
+
)
|
466 |
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
467 |
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
468 |
+
|
469 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
470 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
471 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
472 |
+
|
473 |
+
if past_key_value is not None:
|
474 |
+
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
475 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
476 |
+
|
477 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
478 |
+
# to be able to avoid many of these transpose/reshape/view.
|
479 |
+
query_states = query_states.transpose(1, 2)
|
480 |
+
key_states = key_states.transpose(1, 2)
|
481 |
+
value_states = value_states.transpose(1, 2)
|
482 |
+
|
483 |
+
attn_dropout = self.attention_dropout if self.training else 0.0
|
484 |
+
|
485 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
486 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
487 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
488 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
489 |
+
# in fp32.
|
490 |
+
|
491 |
+
if query_states.dtype == torch.float32:
|
492 |
+
if torch.is_autocast_enabled():
|
493 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
494 |
+
# Handle the case where the model is quantized
|
495 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
496 |
+
target_dtype = self.config._pre_quantization_dtype
|
497 |
+
else:
|
498 |
+
target_dtype = self.q_proj.weight.dtype
|
499 |
+
|
500 |
+
logger.warning_once(
|
501 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
502 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
503 |
+
f" {target_dtype}."
|
504 |
+
)
|
505 |
+
|
506 |
+
query_states = query_states.to(target_dtype)
|
507 |
+
key_states = key_states.to(target_dtype)
|
508 |
+
value_states = value_states.to(target_dtype)
|
509 |
+
|
510 |
+
attn_output = self._flash_attention_forward(
|
511 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
|
512 |
+
)
|
513 |
+
|
514 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
515 |
+
attn_output = self.dense(attn_output)
|
516 |
+
|
517 |
+
if not output_attentions:
|
518 |
+
attn_weights = None
|
519 |
+
|
520 |
+
return attn_output, attn_weights, past_key_value
|
521 |
+
|
522 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
523 |
+
def _flash_attention_forward(
|
524 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
525 |
+
):
|
526 |
+
"""
|
527 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
528 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
529 |
+
Args:
|
530 |
+
query_states (`torch.Tensor`):
|
531 |
+
Input query states to be passed to Flash Attention API
|
532 |
+
key_states (`torch.Tensor`):
|
533 |
+
Input key states to be passed to Flash Attention API
|
534 |
+
value_states (`torch.Tensor`):
|
535 |
+
Input value states to be passed to Flash Attention API
|
536 |
+
attention_mask (`torch.Tensor`):
|
537 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
538 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
539 |
+
dropout (`int`, *optional*):
|
540 |
+
Attention dropout
|
541 |
+
softmax_scale (`float`, *optional*):
|
542 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
543 |
+
"""
|
544 |
+
if not self._flash_attn_uses_top_left_mask:
|
545 |
+
causal = self.is_causal
|
546 |
+
else:
|
547 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
548 |
+
causal = self.is_causal and query_length != 1
|
549 |
+
|
550 |
+
# Contains at least one padding token in the sequence
|
551 |
+
if attention_mask is not None:
|
552 |
+
batch_size = query_states.shape[0]
|
553 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
554 |
+
query_states, key_states, value_states, attention_mask, query_length
|
555 |
+
)
|
556 |
+
|
557 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
558 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
559 |
+
|
560 |
+
attn_output_unpad = flash_attn_varlen_func(
|
561 |
+
query_states,
|
562 |
+
key_states,
|
563 |
+
value_states,
|
564 |
+
cu_seqlens_q=cu_seqlens_q,
|
565 |
+
cu_seqlens_k=cu_seqlens_k,
|
566 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
567 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
568 |
+
dropout_p=dropout,
|
569 |
+
softmax_scale=softmax_scale,
|
570 |
+
causal=causal,
|
571 |
+
)
|
572 |
+
|
573 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
574 |
+
else:
|
575 |
+
attn_output = flash_attn_func(
|
576 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
577 |
+
)
|
578 |
+
|
579 |
+
return attn_output
|
580 |
+
|
581 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
582 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
583 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
584 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
585 |
+
|
586 |
+
key_layer = index_first_axis(
|
587 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
588 |
+
)
|
589 |
+
value_layer = index_first_axis(
|
590 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
591 |
+
)
|
592 |
+
if query_length == kv_seq_len:
|
593 |
+
query_layer = index_first_axis(
|
594 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
595 |
+
)
|
596 |
+
cu_seqlens_q = cu_seqlens_k
|
597 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
598 |
+
indices_q = indices_k
|
599 |
+
elif query_length == 1:
|
600 |
+
max_seqlen_in_batch_q = 1
|
601 |
+
cu_seqlens_q = torch.arange(
|
602 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
603 |
+
) # There is a memcpy here, that is very bad.
|
604 |
+
indices_q = cu_seqlens_q[:-1]
|
605 |
+
query_layer = query_layer.squeeze(1)
|
606 |
+
else:
|
607 |
+
# The -q_len: slice assumes left padding.
|
608 |
+
attention_mask = attention_mask[:, -query_length:]
|
609 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
610 |
+
|
611 |
+
return (
|
612 |
+
query_layer,
|
613 |
+
key_layer,
|
614 |
+
value_layer,
|
615 |
+
indices_q,
|
616 |
+
(cu_seqlens_q, cu_seqlens_k),
|
617 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
618 |
+
)
|
619 |
+
|
620 |
+
|
621 |
+
PHI_ATTENTION_CLASSES = {
|
622 |
+
"eager": PhiAttention,
|
623 |
+
"flash_attention_2": PhiFlashAttention2,
|
624 |
+
}
|
625 |
+
|
626 |
+
|
627 |
+
class PhiDecoderLayer(nn.Module):
|
628 |
+
def __init__(self, config: PhiConfig, layer_idx: int):
|
629 |
+
super().__init__()
|
630 |
+
self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
631 |
+
self.mlp = PhiMLP(config)
|
632 |
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
633 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
634 |
+
|
635 |
+
def forward(
|
636 |
+
self,
|
637 |
+
hidden_states: torch.Tensor,
|
638 |
+
attention_mask: Optional[torch.Tensor] = None,
|
639 |
+
position_ids: Optional[torch.LongTensor] = None,
|
640 |
+
output_attentions: Optional[bool] = False,
|
641 |
+
use_cache: Optional[bool] = False,
|
642 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
643 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
644 |
+
"""
|
645 |
+
Args:
|
646 |
+
hidden_states (`torch.FloatTensor`):
|
647 |
+
input to the layer of shape `(batch, seq_len, embed_dim)`
|
648 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
649 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
650 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
651 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
652 |
+
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
653 |
+
output_attentions (`bool`, *optional*):
|
654 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
655 |
+
returned tensors for more detail.
|
656 |
+
use_cache (`bool`, *optional*):
|
657 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
658 |
+
(see `past_key_values`).
|
659 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
660 |
+
"""
|
661 |
+
|
662 |
+
residual = hidden_states
|
663 |
+
|
664 |
+
hidden_states = self.input_layernorm(hidden_states)
|
665 |
+
|
666 |
+
# Self Attention
|
667 |
+
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
|
668 |
+
hidden_states=hidden_states,
|
669 |
+
attention_mask=attention_mask,
|
670 |
+
position_ids=position_ids,
|
671 |
+
past_key_value=past_key_value,
|
672 |
+
output_attentions=output_attentions,
|
673 |
+
use_cache=use_cache,
|
674 |
+
)
|
675 |
+
attn_outputs = self.resid_dropout(attn_outputs)
|
676 |
+
|
677 |
+
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
678 |
+
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
679 |
+
outputs = (hidden_states,)
|
680 |
+
|
681 |
+
if output_attentions:
|
682 |
+
outputs += (self_attn_weights,)
|
683 |
+
|
684 |
+
if use_cache:
|
685 |
+
outputs += (present_key_value,)
|
686 |
+
|
687 |
+
return outputs
|
688 |
+
|
689 |
+
|
690 |
+
PHI_START_DOCSTRING = r"""
|
691 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
692 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
693 |
+
etc.)
|
694 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
695 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
696 |
+
and behavior.
|
697 |
+
Parameters:
|
698 |
+
config ([`PhiConfig`]):
|
699 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
700 |
+
load the weights associated with the model, only the configuration. Check out the
|
701 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
702 |
+
"""
|
703 |
+
|
704 |
+
|
705 |
+
@add_start_docstrings(
|
706 |
+
"The bare Phi Model outputting raw hidden-states without any specific head on top.",
|
707 |
+
PHI_START_DOCSTRING,
|
708 |
+
)
|
709 |
+
class PhiPreTrainedModel(PreTrainedModel):
|
710 |
+
config_class = PhiConfig
|
711 |
+
base_model_prefix = "model"
|
712 |
+
supports_gradient_checkpointing = True
|
713 |
+
_no_split_modules = ["PhiDecoderLayer"]
|
714 |
+
_skip_keys_device_placement = "past_key_values"
|
715 |
+
_supports_flash_attn_2 = True
|
716 |
+
_supports_cache_class = True
|
717 |
+
|
718 |
+
def _init_weights(self, module):
|
719 |
+
std = self.config.initializer_range
|
720 |
+
if isinstance(module, nn.Linear):
|
721 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
722 |
+
if module.bias is not None:
|
723 |
+
module.bias.data.zero_()
|
724 |
+
elif isinstance(module, nn.Embedding):
|
725 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
726 |
+
if module.padding_idx is not None:
|
727 |
+
module.weight.data[module.padding_idx].zero_()
|
728 |
+
|
729 |
+
|
730 |
+
PHI_INPUTS_DOCSTRING = r"""
|
731 |
+
Args:
|
732 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
733 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
734 |
+
it.
|
735 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
736 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
737 |
+
[What are input IDs?](../glossary#input-ids)
|
738 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
739 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
740 |
+
- 1 for tokens that are **not masked**,
|
741 |
+
- 0 for tokens that are **masked**.
|
742 |
+
[What are attention masks?](../glossary#attention-mask)
|
743 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
744 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
745 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
746 |
+
`past_key_values`).
|
747 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
748 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
749 |
+
information on the default strategy.
|
750 |
+
- 1 indicates the head is **not masked**,
|
751 |
+
- 0 indicates the head is **masked**.
|
752 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
753 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
754 |
+
config.n_positions - 1]`.
|
755 |
+
[What are position IDs?](../glossary#position-ids)
|
756 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
757 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
758 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
759 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
760 |
+
Two formats are allowed:
|
761 |
+
- a [`~cache_utils.Cache`] instance;
|
762 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
763 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
764 |
+
cache format.
|
765 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
766 |
+
legacy cache format will be returned.
|
767 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
768 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
769 |
+
of shape `(batch_size, sequence_length)`.
|
770 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
771 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
772 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
773 |
+
model's internal embedding lookup matrix.
|
774 |
+
use_cache (`bool`, *optional*):
|
775 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
776 |
+
`past_key_values`).
|
777 |
+
output_attentions (`bool`, *optional*):
|
778 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
779 |
+
tensors for more detail.
|
780 |
+
output_hidden_states (`bool`, *optional*):
|
781 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
782 |
+
more detail.
|
783 |
+
return_dict (`bool`, *optional*):
|
784 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
785 |
+
"""
|
786 |
+
|
787 |
+
|
788 |
+
@add_start_docstrings(
|
789 |
+
"The bare Phi Model outputting raw hidden-states without any specific head on top.",
|
790 |
+
PHI_START_DOCSTRING,
|
791 |
+
)
|
792 |
+
class PhiModel(PhiPreTrainedModel):
|
793 |
+
"""
|
794 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
|
795 |
+
Args:
|
796 |
+
config: PhiConfig
|
797 |
+
"""
|
798 |
+
|
799 |
+
def __init__(self, config: PhiConfig):
|
800 |
+
super().__init__(config)
|
801 |
+
self.padding_idx = config.pad_token_id
|
802 |
+
self.vocab_size = config.vocab_size
|
803 |
+
|
804 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
805 |
+
self.embed_dropout = nn.Dropout(config.embd_pdrop)
|
806 |
+
self.layers = nn.ModuleList(
|
807 |
+
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
808 |
+
)
|
809 |
+
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
810 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
811 |
+
|
812 |
+
self.gradient_checkpointing = False
|
813 |
+
# Initialize weights and apply final processing
|
814 |
+
self.post_init()
|
815 |
+
|
816 |
+
def get_input_embeddings(self):
|
817 |
+
return self.embed_tokens
|
818 |
+
|
819 |
+
def set_input_embeddings(self, value):
|
820 |
+
self.embed_tokens = value
|
821 |
+
|
822 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
823 |
+
def forward(
|
824 |
+
self,
|
825 |
+
input_ids: torch.LongTensor = None,
|
826 |
+
attention_mask: Optional[torch.Tensor] = None,
|
827 |
+
position_ids: Optional[torch.LongTensor] = None,
|
828 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
829 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
830 |
+
use_cache: Optional[bool] = None,
|
831 |
+
output_attentions: Optional[bool] = None,
|
832 |
+
output_hidden_states: Optional[bool] = None,
|
833 |
+
return_dict: Optional[bool] = None,
|
834 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
835 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
836 |
+
output_hidden_states = (
|
837 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
838 |
+
)
|
839 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
840 |
+
|
841 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
842 |
+
|
843 |
+
# retrieve input_ids and inputs_embeds
|
844 |
+
if input_ids is not None and inputs_embeds is not None:
|
845 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
846 |
+
elif input_ids is not None:
|
847 |
+
batch_size, seq_length = input_ids.shape[:2]
|
848 |
+
elif inputs_embeds is not None:
|
849 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
850 |
+
else:
|
851 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
852 |
+
|
853 |
+
past_key_values_length = 0
|
854 |
+
|
855 |
+
if self.gradient_checkpointing and self.training:
|
856 |
+
if use_cache:
|
857 |
+
logger.warning_once(
|
858 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
859 |
+
)
|
860 |
+
use_cache = False
|
861 |
+
|
862 |
+
if use_cache:
|
863 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
864 |
+
if use_legacy_cache:
|
865 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
866 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
867 |
+
|
868 |
+
if position_ids is None:
|
869 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
870 |
+
position_ids = torch.arange(
|
871 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
872 |
+
)
|
873 |
+
position_ids = position_ids.unsqueeze(0)
|
874 |
+
|
875 |
+
if inputs_embeds is None:
|
876 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
877 |
+
|
878 |
+
inputs_embeds = self.embed_dropout(inputs_embeds)
|
879 |
+
|
880 |
+
# Attention mask.
|
881 |
+
if self._use_flash_attention_2:
|
882 |
+
# 2d mask is passed through the layers
|
883 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
884 |
+
else:
|
885 |
+
# 4d mask is passed through the layers
|
886 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
887 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
888 |
+
)
|
889 |
+
|
890 |
+
hidden_states = inputs_embeds
|
891 |
+
|
892 |
+
# decoder layers
|
893 |
+
all_hidden_states = () if output_hidden_states else None
|
894 |
+
all_self_attns = () if output_attentions else None
|
895 |
+
next_decoder_cache = None
|
896 |
+
|
897 |
+
for decoder_layer in self.layers:
|
898 |
+
if output_hidden_states:
|
899 |
+
all_hidden_states += (hidden_states,)
|
900 |
+
|
901 |
+
if self.gradient_checkpointing and self.training:
|
902 |
+
layer_outputs = self._gradient_checkpointing_func(
|
903 |
+
decoder_layer.__call__,
|
904 |
+
hidden_states,
|
905 |
+
attention_mask,
|
906 |
+
position_ids,
|
907 |
+
past_key_values,
|
908 |
+
output_attentions,
|
909 |
+
)
|
910 |
+
else:
|
911 |
+
layer_outputs = decoder_layer(
|
912 |
+
hidden_states,
|
913 |
+
attention_mask=attention_mask,
|
914 |
+
position_ids=position_ids,
|
915 |
+
past_key_value=past_key_values,
|
916 |
+
output_attentions=output_attentions,
|
917 |
+
use_cache=use_cache,
|
918 |
+
)
|
919 |
+
|
920 |
+
hidden_states = layer_outputs[0]
|
921 |
+
|
922 |
+
if use_cache:
|
923 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
924 |
+
|
925 |
+
if output_attentions:
|
926 |
+
all_self_attns += (layer_outputs[1],)
|
927 |
+
|
928 |
+
hidden_states = self.final_layernorm(hidden_states)
|
929 |
+
|
930 |
+
# add hidden states from the last decoder layer
|
931 |
+
if output_hidden_states:
|
932 |
+
all_hidden_states += (hidden_states,)
|
933 |
+
|
934 |
+
next_cache = None
|
935 |
+
if use_cache:
|
936 |
+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
937 |
+
if not return_dict:
|
938 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
939 |
+
return BaseModelOutputWithPast(
|
940 |
+
last_hidden_state=hidden_states,
|
941 |
+
past_key_values=next_cache,
|
942 |
+
hidden_states=all_hidden_states,
|
943 |
+
attentions=all_self_attns,
|
944 |
+
)
|
945 |
+
|
946 |
+
|
947 |
+
class PhiForCausalLM(PhiPreTrainedModel):
|
948 |
+
_tied_weights_keys = ["lm_head.weight"]
|
949 |
+
|
950 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
|
951 |
+
def __init__(self, config):
|
952 |
+
super().__init__(config)
|
953 |
+
self.model = PhiModel(config)
|
954 |
+
self.vocab_size = config.vocab_size
|
955 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
956 |
+
|
957 |
+
# Initialize weights and apply final processing
|
958 |
+
self.post_init()
|
959 |
+
|
960 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
|
961 |
+
def get_input_embeddings(self):
|
962 |
+
return self.model.embed_tokens
|
963 |
+
|
964 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
|
965 |
+
def set_input_embeddings(self, value):
|
966 |
+
self.model.embed_tokens = value
|
967 |
+
|
968 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
|
969 |
+
def get_output_embeddings(self):
|
970 |
+
return self.lm_head
|
971 |
+
|
972 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
|
973 |
+
def set_output_embeddings(self, new_embeddings):
|
974 |
+
self.lm_head = new_embeddings
|
975 |
+
|
976 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
|
977 |
+
def set_decoder(self, decoder):
|
978 |
+
self.model = decoder
|
979 |
+
|
980 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
|
981 |
+
def get_decoder(self):
|
982 |
+
return self.model
|
983 |
+
|
984 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
985 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
986 |
+
def forward(
|
987 |
+
self,
|
988 |
+
input_ids: torch.LongTensor = None,
|
989 |
+
attention_mask: Optional[torch.Tensor] = None,
|
990 |
+
position_ids: Optional[torch.LongTensor] = None,
|
991 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
992 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
993 |
+
labels: Optional[torch.LongTensor] = None,
|
994 |
+
use_cache: Optional[bool] = None,
|
995 |
+
output_attentions: Optional[bool] = None,
|
996 |
+
output_hidden_states: Optional[bool] = None,
|
997 |
+
return_dict: Optional[bool] = None,
|
998 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
999 |
+
r"""
|
1000 |
+
Args:
|
1001 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1002 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1003 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1004 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1005 |
+
Returns:
|
1006 |
+
Example:
|
1007 |
+
```python
|
1008 |
+
>>> from transformers import AutoTokenizer, PhiForCausalLM
|
1009 |
+
>>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
|
1010 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
|
1011 |
+
>>> prompt = "This is an example script ."
|
1012 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1013 |
+
>>> # Generate
|
1014 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1015 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1016 |
+
'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
|
1017 |
+
```"""
|
1018 |
+
|
1019 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1020 |
+
output_hidden_states = (
|
1021 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1022 |
+
)
|
1023 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1024 |
+
|
1025 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1026 |
+
outputs = self.model(
|
1027 |
+
input_ids=input_ids,
|
1028 |
+
attention_mask=attention_mask,
|
1029 |
+
position_ids=position_ids,
|
1030 |
+
past_key_values=past_key_values,
|
1031 |
+
inputs_embeds=inputs_embeds,
|
1032 |
+
use_cache=use_cache,
|
1033 |
+
output_attentions=output_attentions,
|
1034 |
+
output_hidden_states=output_hidden_states,
|
1035 |
+
return_dict=return_dict,
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
hidden_states = outputs[0]
|
1039 |
+
logits = self.lm_head(hidden_states)
|
1040 |
+
logits = logits.float()
|
1041 |
+
|
1042 |
+
loss = None
|
1043 |
+
if labels is not None:
|
1044 |
+
# Shift so that tokens < n predict n
|
1045 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1046 |
+
shift_labels = labels[..., 1:].contiguous()
|
1047 |
+
# import pdb;pdb.set_trace()
|
1048 |
+
# Flatten the tokens
|
1049 |
+
loss_fct = CrossEntropyLoss()
|
1050 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1051 |
+
shift_labels = shift_labels.view(-1)
|
1052 |
+
# Enable model parallelism
|
1053 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
1054 |
+
loss = loss_fct(shift_logits, shift_labels)
|
1055 |
+
|
1056 |
+
if not return_dict:
|
1057 |
+
output = (logits,) + outputs[1:]
|
1058 |
+
return (loss,) + output if loss is not None else output
|
1059 |
+
|
1060 |
+
return CausalLMOutputWithPast(
|
1061 |
+
loss=loss,
|
1062 |
+
logits=logits,
|
1063 |
+
past_key_values=outputs.past_key_values,
|
1064 |
+
hidden_states=outputs.hidden_states,
|
1065 |
+
attentions=outputs.attentions,
|
1066 |
+
)
|
1067 |
+
|
1068 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
1069 |
+
def prepare_inputs_for_generation(
|
1070 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
1071 |
+
):
|
1072 |
+
if past_key_values is not None:
|
1073 |
+
if isinstance(past_key_values, Cache):
|
1074 |
+
cache_length = past_key_values.get_seq_length()
|
1075 |
+
past_length = past_key_values.seen_tokens
|
1076 |
+
max_cache_length = past_key_values.get_max_length()
|
1077 |
+
else:
|
1078 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
1079 |
+
max_cache_length = None
|
1080 |
+
|
1081 |
+
# Keep only the unprocessed tokens:
|
1082 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1083 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
1084 |
+
# input)
|
1085 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1086 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1087 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
1088 |
+
# input_ids based on the past_length.
|
1089 |
+
elif past_length < input_ids.shape[1]:
|
1090 |
+
input_ids = input_ids[:, past_length:]
|
1091 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
1092 |
+
|
1093 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
1094 |
+
if (
|
1095 |
+
max_cache_length is not None
|
1096 |
+
and attention_mask is not None
|
1097 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
1098 |
+
):
|
1099 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
1100 |
+
|
1101 |
+
position_ids = kwargs.get("position_ids", None)
|
1102 |
+
if attention_mask is not None and position_ids is None:
|
1103 |
+
# create position_ids on the fly for batch generation
|
1104 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1105 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1106 |
+
if past_key_values:
|
1107 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1108 |
+
|
1109 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1110 |
+
if inputs_embeds is not None and past_key_values is None:
|
1111 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1112 |
+
else:
|
1113 |
+
model_inputs = {"input_ids": input_ids}
|
1114 |
+
|
1115 |
+
model_inputs.update(
|
1116 |
+
{
|
1117 |
+
"position_ids": position_ids,
|
1118 |
+
"past_key_values": past_key_values,
|
1119 |
+
"use_cache": kwargs.get("use_cache"),
|
1120 |
+
"attention_mask": attention_mask,
|
1121 |
+
}
|
1122 |
+
)
|
1123 |
+
return model_inputs
|
1124 |
+
|
1125 |
+
@staticmethod
|
1126 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
|
1127 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1128 |
+
reordered_past = ()
|
1129 |
+
for layer_past in past_key_values:
|
1130 |
+
reordered_past += (
|
1131 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1132 |
+
)
|
1133 |
+
return reordered_past
|
1134 |
+
|
1135 |
+
|
1136 |
+
@add_start_docstrings(
|
1137 |
+
"""
|
1138 |
+
The PhiModel with a sequence classification head on top (linear layer).
|
1139 |
+
[`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1140 |
+
(e.g. GPT-2) do.
|
1141 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1142 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1143 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1144 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1145 |
+
each row of the batch).
|
1146 |
+
""",
|
1147 |
+
PHI_START_DOCSTRING,
|
1148 |
+
)
|
1149 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
|
1150 |
+
class PhiForSequenceClassification(PhiPreTrainedModel):
|
1151 |
+
def __init__(self, config):
|
1152 |
+
super().__init__(config)
|
1153 |
+
self.num_labels = config.num_labels
|
1154 |
+
self.model = PhiModel(config)
|
1155 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1156 |
+
|
1157 |
+
# Initialize weights and apply final processing
|
1158 |
+
self.post_init()
|
1159 |
+
|
1160 |
+
def get_input_embeddings(self):
|
1161 |
+
return self.model.embed_tokens
|
1162 |
+
|
1163 |
+
def set_input_embeddings(self, value):
|
1164 |
+
self.model.embed_tokens = value
|
1165 |
+
|
1166 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
1167 |
+
def forward(
|
1168 |
+
self,
|
1169 |
+
input_ids: torch.LongTensor = None,
|
1170 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1171 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1172 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1173 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1174 |
+
labels: Optional[torch.LongTensor] = None,
|
1175 |
+
use_cache: Optional[bool] = None,
|
1176 |
+
output_attentions: Optional[bool] = None,
|
1177 |
+
output_hidden_states: Optional[bool] = None,
|
1178 |
+
return_dict: Optional[bool] = None,
|
1179 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1180 |
+
r"""
|
1181 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1182 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1183 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1184 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1185 |
+
"""
|
1186 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1187 |
+
|
1188 |
+
model_outputs = self.model(
|
1189 |
+
input_ids,
|
1190 |
+
attention_mask=attention_mask,
|
1191 |
+
position_ids=position_ids,
|
1192 |
+
past_key_values=past_key_values,
|
1193 |
+
inputs_embeds=inputs_embeds,
|
1194 |
+
use_cache=use_cache,
|
1195 |
+
output_attentions=output_attentions,
|
1196 |
+
output_hidden_states=output_hidden_states,
|
1197 |
+
return_dict=return_dict,
|
1198 |
+
)
|
1199 |
+
hidden_states = model_outputs[0]
|
1200 |
+
logits = self.score(hidden_states)
|
1201 |
+
|
1202 |
+
if input_ids is not None:
|
1203 |
+
batch_size = input_ids.shape[0]
|
1204 |
+
else:
|
1205 |
+
batch_size = inputs_embeds.shape[0]
|
1206 |
+
|
1207 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1208 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1209 |
+
if self.config.pad_token_id is None:
|
1210 |
+
sequence_lengths = -1
|
1211 |
+
else:
|
1212 |
+
if input_ids is not None:
|
1213 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1214 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1215 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1216 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1217 |
+
else:
|
1218 |
+
sequence_lengths = -1
|
1219 |
+
|
1220 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1221 |
+
|
1222 |
+
loss = None
|
1223 |
+
if labels is not None:
|
1224 |
+
labels = labels.to(logits.device)
|
1225 |
+
if self.config.problem_type is None:
|
1226 |
+
if self.num_labels == 1:
|
1227 |
+
self.config.problem_type = "regression"
|
1228 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1229 |
+
self.config.problem_type = "single_label_classification"
|
1230 |
+
else:
|
1231 |
+
self.config.problem_type = "multi_label_classification"
|
1232 |
+
|
1233 |
+
if self.config.problem_type == "regression":
|
1234 |
+
loss_fct = MSELoss()
|
1235 |
+
if self.num_labels == 1:
|
1236 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1237 |
+
else:
|
1238 |
+
loss = loss_fct(pooled_logits, labels)
|
1239 |
+
elif self.config.problem_type == "single_label_classification":
|
1240 |
+
loss_fct = CrossEntropyLoss()
|
1241 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1242 |
+
elif self.config.problem_type == "multi_label_classification":
|
1243 |
+
loss_fct = BCEWithLogitsLoss()
|
1244 |
+
loss = loss_fct(pooled_logits, labels)
|
1245 |
+
if not return_dict:
|
1246 |
+
output = (pooled_logits,) + model_outputs[1:]
|
1247 |
+
return ((loss,) + output) if loss is not None else output
|
1248 |
+
|
1249 |
+
return SequenceClassifierOutputWithPast(
|
1250 |
+
loss=loss,
|
1251 |
+
logits=pooled_logits,
|
1252 |
+
past_key_values=model_outputs.past_key_values,
|
1253 |
+
hidden_states=model_outputs.hidden_states,
|
1254 |
+
attentions=model_outputs.attentions,
|
1255 |
+
)
|
1256 |
+
|
1257 |
+
|
1258 |
+
@add_start_docstrings(
|
1259 |
+
"""
|
1260 |
+
PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
1261 |
+
Named-Entity-Recognition (NER) tasks.
|
1262 |
+
""",
|
1263 |
+
PHI_START_DOCSTRING,
|
1264 |
+
)
|
1265 |
+
# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs
|
1266 |
+
class PhiForTokenClassification(PhiPreTrainedModel):
|
1267 |
+
def __init__(self, config: PhiConfig):
|
1268 |
+
super().__init__(config)
|
1269 |
+
self.num_labels = config.num_labels
|
1270 |
+
|
1271 |
+
self.model = PhiModel(config)
|
1272 |
+
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
1273 |
+
classifier_dropout = config.classifier_dropout
|
1274 |
+
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
1275 |
+
classifier_dropout = config.hidden_dropout
|
1276 |
+
else:
|
1277 |
+
classifier_dropout = 0.1
|
1278 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1279 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1280 |
+
|
1281 |
+
# Initialize weights and apply final processing
|
1282 |
+
self.post_init()
|
1283 |
+
|
1284 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
1285 |
+
@add_code_sample_docstrings(
|
1286 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1287 |
+
output_type=TokenClassifierOutput,
|
1288 |
+
config_class=_CONFIG_FOR_DOC,
|
1289 |
+
)
|
1290 |
+
def forward(
|
1291 |
+
self,
|
1292 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1293 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
1294 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1295 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1296 |
+
labels: Optional[torch.Tensor] = None,
|
1297 |
+
use_cache: Optional[bool] = None,
|
1298 |
+
output_attentions: Optional[bool] = None,
|
1299 |
+
output_hidden_states: Optional[bool] = None,
|
1300 |
+
return_dict: Optional[bool] = None,
|
1301 |
+
**deprecated_arguments,
|
1302 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
1303 |
+
r"""
|
1304 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1305 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1306 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1307 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1308 |
+
"""
|
1309 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1310 |
+
|
1311 |
+
model_outputs = self.model(
|
1312 |
+
input_ids,
|
1313 |
+
past_key_values=past_key_values,
|
1314 |
+
attention_mask=attention_mask,
|
1315 |
+
inputs_embeds=inputs_embeds,
|
1316 |
+
use_cache=use_cache,
|
1317 |
+
output_attentions=output_attentions,
|
1318 |
+
output_hidden_states=output_hidden_states,
|
1319 |
+
return_dict=return_dict,
|
1320 |
+
)
|
1321 |
+
|
1322 |
+
hidden_states = model_outputs[0]
|
1323 |
+
hidden_states = self.dropout(hidden_states)
|
1324 |
+
logits = self.classifier(hidden_states)
|
1325 |
+
|
1326 |
+
loss = None
|
1327 |
+
if labels is not None:
|
1328 |
+
# move labels to correct device to enable model parallelism
|
1329 |
+
labels = labels.to(logits.device)
|
1330 |
+
batch_size, seq_length = labels.shape
|
1331 |
+
loss_fct = CrossEntropyLoss()
|
1332 |
+
loss = loss_fct(
|
1333 |
+
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
|
1334 |
+
)
|
1335 |
+
|
1336 |
+
if not return_dict:
|
1337 |
+
output = (logits,) + model_outputs[2:]
|
1338 |
+
return ((loss,) + output) if loss is not None else output
|
1339 |
+
|
1340 |
+
return TokenClassifierOutput(
|
1341 |
+
loss=loss,
|
1342 |
+
logits=logits,
|
1343 |
+
hidden_states=model_outputs.hidden_states,
|
1344 |
+
attentions=model_outputs.attentions,
|
1345 |
+
)
|
tinychart/model/language_model/phi/utils.py
ADDED
@@ -0,0 +1,1428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Import utilities: Utilities related to imports and our lazy inits.
|
16 |
+
"""
|
17 |
+
|
18 |
+
import importlib.metadata
|
19 |
+
import importlib.util
|
20 |
+
import json
|
21 |
+
import os
|
22 |
+
import shutil
|
23 |
+
import subprocess
|
24 |
+
import sys
|
25 |
+
import warnings
|
26 |
+
from collections import OrderedDict
|
27 |
+
from functools import lru_cache, wraps
|
28 |
+
from itertools import chain
|
29 |
+
from types import ModuleType
|
30 |
+
from typing import Any, Tuple, Union
|
31 |
+
|
32 |
+
from packaging import version
|
33 |
+
|
34 |
+
from transformers import logging
|
35 |
+
|
36 |
+
|
37 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
38 |
+
|
39 |
+
|
40 |
+
# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
|
41 |
+
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
|
42 |
+
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
|
43 |
+
package_exists = importlib.util.find_spec(pkg_name) is not None
|
44 |
+
package_version = "N/A"
|
45 |
+
if package_exists:
|
46 |
+
try:
|
47 |
+
package_version = importlib.metadata.version(pkg_name)
|
48 |
+
package_exists = True
|
49 |
+
except importlib.metadata.PackageNotFoundError:
|
50 |
+
package_exists = False
|
51 |
+
logger.debug(f"Detected {pkg_name} version {package_version}")
|
52 |
+
if return_version:
|
53 |
+
return package_exists, package_version
|
54 |
+
else:
|
55 |
+
return package_exists
|
56 |
+
|
57 |
+
|
58 |
+
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
59 |
+
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
60 |
+
|
61 |
+
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
62 |
+
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
63 |
+
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
64 |
+
|
65 |
+
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
|
66 |
+
|
67 |
+
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
|
68 |
+
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
|
69 |
+
|
70 |
+
ACCELERATE_MIN_VERSION = "0.21.0"
|
71 |
+
FSDP_MIN_VERSION = "1.12.0"
|
72 |
+
|
73 |
+
|
74 |
+
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
75 |
+
_apex_available = _is_package_available("apex")
|
76 |
+
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
77 |
+
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
78 |
+
_bs4_available = importlib.util.find_spec("bs4") is not None
|
79 |
+
_coloredlogs_available = _is_package_available("coloredlogs")
|
80 |
+
# `importlib.metadata.util` doesn't work with `opencv-python-headless`.
|
81 |
+
_cv2_available = importlib.util.find_spec("cv2") is not None
|
82 |
+
_datasets_available = _is_package_available("datasets")
|
83 |
+
_decord_available = importlib.util.find_spec("decord") is not None
|
84 |
+
_detectron2_available = _is_package_available("detectron2")
|
85 |
+
# We need to check both `faiss` and `faiss-cpu`.
|
86 |
+
_faiss_available = importlib.util.find_spec("faiss") is not None
|
87 |
+
try:
|
88 |
+
_faiss_version = importlib.metadata.version("faiss")
|
89 |
+
logger.debug(f"Successfully imported faiss version {_faiss_version}")
|
90 |
+
except importlib.metadata.PackageNotFoundError:
|
91 |
+
try:
|
92 |
+
_faiss_version = importlib.metadata.version("faiss-cpu")
|
93 |
+
logger.debug(f"Successfully imported faiss version {_faiss_version}")
|
94 |
+
except importlib.metadata.PackageNotFoundError:
|
95 |
+
_faiss_available = False
|
96 |
+
_ftfy_available = _is_package_available("ftfy")
|
97 |
+
_g2p_en_available = _is_package_available("g2p_en")
|
98 |
+
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
|
99 |
+
_jieba_available = _is_package_available("jieba")
|
100 |
+
_jinja_available = _is_package_available("jinja2")
|
101 |
+
_kenlm_available = _is_package_available("kenlm")
|
102 |
+
_keras_nlp_available = _is_package_available("keras_nlp")
|
103 |
+
_levenshtein_available = _is_package_available("Levenshtein")
|
104 |
+
_librosa_available = _is_package_available("librosa")
|
105 |
+
_natten_available = _is_package_available("natten")
|
106 |
+
_nltk_available = _is_package_available("nltk")
|
107 |
+
_onnx_available = _is_package_available("onnx")
|
108 |
+
_openai_available = _is_package_available("openai")
|
109 |
+
_optimum_available = _is_package_available("optimum")
|
110 |
+
_auto_gptq_available = _is_package_available("auto_gptq")
|
111 |
+
# `importlib.metadata.version` doesn't work with `awq`
|
112 |
+
_auto_awq_available = importlib.util.find_spec("awq") is not None
|
113 |
+
_pandas_available = _is_package_available("pandas")
|
114 |
+
_peft_available = _is_package_available("peft")
|
115 |
+
_phonemizer_available = _is_package_available("phonemizer")
|
116 |
+
_psutil_available = _is_package_available("psutil")
|
117 |
+
_py3nvml_available = _is_package_available("py3nvml")
|
118 |
+
_pyctcdecode_available = _is_package_available("pyctcdecode")
|
119 |
+
_pytesseract_available = _is_package_available("pytesseract")
|
120 |
+
_pytest_available = _is_package_available("pytest")
|
121 |
+
_pytorch_quantization_available = _is_package_available("pytorch_quantization")
|
122 |
+
_rjieba_available = _is_package_available("rjieba")
|
123 |
+
_sacremoses_available = _is_package_available("sacremoses")
|
124 |
+
_safetensors_available = _is_package_available("safetensors")
|
125 |
+
_scipy_available = _is_package_available("scipy")
|
126 |
+
_sentencepiece_available = _is_package_available("sentencepiece")
|
127 |
+
_is_seqio_available = _is_package_available("seqio")
|
128 |
+
_sklearn_available = importlib.util.find_spec("sklearn") is not None
|
129 |
+
if _sklearn_available:
|
130 |
+
try:
|
131 |
+
importlib.metadata.version("scikit-learn")
|
132 |
+
except importlib.metadata.PackageNotFoundError:
|
133 |
+
_sklearn_available = False
|
134 |
+
_smdistributed_available = importlib.util.find_spec("smdistributed") is not None
|
135 |
+
_soundfile_available = _is_package_available("soundfile")
|
136 |
+
_spacy_available = _is_package_available("spacy")
|
137 |
+
_sudachipy_available = _is_package_available("sudachipy")
|
138 |
+
_tensorflow_probability_available = _is_package_available("tensorflow_probability")
|
139 |
+
_tensorflow_text_available = _is_package_available("tensorflow_text")
|
140 |
+
_tf2onnx_available = _is_package_available("tf2onnx")
|
141 |
+
_timm_available = _is_package_available("timm")
|
142 |
+
_tokenizers_available = _is_package_available("tokenizers")
|
143 |
+
_torchaudio_available = _is_package_available("torchaudio")
|
144 |
+
_torchdistx_available = _is_package_available("torchdistx")
|
145 |
+
_torchvision_available = _is_package_available("torchvision")
|
146 |
+
|
147 |
+
|
148 |
+
_torch_version = "N/A"
|
149 |
+
_torch_available = False
|
150 |
+
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
151 |
+
_torch_available, _torch_version = _is_package_available("torch", return_version=True)
|
152 |
+
else:
|
153 |
+
logger.info("Disabling PyTorch because USE_TF is set")
|
154 |
+
_torch_available = False
|
155 |
+
|
156 |
+
|
157 |
+
_tf_version = "N/A"
|
158 |
+
_tf_available = False
|
159 |
+
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES:
|
160 |
+
_tf_available = True
|
161 |
+
else:
|
162 |
+
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
163 |
+
# Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below
|
164 |
+
# with tensorflow-cpu to make sure it still works!
|
165 |
+
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
166 |
+
if _tf_available:
|
167 |
+
candidates = (
|
168 |
+
"tensorflow",
|
169 |
+
"tensorflow-cpu",
|
170 |
+
"tensorflow-gpu",
|
171 |
+
"tf-nightly",
|
172 |
+
"tf-nightly-cpu",
|
173 |
+
"tf-nightly-gpu",
|
174 |
+
"tf-nightly-rocm",
|
175 |
+
"intel-tensorflow",
|
176 |
+
"intel-tensorflow-avx512",
|
177 |
+
"tensorflow-rocm",
|
178 |
+
"tensorflow-macos",
|
179 |
+
"tensorflow-aarch64",
|
180 |
+
)
|
181 |
+
_tf_version = None
|
182 |
+
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
183 |
+
for pkg in candidates:
|
184 |
+
try:
|
185 |
+
_tf_version = importlib.metadata.version(pkg)
|
186 |
+
break
|
187 |
+
except importlib.metadata.PackageNotFoundError:
|
188 |
+
pass
|
189 |
+
_tf_available = _tf_version is not None
|
190 |
+
if _tf_available:
|
191 |
+
if version.parse(_tf_version) < version.parse("2"):
|
192 |
+
logger.info(
|
193 |
+
f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum."
|
194 |
+
)
|
195 |
+
_tf_available = False
|
196 |
+
else:
|
197 |
+
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
198 |
+
|
199 |
+
|
200 |
+
_essentia_available = importlib.util.find_spec("essentia") is not None
|
201 |
+
try:
|
202 |
+
_essentia_version = importlib.metadata.version("essentia")
|
203 |
+
logger.debug(f"Successfully imported essentia version {_essentia_version}")
|
204 |
+
except importlib.metadata.PackageNotFoundError:
|
205 |
+
_essentia_version = False
|
206 |
+
|
207 |
+
|
208 |
+
_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
|
209 |
+
try:
|
210 |
+
_pretty_midi_version = importlib.metadata.version("pretty_midi")
|
211 |
+
logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}")
|
212 |
+
except importlib.metadata.PackageNotFoundError:
|
213 |
+
_pretty_midi_available = False
|
214 |
+
|
215 |
+
|
216 |
+
ccl_version = "N/A"
|
217 |
+
_is_ccl_available = (
|
218 |
+
importlib.util.find_spec("torch_ccl") is not None
|
219 |
+
or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
|
220 |
+
)
|
221 |
+
try:
|
222 |
+
ccl_version = importlib.metadata.version("oneccl_bind_pt")
|
223 |
+
logger.debug(f"Detected oneccl_bind_pt version {ccl_version}")
|
224 |
+
except importlib.metadata.PackageNotFoundError:
|
225 |
+
_is_ccl_available = False
|
226 |
+
|
227 |
+
|
228 |
+
_flax_available = False
|
229 |
+
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
230 |
+
_flax_available, _flax_version = _is_package_available("flax", return_version=True)
|
231 |
+
if _flax_available:
|
232 |
+
_jax_available, _jax_version = _is_package_available("jax", return_version=True)
|
233 |
+
if _jax_available:
|
234 |
+
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
|
235 |
+
else:
|
236 |
+
_flax_available = _jax_available = False
|
237 |
+
_jax_version = _flax_version = "N/A"
|
238 |
+
|
239 |
+
|
240 |
+
_torch_fx_available = False
|
241 |
+
if _torch_available:
|
242 |
+
torch_version = version.parse(_torch_version)
|
243 |
+
_torch_fx_available = (torch_version.major, torch_version.minor) >= (
|
244 |
+
TORCH_FX_REQUIRED_VERSION.major,
|
245 |
+
TORCH_FX_REQUIRED_VERSION.minor,
|
246 |
+
)
|
247 |
+
|
248 |
+
|
249 |
+
def is_kenlm_available():
|
250 |
+
return _kenlm_available
|
251 |
+
|
252 |
+
|
253 |
+
def is_cv2_available():
|
254 |
+
return _cv2_available
|
255 |
+
|
256 |
+
|
257 |
+
def is_torch_available():
|
258 |
+
return _torch_available
|
259 |
+
|
260 |
+
|
261 |
+
def get_torch_version():
|
262 |
+
return _torch_version
|
263 |
+
|
264 |
+
|
265 |
+
def is_torch_sdpa_available():
|
266 |
+
if not is_torch_available():
|
267 |
+
return False
|
268 |
+
elif _torch_version == "N/A":
|
269 |
+
return False
|
270 |
+
|
271 |
+
# NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
|
272 |
+
# - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
|
273 |
+
# - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
|
274 |
+
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
|
275 |
+
return version.parse(_torch_version) >= version.parse("2.1.1")
|
276 |
+
|
277 |
+
|
278 |
+
def is_torchvision_available():
|
279 |
+
return _torchvision_available
|
280 |
+
|
281 |
+
|
282 |
+
def is_pyctcdecode_available():
|
283 |
+
return _pyctcdecode_available
|
284 |
+
|
285 |
+
|
286 |
+
def is_librosa_available():
|
287 |
+
return _librosa_available
|
288 |
+
|
289 |
+
|
290 |
+
def is_essentia_available():
|
291 |
+
return _essentia_available
|
292 |
+
|
293 |
+
|
294 |
+
def is_pretty_midi_available():
|
295 |
+
return _pretty_midi_available
|
296 |
+
|
297 |
+
|
298 |
+
def is_torch_cuda_available():
|
299 |
+
if is_torch_available():
|
300 |
+
import torch
|
301 |
+
|
302 |
+
return torch.cuda.is_available()
|
303 |
+
else:
|
304 |
+
return False
|
305 |
+
|
306 |
+
|
307 |
+
def is_torch_mps_available():
|
308 |
+
if is_torch_available():
|
309 |
+
import torch
|
310 |
+
|
311 |
+
if hasattr(torch.backends, "mps"):
|
312 |
+
return torch.backends.mps.is_available()
|
313 |
+
return False
|
314 |
+
|
315 |
+
|
316 |
+
def is_torch_bf16_gpu_available():
|
317 |
+
if not is_torch_available():
|
318 |
+
return False
|
319 |
+
|
320 |
+
import torch
|
321 |
+
|
322 |
+
return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
323 |
+
|
324 |
+
|
325 |
+
def is_torch_bf16_cpu_available():
|
326 |
+
if not is_torch_available():
|
327 |
+
return False
|
328 |
+
|
329 |
+
import torch
|
330 |
+
|
331 |
+
try:
|
332 |
+
# multiple levels of AttributeError depending on the pytorch version so do them all in one check
|
333 |
+
_ = torch.cpu.amp.autocast
|
334 |
+
except AttributeError:
|
335 |
+
return False
|
336 |
+
|
337 |
+
return True
|
338 |
+
|
339 |
+
|
340 |
+
def is_torch_bf16_available():
|
341 |
+
# the original bf16 check was for gpu only, but later a cpu/bf16 combo has emerged so this util
|
342 |
+
# has become ambiguous and therefore deprecated
|
343 |
+
warnings.warn(
|
344 |
+
"The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available "
|
345 |
+
"or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu",
|
346 |
+
FutureWarning,
|
347 |
+
)
|
348 |
+
return is_torch_bf16_gpu_available()
|
349 |
+
|
350 |
+
|
351 |
+
@lru_cache()
|
352 |
+
def is_torch_fp16_available_on_device(device):
|
353 |
+
if not is_torch_available():
|
354 |
+
return False
|
355 |
+
|
356 |
+
import torch
|
357 |
+
|
358 |
+
try:
|
359 |
+
x = torch.zeros(2, 2, dtype=torch.float16).to(device)
|
360 |
+
_ = x @ x
|
361 |
+
except: # noqa: E722
|
362 |
+
# TODO: more precise exception matching, if possible.
|
363 |
+
# most backends should return `RuntimeError` however this is not guaranteed.
|
364 |
+
return False
|
365 |
+
|
366 |
+
return True
|
367 |
+
|
368 |
+
|
369 |
+
@lru_cache()
|
370 |
+
def is_torch_bf16_available_on_device(device):
|
371 |
+
if not is_torch_available():
|
372 |
+
return False
|
373 |
+
|
374 |
+
import torch
|
375 |
+
|
376 |
+
if device == "cuda":
|
377 |
+
return is_torch_bf16_gpu_available()
|
378 |
+
|
379 |
+
try:
|
380 |
+
x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
|
381 |
+
_ = x @ x
|
382 |
+
except: # noqa: E722
|
383 |
+
# TODO: more precise exception matching, if possible.
|
384 |
+
# most backends should return `RuntimeError` however this is not guaranteed.
|
385 |
+
return False
|
386 |
+
|
387 |
+
return True
|
388 |
+
|
389 |
+
|
390 |
+
def is_torch_tf32_available():
|
391 |
+
if not is_torch_available():
|
392 |
+
return False
|
393 |
+
|
394 |
+
import torch
|
395 |
+
|
396 |
+
if not torch.cuda.is_available() or torch.version.cuda is None:
|
397 |
+
return False
|
398 |
+
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
399 |
+
return False
|
400 |
+
if int(torch.version.cuda.split(".")[0]) < 11:
|
401 |
+
return False
|
402 |
+
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
|
403 |
+
return False
|
404 |
+
|
405 |
+
return True
|
406 |
+
|
407 |
+
|
408 |
+
def is_torch_fx_available():
|
409 |
+
return _torch_fx_available
|
410 |
+
|
411 |
+
|
412 |
+
def is_peft_available():
|
413 |
+
return _peft_available
|
414 |
+
|
415 |
+
|
416 |
+
def is_bs4_available():
|
417 |
+
return _bs4_available
|
418 |
+
|
419 |
+
|
420 |
+
def is_tf_available():
|
421 |
+
return _tf_available
|
422 |
+
|
423 |
+
|
424 |
+
def is_coloredlogs_available():
|
425 |
+
return _coloredlogs_available
|
426 |
+
|
427 |
+
|
428 |
+
def is_tf2onnx_available():
|
429 |
+
return _tf2onnx_available
|
430 |
+
|
431 |
+
|
432 |
+
def is_onnx_available():
|
433 |
+
return _onnx_available
|
434 |
+
|
435 |
+
|
436 |
+
def is_openai_available():
|
437 |
+
return _openai_available
|
438 |
+
|
439 |
+
|
440 |
+
def is_flax_available():
|
441 |
+
return _flax_available
|
442 |
+
|
443 |
+
|
444 |
+
def is_ftfy_available():
|
445 |
+
return _ftfy_available
|
446 |
+
|
447 |
+
|
448 |
+
def is_g2p_en_available():
|
449 |
+
return _g2p_en_available
|
450 |
+
|
451 |
+
|
452 |
+
@lru_cache()
|
453 |
+
def is_torch_tpu_available(check_device=True):
|
454 |
+
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
|
455 |
+
if not _torch_available:
|
456 |
+
return False
|
457 |
+
if importlib.util.find_spec("torch_xla") is not None:
|
458 |
+
if check_device:
|
459 |
+
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
|
460 |
+
try:
|
461 |
+
import torch_xla.core.xla_model as xm
|
462 |
+
|
463 |
+
_ = xm.xla_device()
|
464 |
+
return True
|
465 |
+
except RuntimeError:
|
466 |
+
return False
|
467 |
+
return True
|
468 |
+
return False
|
469 |
+
|
470 |
+
|
471 |
+
@lru_cache()
|
472 |
+
def is_torch_neuroncore_available(check_device=True):
|
473 |
+
if importlib.util.find_spec("torch_neuronx") is not None:
|
474 |
+
return is_torch_tpu_available(check_device)
|
475 |
+
return False
|
476 |
+
|
477 |
+
|
478 |
+
@lru_cache()
|
479 |
+
def is_torch_npu_available(check_device=False):
|
480 |
+
"Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
|
481 |
+
if not _torch_available or importlib.util.find_spec("torch_npu") is None:
|
482 |
+
return False
|
483 |
+
|
484 |
+
import torch
|
485 |
+
import torch_npu # noqa: F401
|
486 |
+
|
487 |
+
if check_device:
|
488 |
+
try:
|
489 |
+
# Will raise a RuntimeError if no NPU is found
|
490 |
+
_ = torch.npu.device_count()
|
491 |
+
return torch.npu.is_available()
|
492 |
+
except RuntimeError:
|
493 |
+
return False
|
494 |
+
return hasattr(torch, "npu") and torch.npu.is_available()
|
495 |
+
|
496 |
+
|
497 |
+
def is_torchdynamo_available():
|
498 |
+
if not is_torch_available():
|
499 |
+
return False
|
500 |
+
try:
|
501 |
+
import torch._dynamo as dynamo # noqa: F401
|
502 |
+
|
503 |
+
return True
|
504 |
+
except Exception:
|
505 |
+
return False
|
506 |
+
|
507 |
+
|
508 |
+
def is_torch_compile_available():
|
509 |
+
if not is_torch_available():
|
510 |
+
return False
|
511 |
+
|
512 |
+
import torch
|
513 |
+
|
514 |
+
# We don't do any version check here to support nighlies marked as 1.14. Ultimately needs to check version against
|
515 |
+
# 2.0 but let's do it later.
|
516 |
+
return hasattr(torch, "compile")
|
517 |
+
|
518 |
+
|
519 |
+
def is_torchdynamo_compiling():
|
520 |
+
if not is_torch_available():
|
521 |
+
return False
|
522 |
+
try:
|
523 |
+
import torch._dynamo as dynamo # noqa: F401
|
524 |
+
|
525 |
+
return dynamo.is_compiling()
|
526 |
+
except Exception:
|
527 |
+
return False
|
528 |
+
|
529 |
+
|
530 |
+
def is_torch_tensorrt_fx_available():
|
531 |
+
if importlib.util.find_spec("torch_tensorrt") is None:
|
532 |
+
return False
|
533 |
+
return importlib.util.find_spec("torch_tensorrt.fx") is not None
|
534 |
+
|
535 |
+
|
536 |
+
def is_datasets_available():
|
537 |
+
return _datasets_available
|
538 |
+
|
539 |
+
|
540 |
+
def is_detectron2_available():
|
541 |
+
return _detectron2_available
|
542 |
+
|
543 |
+
|
544 |
+
def is_rjieba_available():
|
545 |
+
return _rjieba_available
|
546 |
+
|
547 |
+
|
548 |
+
def is_psutil_available():
|
549 |
+
return _psutil_available
|
550 |
+
|
551 |
+
|
552 |
+
def is_py3nvml_available():
|
553 |
+
return _py3nvml_available
|
554 |
+
|
555 |
+
|
556 |
+
def is_sacremoses_available():
|
557 |
+
return _sacremoses_available
|
558 |
+
|
559 |
+
|
560 |
+
def is_apex_available():
|
561 |
+
return _apex_available
|
562 |
+
|
563 |
+
|
564 |
+
def is_ninja_available():
|
565 |
+
r"""
|
566 |
+
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
|
567 |
+
[ninja](https://ninja-build.org/) build system is available on the system, `False` otherwise.
|
568 |
+
"""
|
569 |
+
try:
|
570 |
+
subprocess.check_output("ninja --version".split())
|
571 |
+
except Exception:
|
572 |
+
return False
|
573 |
+
else:
|
574 |
+
return True
|
575 |
+
|
576 |
+
|
577 |
+
def is_ipex_available():
|
578 |
+
def get_major_and_minor_from_version(full_version):
|
579 |
+
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
|
580 |
+
|
581 |
+
if not is_torch_available() or not _ipex_available:
|
582 |
+
return False
|
583 |
+
|
584 |
+
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
|
585 |
+
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
|
586 |
+
if torch_major_and_minor != ipex_major_and_minor:
|
587 |
+
logger.warning(
|
588 |
+
f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
|
589 |
+
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
|
590 |
+
)
|
591 |
+
return False
|
592 |
+
return True
|
593 |
+
|
594 |
+
|
595 |
+
@lru_cache
|
596 |
+
def is_torch_xpu_available(check_device=False):
|
597 |
+
"Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment"
|
598 |
+
if not is_ipex_available():
|
599 |
+
return False
|
600 |
+
|
601 |
+
import intel_extension_for_pytorch # noqa: F401
|
602 |
+
import torch
|
603 |
+
|
604 |
+
if check_device:
|
605 |
+
try:
|
606 |
+
# Will raise a RuntimeError if no XPU is found
|
607 |
+
_ = torch.xpu.device_count()
|
608 |
+
return torch.xpu.is_available()
|
609 |
+
except RuntimeError:
|
610 |
+
return False
|
611 |
+
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
612 |
+
|
613 |
+
|
614 |
+
def is_bitsandbytes_available():
|
615 |
+
if not is_torch_available():
|
616 |
+
return False
|
617 |
+
|
618 |
+
# bitsandbytes throws an error if cuda is not available
|
619 |
+
# let's avoid that by adding a simple check
|
620 |
+
import torch
|
621 |
+
|
622 |
+
return _bitsandbytes_available and torch.cuda.is_available()
|
623 |
+
|
624 |
+
|
625 |
+
def is_flash_attn_2_available():
|
626 |
+
if not is_torch_available():
|
627 |
+
return False
|
628 |
+
|
629 |
+
if not _is_package_available("flash_attn"):
|
630 |
+
return False
|
631 |
+
|
632 |
+
# Let's add an extra check to see if cuda is available
|
633 |
+
import torch
|
634 |
+
|
635 |
+
if not torch.cuda.is_available():
|
636 |
+
return False
|
637 |
+
|
638 |
+
if torch.version.cuda:
|
639 |
+
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
|
640 |
+
elif torch.version.hip:
|
641 |
+
# TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention
|
642 |
+
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
|
643 |
+
else:
|
644 |
+
return False
|
645 |
+
|
646 |
+
|
647 |
+
def is_flash_attn_greater_or_equal_2_10():
|
648 |
+
if not _is_package_available("flash_attn"):
|
649 |
+
return False
|
650 |
+
|
651 |
+
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
|
652 |
+
|
653 |
+
|
654 |
+
def is_flash_attn_available():
|
655 |
+
logger.warning(
|
656 |
+
"Using `is_flash_attn_available` is deprecated and will be removed in v4.38. "
|
657 |
+
"Please use `is_flash_attn_2_available` instead."
|
658 |
+
)
|
659 |
+
return is_flash_attn_2_available()
|
660 |
+
|
661 |
+
|
662 |
+
def is_torchdistx_available():
|
663 |
+
return _torchdistx_available
|
664 |
+
|
665 |
+
|
666 |
+
def is_faiss_available():
|
667 |
+
return _faiss_available
|
668 |
+
|
669 |
+
|
670 |
+
def is_scipy_available():
|
671 |
+
return _scipy_available
|
672 |
+
|
673 |
+
|
674 |
+
def is_sklearn_available():
|
675 |
+
return _sklearn_available
|
676 |
+
|
677 |
+
|
678 |
+
def is_sentencepiece_available():
|
679 |
+
return _sentencepiece_available
|
680 |
+
|
681 |
+
|
682 |
+
def is_seqio_available():
|
683 |
+
return _is_seqio_available
|
684 |
+
|
685 |
+
|
686 |
+
def is_protobuf_available():
|
687 |
+
if importlib.util.find_spec("google") is None:
|
688 |
+
return False
|
689 |
+
return importlib.util.find_spec("google.protobuf") is not None
|
690 |
+
|
691 |
+
|
692 |
+
def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
|
693 |
+
if min_version is not None:
|
694 |
+
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
|
695 |
+
return _accelerate_available
|
696 |
+
|
697 |
+
|
698 |
+
def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
|
699 |
+
return is_torch_available() and version.parse(_torch_version) >= version.parse(min_version)
|
700 |
+
|
701 |
+
|
702 |
+
def is_optimum_available():
|
703 |
+
return _optimum_available
|
704 |
+
|
705 |
+
|
706 |
+
def is_auto_awq_available():
|
707 |
+
return _auto_awq_available
|
708 |
+
|
709 |
+
|
710 |
+
def is_auto_gptq_available():
|
711 |
+
return _auto_gptq_available
|
712 |
+
|
713 |
+
|
714 |
+
def is_levenshtein_available():
|
715 |
+
return _levenshtein_available
|
716 |
+
|
717 |
+
|
718 |
+
def is_optimum_neuron_available():
|
719 |
+
return _optimum_available and _is_package_available("optimum.neuron")
|
720 |
+
|
721 |
+
|
722 |
+
def is_safetensors_available():
|
723 |
+
return _safetensors_available
|
724 |
+
|
725 |
+
|
726 |
+
def is_tokenizers_available():
|
727 |
+
return _tokenizers_available
|
728 |
+
|
729 |
+
|
730 |
+
def is_vision_available():
|
731 |
+
_pil_available = importlib.util.find_spec("PIL") is not None
|
732 |
+
if _pil_available:
|
733 |
+
try:
|
734 |
+
package_version = importlib.metadata.version("Pillow")
|
735 |
+
except importlib.metadata.PackageNotFoundError:
|
736 |
+
try:
|
737 |
+
package_version = importlib.metadata.version("Pillow-SIMD")
|
738 |
+
except importlib.metadata.PackageNotFoundError:
|
739 |
+
return False
|
740 |
+
logger.debug(f"Detected PIL version {package_version}")
|
741 |
+
return _pil_available
|
742 |
+
|
743 |
+
|
744 |
+
def is_pytesseract_available():
|
745 |
+
return _pytesseract_available
|
746 |
+
|
747 |
+
|
748 |
+
def is_pytest_available():
|
749 |
+
return _pytest_available
|
750 |
+
|
751 |
+
|
752 |
+
def is_spacy_available():
|
753 |
+
return _spacy_available
|
754 |
+
|
755 |
+
|
756 |
+
def is_tensorflow_text_available():
|
757 |
+
return is_tf_available() and _tensorflow_text_available
|
758 |
+
|
759 |
+
|
760 |
+
def is_keras_nlp_available():
|
761 |
+
return is_tensorflow_text_available() and _keras_nlp_available
|
762 |
+
|
763 |
+
|
764 |
+
def is_in_notebook():
|
765 |
+
try:
|
766 |
+
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
767 |
+
get_ipython = sys.modules["IPython"].get_ipython
|
768 |
+
if "IPKernelApp" not in get_ipython().config:
|
769 |
+
raise ImportError("console")
|
770 |
+
if "VSCODE_PID" in os.environ:
|
771 |
+
raise ImportError("vscode")
|
772 |
+
if "DATABRICKS_RUNTIME_VERSION" in os.environ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0":
|
773 |
+
# Databricks Runtime 11.0 and above uses IPython kernel by default so it should be compatible with Jupyter notebook
|
774 |
+
# https://docs.microsoft.com/en-us/azure/databricks/notebooks/ipython-kernel
|
775 |
+
raise ImportError("databricks")
|
776 |
+
|
777 |
+
return importlib.util.find_spec("IPython") is not None
|
778 |
+
except (AttributeError, ImportError, KeyError):
|
779 |
+
return False
|
780 |
+
|
781 |
+
|
782 |
+
def is_pytorch_quantization_available():
|
783 |
+
return _pytorch_quantization_available
|
784 |
+
|
785 |
+
|
786 |
+
def is_tensorflow_probability_available():
|
787 |
+
return _tensorflow_probability_available
|
788 |
+
|
789 |
+
|
790 |
+
def is_pandas_available():
|
791 |
+
return _pandas_available
|
792 |
+
|
793 |
+
|
794 |
+
def is_sagemaker_dp_enabled():
|
795 |
+
# Get the sagemaker specific env variable.
|
796 |
+
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
|
797 |
+
try:
|
798 |
+
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
|
799 |
+
sagemaker_params = json.loads(sagemaker_params)
|
800 |
+
if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
|
801 |
+
return False
|
802 |
+
except json.JSONDecodeError:
|
803 |
+
return False
|
804 |
+
# Lastly, check if the `smdistributed` module is present.
|
805 |
+
return _smdistributed_available
|
806 |
+
|
807 |
+
|
808 |
+
def is_sagemaker_mp_enabled():
|
809 |
+
# Get the sagemaker specific mp parameters from smp_options variable.
|
810 |
+
smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
|
811 |
+
try:
|
812 |
+
# Parse it and check the field "partitions" is included, it is required for model parallel.
|
813 |
+
smp_options = json.loads(smp_options)
|
814 |
+
if "partitions" not in smp_options:
|
815 |
+
return False
|
816 |
+
except json.JSONDecodeError:
|
817 |
+
return False
|
818 |
+
|
819 |
+
# Get the sagemaker specific framework parameters from mpi_options variable.
|
820 |
+
mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
|
821 |
+
try:
|
822 |
+
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
|
823 |
+
mpi_options = json.loads(mpi_options)
|
824 |
+
if not mpi_options.get("sagemaker_mpi_enabled", False):
|
825 |
+
return False
|
826 |
+
except json.JSONDecodeError:
|
827 |
+
return False
|
828 |
+
# Lastly, check if the `smdistributed` module is present.
|
829 |
+
return _smdistributed_available
|
830 |
+
|
831 |
+
|
832 |
+
def is_training_run_on_sagemaker():
|
833 |
+
return "SAGEMAKER_JOB_NAME" in os.environ
|
834 |
+
|
835 |
+
|
836 |
+
def is_soundfile_availble():
|
837 |
+
return _soundfile_available
|
838 |
+
|
839 |
+
|
840 |
+
def is_timm_available():
|
841 |
+
return _timm_available
|
842 |
+
|
843 |
+
|
844 |
+
def is_natten_available():
|
845 |
+
return _natten_available
|
846 |
+
|
847 |
+
|
848 |
+
def is_nltk_available():
|
849 |
+
return _nltk_available
|
850 |
+
|
851 |
+
|
852 |
+
def is_torchaudio_available():
|
853 |
+
return _torchaudio_available
|
854 |
+
|
855 |
+
|
856 |
+
def is_speech_available():
|
857 |
+
# For now this depends on torchaudio but the exact dependency might evolve in the future.
|
858 |
+
return _torchaudio_available
|
859 |
+
|
860 |
+
|
861 |
+
def is_phonemizer_available():
|
862 |
+
return _phonemizer_available
|
863 |
+
|
864 |
+
|
865 |
+
def torch_only_method(fn):
|
866 |
+
def wrapper(*args, **kwargs):
|
867 |
+
if not _torch_available:
|
868 |
+
raise ImportError(
|
869 |
+
"You need to install pytorch to use this method or class, "
|
870 |
+
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
|
871 |
+
)
|
872 |
+
else:
|
873 |
+
return fn(*args, **kwargs)
|
874 |
+
|
875 |
+
return wrapper
|
876 |
+
|
877 |
+
|
878 |
+
def is_ccl_available():
|
879 |
+
return _is_ccl_available
|
880 |
+
|
881 |
+
|
882 |
+
def is_decord_available():
|
883 |
+
return _decord_available
|
884 |
+
|
885 |
+
|
886 |
+
def is_sudachi_available():
|
887 |
+
return _sudachipy_available
|
888 |
+
|
889 |
+
|
890 |
+
def is_jumanpp_available():
|
891 |
+
return (importlib.util.find_spec("rhoknp") is not None) and (shutil.which("jumanpp") is not None)
|
892 |
+
|
893 |
+
|
894 |
+
def is_cython_available():
|
895 |
+
return importlib.util.find_spec("pyximport") is not None
|
896 |
+
|
897 |
+
|
898 |
+
def is_jieba_available():
|
899 |
+
return _jieba_available
|
900 |
+
|
901 |
+
|
902 |
+
def is_jinja_available():
|
903 |
+
return _jinja_available
|
904 |
+
|
905 |
+
|
906 |
+
# docstyle-ignore
|
907 |
+
CV2_IMPORT_ERROR = """
|
908 |
+
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
|
909 |
+
```
|
910 |
+
pip install opencv-python
|
911 |
+
```
|
912 |
+
Please note that you may need to restart your runtime after installation.
|
913 |
+
"""
|
914 |
+
|
915 |
+
|
916 |
+
# docstyle-ignore
|
917 |
+
DATASETS_IMPORT_ERROR = """
|
918 |
+
{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
|
919 |
+
```
|
920 |
+
pip install datasets
|
921 |
+
```
|
922 |
+
In a notebook or a colab, you can install it by executing a cell with
|
923 |
+
```
|
924 |
+
!pip install datasets
|
925 |
+
```
|
926 |
+
then restarting your kernel.
|
927 |
+
|
928 |
+
Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
|
929 |
+
working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
|
930 |
+
that python file if that's the case. Please note that you may need to restart your runtime after installation.
|
931 |
+
"""
|
932 |
+
|
933 |
+
|
934 |
+
# docstyle-ignore
|
935 |
+
TOKENIZERS_IMPORT_ERROR = """
|
936 |
+
{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with:
|
937 |
+
```
|
938 |
+
pip install tokenizers
|
939 |
+
```
|
940 |
+
In a notebook or a colab, you can install it by executing a cell with
|
941 |
+
```
|
942 |
+
!pip install tokenizers
|
943 |
+
```
|
944 |
+
Please note that you may need to restart your runtime after installation.
|
945 |
+
"""
|
946 |
+
|
947 |
+
|
948 |
+
# docstyle-ignore
|
949 |
+
SENTENCEPIECE_IMPORT_ERROR = """
|
950 |
+
{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
|
951 |
+
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
|
952 |
+
that match your environment. Please note that you may need to restart your runtime after installation.
|
953 |
+
"""
|
954 |
+
|
955 |
+
|
956 |
+
# docstyle-ignore
|
957 |
+
PROTOBUF_IMPORT_ERROR = """
|
958 |
+
{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
|
959 |
+
installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
|
960 |
+
that match your environment. Please note that you may need to restart your runtime after installation.
|
961 |
+
"""
|
962 |
+
|
963 |
+
|
964 |
+
# docstyle-ignore
|
965 |
+
FAISS_IMPORT_ERROR = """
|
966 |
+
{0} requires the faiss library but it was not found in your environment. Checkout the instructions on the
|
967 |
+
installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
|
968 |
+
that match your environment. Please note that you may need to restart your runtime after installation.
|
969 |
+
"""
|
970 |
+
|
971 |
+
|
972 |
+
# docstyle-ignore
|
973 |
+
PYTORCH_IMPORT_ERROR = """
|
974 |
+
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
|
975 |
+
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
|
976 |
+
Please note that you may need to restart your runtime after installation.
|
977 |
+
"""
|
978 |
+
|
979 |
+
|
980 |
+
# docstyle-ignore
|
981 |
+
TORCHVISION_IMPORT_ERROR = """
|
982 |
+
{0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the
|
983 |
+
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
|
984 |
+
Please note that you may need to restart your runtime after installation.
|
985 |
+
"""
|
986 |
+
|
987 |
+
# docstyle-ignore
|
988 |
+
PYTORCH_IMPORT_ERROR_WITH_TF = """
|
989 |
+
{0} requires the PyTorch library but it was not found in your environment.
|
990 |
+
However, we were able to find a TensorFlow installation. TensorFlow classes begin
|
991 |
+
with "TF", but are otherwise identically named to our PyTorch classes. This
|
992 |
+
means that the TF equivalent of the class you tried to import would be "TF{0}".
|
993 |
+
If you want to use TensorFlow, please use TF classes instead!
|
994 |
+
|
995 |
+
If you really do want to use PyTorch please go to
|
996 |
+
https://pytorch.org/get-started/locally/ and follow the instructions that
|
997 |
+
match your environment.
|
998 |
+
"""
|
999 |
+
|
1000 |
+
# docstyle-ignore
|
1001 |
+
TF_IMPORT_ERROR_WITH_PYTORCH = """
|
1002 |
+
{0} requires the TensorFlow library but it was not found in your environment.
|
1003 |
+
However, we were able to find a PyTorch installation. PyTorch classes do not begin
|
1004 |
+
with "TF", but are otherwise identically named to our TF classes.
|
1005 |
+
If you want to use PyTorch, please use those classes instead!
|
1006 |
+
|
1007 |
+
If you really do want to use TensorFlow, please follow the instructions on the
|
1008 |
+
installation page https://www.tensorflow.org/install that match your environment.
|
1009 |
+
"""
|
1010 |
+
|
1011 |
+
# docstyle-ignore
|
1012 |
+
BS4_IMPORT_ERROR = """
|
1013 |
+
{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
|
1014 |
+
`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
|
1015 |
+
"""
|
1016 |
+
|
1017 |
+
|
1018 |
+
# docstyle-ignore
|
1019 |
+
SKLEARN_IMPORT_ERROR = """
|
1020 |
+
{0} requires the scikit-learn library but it was not found in your environment. You can install it with:
|
1021 |
+
```
|
1022 |
+
pip install -U scikit-learn
|
1023 |
+
```
|
1024 |
+
In a notebook or a colab, you can install it by executing a cell with
|
1025 |
+
```
|
1026 |
+
!pip install -U scikit-learn
|
1027 |
+
```
|
1028 |
+
Please note that you may need to restart your runtime after installation.
|
1029 |
+
"""
|
1030 |
+
|
1031 |
+
|
1032 |
+
# docstyle-ignore
|
1033 |
+
TENSORFLOW_IMPORT_ERROR = """
|
1034 |
+
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
|
1035 |
+
installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
|
1036 |
+
Please note that you may need to restart your runtime after installation.
|
1037 |
+
"""
|
1038 |
+
|
1039 |
+
|
1040 |
+
# docstyle-ignore
|
1041 |
+
DETECTRON2_IMPORT_ERROR = """
|
1042 |
+
{0} requires the detectron2 library but it was not found in your environment. Checkout the instructions on the
|
1043 |
+
installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones
|
1044 |
+
that match your environment. Please note that you may need to restart your runtime after installation.
|
1045 |
+
"""
|
1046 |
+
|
1047 |
+
|
1048 |
+
# docstyle-ignore
|
1049 |
+
FLAX_IMPORT_ERROR = """
|
1050 |
+
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
1051 |
+
installation page: https://github.com/google/flax and follow the ones that match your environment.
|
1052 |
+
Please note that you may need to restart your runtime after installation.
|
1053 |
+
"""
|
1054 |
+
|
1055 |
+
# docstyle-ignore
|
1056 |
+
FTFY_IMPORT_ERROR = """
|
1057 |
+
{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
|
1058 |
+
installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
|
1059 |
+
that match your environment. Please note that you may need to restart your runtime after installation.
|
1060 |
+
"""
|
1061 |
+
|
1062 |
+
LEVENSHTEIN_IMPORT_ERROR = """
|
1063 |
+
{0} requires the python-Levenshtein library but it was not found in your environment. You can install it with pip: `pip
|
1064 |
+
install python-Levenshtein`. Please note that you may need to restart your runtime after installation.
|
1065 |
+
"""
|
1066 |
+
|
1067 |
+
# docstyle-ignore
|
1068 |
+
G2P_EN_IMPORT_ERROR = """
|
1069 |
+
{0} requires the g2p-en library but it was not found in your environment. You can install it with pip:
|
1070 |
+
`pip install g2p-en`. Please note that you may need to restart your runtime after installation.
|
1071 |
+
"""
|
1072 |
+
|
1073 |
+
# docstyle-ignore
|
1074 |
+
PYTORCH_QUANTIZATION_IMPORT_ERROR = """
|
1075 |
+
{0} requires the pytorch-quantization library but it was not found in your environment. You can install it with pip:
|
1076 |
+
`pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com`
|
1077 |
+
Please note that you may need to restart your runtime after installation.
|
1078 |
+
"""
|
1079 |
+
|
1080 |
+
# docstyle-ignore
|
1081 |
+
TENSORFLOW_PROBABILITY_IMPORT_ERROR = """
|
1082 |
+
{0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as
|
1083 |
+
explained here: https://github.com/tensorflow/probability. Please note that you may need to restart your runtime after installation.
|
1084 |
+
"""
|
1085 |
+
|
1086 |
+
# docstyle-ignore
|
1087 |
+
TENSORFLOW_TEXT_IMPORT_ERROR = """
|
1088 |
+
{0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as
|
1089 |
+
explained here: https://www.tensorflow.org/text/guide/tf_text_intro.
|
1090 |
+
Please note that you may need to restart your runtime after installation.
|
1091 |
+
"""
|
1092 |
+
|
1093 |
+
|
1094 |
+
# docstyle-ignore
|
1095 |
+
PANDAS_IMPORT_ERROR = """
|
1096 |
+
{0} requires the pandas library but it was not found in your environment. You can install it with pip as
|
1097 |
+
explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
|
1098 |
+
Please note that you may need to restart your runtime after installation.
|
1099 |
+
"""
|
1100 |
+
|
1101 |
+
|
1102 |
+
# docstyle-ignore
|
1103 |
+
PHONEMIZER_IMPORT_ERROR = """
|
1104 |
+
{0} requires the phonemizer library but it was not found in your environment. You can install it with pip:
|
1105 |
+
`pip install phonemizer`. Please note that you may need to restart your runtime after installation.
|
1106 |
+
"""
|
1107 |
+
|
1108 |
+
|
1109 |
+
# docstyle-ignore
|
1110 |
+
SACREMOSES_IMPORT_ERROR = """
|
1111 |
+
{0} requires the sacremoses library but it was not found in your environment. You can install it with pip:
|
1112 |
+
`pip install sacremoses`. Please note that you may need to restart your runtime after installation.
|
1113 |
+
"""
|
1114 |
+
|
1115 |
+
# docstyle-ignore
|
1116 |
+
SCIPY_IMPORT_ERROR = """
|
1117 |
+
{0} requires the scipy library but it was not found in your environment. You can install it with pip:
|
1118 |
+
`pip install scipy`. Please note that you may need to restart your runtime after installation.
|
1119 |
+
"""
|
1120 |
+
|
1121 |
+
|
1122 |
+
# docstyle-ignore
|
1123 |
+
SPEECH_IMPORT_ERROR = """
|
1124 |
+
{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
|
1125 |
+
`pip install torchaudio`. Please note that you may need to restart your runtime after installation.
|
1126 |
+
"""
|
1127 |
+
|
1128 |
+
# docstyle-ignore
|
1129 |
+
TIMM_IMPORT_ERROR = """
|
1130 |
+
{0} requires the timm library but it was not found in your environment. You can install it with pip:
|
1131 |
+
`pip install timm`. Please note that you may need to restart your runtime after installation.
|
1132 |
+
"""
|
1133 |
+
|
1134 |
+
# docstyle-ignore
|
1135 |
+
NATTEN_IMPORT_ERROR = """
|
1136 |
+
{0} requires the natten library but it was not found in your environment. You can install it by referring to:
|
1137 |
+
shi-labs.com/natten . You can also install it with pip (may take longer to build):
|
1138 |
+
`pip install natten`. Please note that you may need to restart your runtime after installation.
|
1139 |
+
"""
|
1140 |
+
|
1141 |
+
|
1142 |
+
# docstyle-ignore
|
1143 |
+
NLTK_IMPORT_ERROR = """
|
1144 |
+
{0} requires the NLTK library but it was not found in your environment. You can install it by referring to:
|
1145 |
+
https://www.nltk.org/install.html. Please note that you may need to restart your runtime after installation.
|
1146 |
+
"""
|
1147 |
+
|
1148 |
+
|
1149 |
+
# docstyle-ignore
|
1150 |
+
VISION_IMPORT_ERROR = """
|
1151 |
+
{0} requires the PIL library but it was not found in your environment. You can install it with pip:
|
1152 |
+
`pip install pillow`. Please note that you may need to restart your runtime after installation.
|
1153 |
+
"""
|
1154 |
+
|
1155 |
+
|
1156 |
+
# docstyle-ignore
|
1157 |
+
PYTESSERACT_IMPORT_ERROR = """
|
1158 |
+
{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
|
1159 |
+
`pip install pytesseract`. Please note that you may need to restart your runtime after installation.
|
1160 |
+
"""
|
1161 |
+
|
1162 |
+
# docstyle-ignore
|
1163 |
+
PYCTCDECODE_IMPORT_ERROR = """
|
1164 |
+
{0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip:
|
1165 |
+
`pip install pyctcdecode`. Please note that you may need to restart your runtime after installation.
|
1166 |
+
"""
|
1167 |
+
|
1168 |
+
# docstyle-ignore
|
1169 |
+
ACCELERATE_IMPORT_ERROR = """
|
1170 |
+
{0} requires the accelerate library >= {ACCELERATE_MIN_VERSION} it was not found in your environment.
|
1171 |
+
You can install or update it with pip: `pip install --upgrade accelerate`. Please note that you may need to restart your
|
1172 |
+
runtime after installation.
|
1173 |
+
"""
|
1174 |
+
|
1175 |
+
# docstyle-ignore
|
1176 |
+
CCL_IMPORT_ERROR = """
|
1177 |
+
{0} requires the torch ccl library but it was not found in your environment. You can install it with pip:
|
1178 |
+
`pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable`
|
1179 |
+
Please note that you may need to restart your runtime after installation.
|
1180 |
+
"""
|
1181 |
+
|
1182 |
+
# docstyle-ignore
|
1183 |
+
ESSENTIA_IMPORT_ERROR = """
|
1184 |
+
{0} requires essentia library. But that was not found in your environment. You can install them with pip:
|
1185 |
+
`pip install essentia==2.1b6.dev1034`
|
1186 |
+
Please note that you may need to restart your runtime after installation.
|
1187 |
+
"""
|
1188 |
+
|
1189 |
+
# docstyle-ignore
|
1190 |
+
LIBROSA_IMPORT_ERROR = """
|
1191 |
+
{0} requires thes librosa library. But that was not found in your environment. You can install them with pip:
|
1192 |
+
`pip install librosa`
|
1193 |
+
Please note that you may need to restart your runtime after installation.
|
1194 |
+
"""
|
1195 |
+
|
1196 |
+
# docstyle-ignore
|
1197 |
+
PRETTY_MIDI_IMPORT_ERROR = """
|
1198 |
+
{0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip:
|
1199 |
+
`pip install pretty_midi`
|
1200 |
+
Please note that you may need to restart your runtime after installation.
|
1201 |
+
"""
|
1202 |
+
|
1203 |
+
DECORD_IMPORT_ERROR = """
|
1204 |
+
{0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install
|
1205 |
+
decord`. Please note that you may need to restart your runtime after installation.
|
1206 |
+
"""
|
1207 |
+
|
1208 |
+
CYTHON_IMPORT_ERROR = """
|
1209 |
+
{0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install
|
1210 |
+
Cython`. Please note that you may need to restart your runtime after installation.
|
1211 |
+
"""
|
1212 |
+
|
1213 |
+
JIEBA_IMPORT_ERROR = """
|
1214 |
+
{0} requires the jieba library but it was not found in your environment. You can install it with pip: `pip install
|
1215 |
+
jieba`. Please note that you may need to restart your runtime after installation.
|
1216 |
+
"""
|
1217 |
+
|
1218 |
+
PEFT_IMPORT_ERROR = """
|
1219 |
+
{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install
|
1220 |
+
peft`. Please note that you may need to restart your runtime after installation.
|
1221 |
+
"""
|
1222 |
+
|
1223 |
+
JINJA_IMPORT_ERROR = """
|
1224 |
+
{0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install
|
1225 |
+
jinja2`. Please note that you may need to restart your runtime after installation.
|
1226 |
+
"""
|
1227 |
+
|
1228 |
+
BACKENDS_MAPPING = OrderedDict(
|
1229 |
+
[
|
1230 |
+
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
1231 |
+
("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
|
1232 |
+
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
|
1233 |
+
("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
|
1234 |
+
("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)),
|
1235 |
+
("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
|
1236 |
+
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
1237 |
+
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
|
1238 |
+
("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)),
|
1239 |
+
("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
|
1240 |
+
("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)),
|
1241 |
+
("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)),
|
1242 |
+
("levenshtein", (is_levenshtein_available, LEVENSHTEIN_IMPORT_ERROR)),
|
1243 |
+
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
|
1244 |
+
("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
|
1245 |
+
("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
|
1246 |
+
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
|
1247 |
+
("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)),
|
1248 |
+
("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)),
|
1249 |
+
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
|
1250 |
+
("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
|
1251 |
+
("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
|
1252 |
+
("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)),
|
1253 |
+
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
|
1254 |
+
("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)),
|
1255 |
+
("timm", (is_timm_available, TIMM_IMPORT_ERROR)),
|
1256 |
+
("natten", (is_natten_available, NATTEN_IMPORT_ERROR)),
|
1257 |
+
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
|
1258 |
+
("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
|
1259 |
+
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
1260 |
+
("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
|
1261 |
+
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
|
1262 |
+
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
1263 |
+
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
|
1264 |
+
("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)),
|
1265 |
+
("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
|
1266 |
+
("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
|
1267 |
+
("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
|
1268 |
+
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
|
1269 |
+
("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
|
1270 |
+
]
|
1271 |
+
)
|
1272 |
+
|
1273 |
+
|
1274 |
+
def requires_backends(obj, backends):
|
1275 |
+
if not isinstance(backends, (list, tuple)):
|
1276 |
+
backends = [backends]
|
1277 |
+
|
1278 |
+
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
1279 |
+
|
1280 |
+
# Raise an error for users who might not realize that classes without "TF" are torch-only
|
1281 |
+
if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available():
|
1282 |
+
raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
|
1283 |
+
|
1284 |
+
# Raise the inverse error for PyTorch users trying to load TF classes
|
1285 |
+
if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available():
|
1286 |
+
raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
1287 |
+
|
1288 |
+
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
1289 |
+
failed = [msg.format(name) for available, msg in checks if not available()]
|
1290 |
+
if failed:
|
1291 |
+
raise ImportError("".join(failed))
|
1292 |
+
|
1293 |
+
|
1294 |
+
class DummyObject(type):
|
1295 |
+
"""
|
1296 |
+
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
|
1297 |
+
`requires_backend` each time a user tries to access any method of that class.
|
1298 |
+
"""
|
1299 |
+
|
1300 |
+
def __getattribute__(cls, key):
|
1301 |
+
if key.startswith("_") and key != "_from_config":
|
1302 |
+
return super().__getattribute__(key)
|
1303 |
+
requires_backends(cls, cls._backends)
|
1304 |
+
|
1305 |
+
|
1306 |
+
def torch_required(func):
|
1307 |
+
warnings.warn(
|
1308 |
+
"The method `torch_required` is deprecated and will be removed in v4.36. Use `requires_backends` instead.",
|
1309 |
+
FutureWarning,
|
1310 |
+
)
|
1311 |
+
|
1312 |
+
# Chose a different decorator name than in tests so it's clear they are not the same.
|
1313 |
+
@wraps(func)
|
1314 |
+
def wrapper(*args, **kwargs):
|
1315 |
+
if is_torch_available():
|
1316 |
+
return func(*args, **kwargs)
|
1317 |
+
else:
|
1318 |
+
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
|
1319 |
+
|
1320 |
+
return wrapper
|
1321 |
+
|
1322 |
+
|
1323 |
+
def tf_required(func):
|
1324 |
+
warnings.warn(
|
1325 |
+
"The method `tf_required` is deprecated and will be removed in v4.36. Use `requires_backends` instead.",
|
1326 |
+
FutureWarning,
|
1327 |
+
)
|
1328 |
+
|
1329 |
+
# Chose a different decorator name than in tests so it's clear they are not the same.
|
1330 |
+
@wraps(func)
|
1331 |
+
def wrapper(*args, **kwargs):
|
1332 |
+
if is_tf_available():
|
1333 |
+
return func(*args, **kwargs)
|
1334 |
+
else:
|
1335 |
+
raise ImportError(f"Method `{func.__name__}` requires TF.")
|
1336 |
+
|
1337 |
+
return wrapper
|
1338 |
+
|
1339 |
+
|
1340 |
+
def is_torch_fx_proxy(x):
|
1341 |
+
if is_torch_fx_available():
|
1342 |
+
import torch.fx
|
1343 |
+
|
1344 |
+
return isinstance(x, torch.fx.Proxy)
|
1345 |
+
return False
|
1346 |
+
|
1347 |
+
|
1348 |
+
class _LazyModule(ModuleType):
|
1349 |
+
"""
|
1350 |
+
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
1351 |
+
"""
|
1352 |
+
|
1353 |
+
# Very heavily inspired by optuna.integration._IntegrationModule
|
1354 |
+
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
1355 |
+
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
|
1356 |
+
super().__init__(name)
|
1357 |
+
self._modules = set(import_structure.keys())
|
1358 |
+
self._class_to_module = {}
|
1359 |
+
for key, values in import_structure.items():
|
1360 |
+
for value in values:
|
1361 |
+
self._class_to_module[value] = key
|
1362 |
+
# Needed for autocompletion in an IDE
|
1363 |
+
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
|
1364 |
+
self.__file__ = module_file
|
1365 |
+
self.__spec__ = module_spec
|
1366 |
+
self.__path__ = [os.path.dirname(module_file)]
|
1367 |
+
self._objects = {} if extra_objects is None else extra_objects
|
1368 |
+
self._name = name
|
1369 |
+
self._import_structure = import_structure
|
1370 |
+
|
1371 |
+
# Needed for autocompletion in an IDE
|
1372 |
+
def __dir__(self):
|
1373 |
+
result = super().__dir__()
|
1374 |
+
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
|
1375 |
+
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
|
1376 |
+
for attr in self.__all__:
|
1377 |
+
if attr not in result:
|
1378 |
+
result.append(attr)
|
1379 |
+
return result
|
1380 |
+
|
1381 |
+
def __getattr__(self, name: str) -> Any:
|
1382 |
+
if name in self._objects:
|
1383 |
+
return self._objects[name]
|
1384 |
+
if name in self._modules:
|
1385 |
+
value = self._get_module(name)
|
1386 |
+
elif name in self._class_to_module.keys():
|
1387 |
+
module = self._get_module(self._class_to_module[name])
|
1388 |
+
value = getattr(module, name)
|
1389 |
+
else:
|
1390 |
+
raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
1391 |
+
|
1392 |
+
setattr(self, name, value)
|
1393 |
+
return value
|
1394 |
+
|
1395 |
+
def _get_module(self, module_name: str):
|
1396 |
+
try:
|
1397 |
+
return importlib.import_module("." + module_name, self.__name__)
|
1398 |
+
except Exception as e:
|
1399 |
+
raise RuntimeError(
|
1400 |
+
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
|
1401 |
+
f" traceback):\n{e}"
|
1402 |
+
) from e
|
1403 |
+
|
1404 |
+
def __reduce__(self):
|
1405 |
+
return (self.__class__, (self._name, self.__file__, self._import_structure))
|
1406 |
+
|
1407 |
+
|
1408 |
+
class OptionalDependencyNotAvailable(BaseException):
|
1409 |
+
"""Internally used error class for signalling an optional dependency was not found."""
|
1410 |
+
|
1411 |
+
|
1412 |
+
def direct_transformers_import(path: str, file="__init__.py") -> ModuleType:
|
1413 |
+
"""Imports transformers directly
|
1414 |
+
|
1415 |
+
Args:
|
1416 |
+
path (`str`): The path to the source file
|
1417 |
+
file (`str`, optional): The file to join with the path. Defaults to "__init__.py".
|
1418 |
+
|
1419 |
+
Returns:
|
1420 |
+
`ModuleType`: The resulting imported module
|
1421 |
+
"""
|
1422 |
+
name = "transformers"
|
1423 |
+
location = os.path.join(path, file)
|
1424 |
+
spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path])
|
1425 |
+
module = importlib.util.module_from_spec(spec)
|
1426 |
+
spec.loader.exec_module(module)
|
1427 |
+
module = sys.modules[name]
|
1428 |
+
return module
|
tinychart/model/llava_arch.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Baichuan Zhou , Junlong Jia, Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from tinychart.model.multimodal_encoder.builder import build_vision_tower
|
22 |
+
from tinychart.model.multimodal_projector.builder import build_vision_projector
|
23 |
+
|
24 |
+
from tinychart.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
25 |
+
|
26 |
+
|
27 |
+
class LlavaMetaModel:
|
28 |
+
|
29 |
+
def __init__(self, config):
|
30 |
+
super(LlavaMetaModel, self).__init__(config)
|
31 |
+
|
32 |
+
if hasattr(config, "mm_vision_tower"):
|
33 |
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
34 |
+
self.mm_projector = build_vision_projector(config)
|
35 |
+
|
36 |
+
if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
|
37 |
+
self.image_newline = nn.Parameter(
|
38 |
+
torch.empty(config.hidden_size, dtype=self.dtype)
|
39 |
+
)
|
40 |
+
|
41 |
+
def get_vision_tower(self):
|
42 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
43 |
+
if type(vision_tower) is list:
|
44 |
+
vision_tower = vision_tower[0]
|
45 |
+
return vision_tower
|
46 |
+
|
47 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
48 |
+
vision_tower = model_args.vision_tower
|
49 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
50 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
51 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
52 |
+
mm_patch_merge_type = model_args.mm_patch_merge_type
|
53 |
+
|
54 |
+
self.config.mm_vision_tower = vision_tower
|
55 |
+
|
56 |
+
if self.get_vision_tower() is None:
|
57 |
+
vision_tower = build_vision_tower(model_args)
|
58 |
+
|
59 |
+
if fsdp is not None and len(fsdp) > 0:
|
60 |
+
self.vision_tower = [vision_tower]
|
61 |
+
else:
|
62 |
+
self.vision_tower = vision_tower
|
63 |
+
|
64 |
+
elif self.get_vision_tower().vision_tower_name != vision_tower:
|
65 |
+
print(f"rebuilding vision tower! vision tower initialized from: {vision_tower}")
|
66 |
+
vision_tower = build_vision_tower(model_args)
|
67 |
+
if fsdp is not None and len(fsdp) > 0:
|
68 |
+
self.vision_tower = [vision_tower]
|
69 |
+
else:
|
70 |
+
self.vision_tower = vision_tower
|
71 |
+
|
72 |
+
else:
|
73 |
+
if fsdp is not None and len(fsdp) > 0:
|
74 |
+
vision_tower = self.vision_tower[0]
|
75 |
+
else:
|
76 |
+
vision_tower = self.vision_tower
|
77 |
+
vision_tower.load_model()
|
78 |
+
|
79 |
+
self.config.use_mm_proj = True
|
80 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
81 |
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
82 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
83 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
84 |
+
self.config.mm_patch_merge_type = mm_patch_merge_type
|
85 |
+
|
86 |
+
if getattr(self, 'mm_projector', None) is None:
|
87 |
+
self.mm_projector = build_vision_projector(self.config)
|
88 |
+
|
89 |
+
if 'unpad' in mm_patch_merge_type:
|
90 |
+
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
91 |
+
self.image_newline = nn.Parameter(
|
92 |
+
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
# In case it is frozen by LoRA
|
96 |
+
for p in self.mm_projector.parameters():
|
97 |
+
p.requires_grad = True
|
98 |
+
|
99 |
+
if pretrain_mm_mlp_adapter is not None:
|
100 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
101 |
+
def get_w(weights, keyword):
|
102 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
103 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
104 |
+
|
105 |
+
|
106 |
+
def unpad_image(tensor, original_size):
|
107 |
+
"""
|
108 |
+
Unpads a PyTorch tensor of a padded and resized image.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
112 |
+
original_size (tuple): The original size of the image (height, width).
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
torch.Tensor: The unpadded image tensor.
|
116 |
+
"""
|
117 |
+
original_width, original_height = original_size
|
118 |
+
current_height, current_width = tensor.shape[1:]
|
119 |
+
|
120 |
+
original_aspect_ratio = original_width / original_height
|
121 |
+
current_aspect_ratio = current_width / current_height
|
122 |
+
|
123 |
+
if original_aspect_ratio > current_aspect_ratio:
|
124 |
+
scale_factor = current_width / original_width
|
125 |
+
new_height = int(original_height * scale_factor)
|
126 |
+
padding = (current_height - new_height) // 2
|
127 |
+
unpadded_tensor = tensor[:, padding:current_height - padding, :]
|
128 |
+
else:
|
129 |
+
scale_factor = current_height / original_height
|
130 |
+
new_width = int(original_width * scale_factor)
|
131 |
+
padding = (current_width - new_width) // 2
|
132 |
+
unpadded_tensor = tensor[:, :, padding:current_width - padding]
|
133 |
+
|
134 |
+
return unpadded_tensor
|
135 |
+
|
136 |
+
|
137 |
+
class LlavaMetaForCausalLM(ABC):
|
138 |
+
|
139 |
+
@abstractmethod
|
140 |
+
def get_model(self):
|
141 |
+
pass
|
142 |
+
|
143 |
+
def get_vision_tower(self):
|
144 |
+
return self.get_model().get_vision_tower()
|
145 |
+
|
146 |
+
def encode_images(self, images):
|
147 |
+
image_features = self.get_model().get_vision_tower()(images)
|
148 |
+
image_features = self.get_model().mm_projector(image_features)
|
149 |
+
return image_features
|
150 |
+
|
151 |
+
def prepare_inputs_labels_for_multimodal(
|
152 |
+
self, input_ids, position_ids, attention_mask, past_key_values, labels,
|
153 |
+
images, image_sizes=None
|
154 |
+
):
|
155 |
+
vision_tower = self.get_vision_tower()
|
156 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
157 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
158 |
+
|
159 |
+
if type(images) is list or images.ndim == 5:
|
160 |
+
if type(images) is list:
|
161 |
+
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
|
162 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
163 |
+
image_features = self.encode_images(concat_images)
|
164 |
+
split_sizes = [image.shape[0] for image in images]
|
165 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
166 |
+
mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
|
167 |
+
image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
|
168 |
+
if mm_patch_merge_type == 'flat':
|
169 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
170 |
+
elif mm_patch_merge_type.startswith('spatial'):
|
171 |
+
new_image_features = []
|
172 |
+
for image_idx, image_feature in enumerate(image_features):
|
173 |
+
if image_feature.shape[0] > 1:
|
174 |
+
base_image_feature = image_feature[0]
|
175 |
+
image_feature = image_feature[1:]
|
176 |
+
height = width = self.get_vision_tower().num_patches_per_side
|
177 |
+
assert height * width == base_image_feature.shape[0]
|
178 |
+
if 'unpad' in mm_patch_merge_type:
|
179 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
180 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
181 |
+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
182 |
+
image_feature = torch.cat((
|
183 |
+
image_feature,
|
184 |
+
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
|
185 |
+
), dim=-1)
|
186 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
187 |
+
else:
|
188 |
+
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
189 |
+
image_feature = image_feature.flatten(0, 3)
|
190 |
+
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
191 |
+
else:
|
192 |
+
image_feature = image_feature[0]
|
193 |
+
if 'unpad' in mm_patch_merge_type:
|
194 |
+
image_feature = torch.cat((
|
195 |
+
image_feature,
|
196 |
+
self.model.image_newline[None].to(image_feature.device)
|
197 |
+
), dim=0)
|
198 |
+
new_image_features.append(image_feature)
|
199 |
+
image_features = new_image_features
|
200 |
+
else:
|
201 |
+
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
202 |
+
else:
|
203 |
+
image_features = self.encode_images(images)
|
204 |
+
|
205 |
+
# TODO: image start / end is not implemented here to support pretraining.
|
206 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
207 |
+
raise NotImplementedError
|
208 |
+
|
209 |
+
# Let's just add dummy tensors if they do not exist,
|
210 |
+
# it is a headache to deal with None all the time.
|
211 |
+
# But it is not ideal, and if you have a better idea,
|
212 |
+
# please open an issue / submit a PR, thanks.
|
213 |
+
_labels = labels
|
214 |
+
_position_ids = position_ids
|
215 |
+
_attention_mask = attention_mask
|
216 |
+
if attention_mask is None:
|
217 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
218 |
+
else:
|
219 |
+
attention_mask = attention_mask.bool()
|
220 |
+
if position_ids is None:
|
221 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
222 |
+
if labels is None:
|
223 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
224 |
+
|
225 |
+
# remove the padding using attention_mask -- FIXME
|
226 |
+
_input_ids = input_ids
|
227 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
228 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
229 |
+
|
230 |
+
new_input_embeds = []
|
231 |
+
new_labels = []
|
232 |
+
cur_image_idx = 0
|
233 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
234 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
235 |
+
if num_images == 0:
|
236 |
+
cur_image_features = image_features[cur_image_idx]
|
237 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
238 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
239 |
+
new_input_embeds.append(cur_input_embeds)
|
240 |
+
new_labels.append(labels[batch_idx])
|
241 |
+
cur_image_idx += 1
|
242 |
+
continue
|
243 |
+
|
244 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
245 |
+
cur_input_ids_noim = []
|
246 |
+
cur_labels = labels[batch_idx]
|
247 |
+
cur_labels_noim = []
|
248 |
+
for i in range(len(image_token_indices) - 1):
|
249 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
|
250 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
|
251 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
252 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
253 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
254 |
+
cur_new_input_embeds = []
|
255 |
+
cur_new_labels = []
|
256 |
+
|
257 |
+
for i in range(num_images + 1):
|
258 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
259 |
+
cur_new_labels.append(cur_labels_noim[i])
|
260 |
+
if i < num_images:
|
261 |
+
cur_image_features = image_features[cur_image_idx]
|
262 |
+
cur_image_idx += 1
|
263 |
+
cur_new_input_embeds.append(cur_image_features)
|
264 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
265 |
+
|
266 |
+
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
267 |
+
|
268 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
269 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
270 |
+
|
271 |
+
new_input_embeds.append(cur_new_input_embeds)
|
272 |
+
new_labels.append(cur_new_labels)
|
273 |
+
|
274 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
275 |
+
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
276 |
+
if tokenizer_model_max_length is not None:
|
277 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
278 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
279 |
+
|
280 |
+
# Combine them
|
281 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
282 |
+
batch_size = len(new_input_embeds)
|
283 |
+
|
284 |
+
new_input_embeds_padded = []
|
285 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
286 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
287 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
288 |
+
|
289 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
290 |
+
cur_len = cur_new_embed.shape[0]
|
291 |
+
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
292 |
+
new_input_embeds_padded.append(torch.cat((
|
293 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
|
294 |
+
cur_new_embed
|
295 |
+
), dim=0))
|
296 |
+
if cur_len > 0:
|
297 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
298 |
+
attention_mask[i, -cur_len:] = True
|
299 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
300 |
+
else:
|
301 |
+
new_input_embeds_padded.append(torch.cat((
|
302 |
+
cur_new_embed,
|
303 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
|
304 |
+
), dim=0))
|
305 |
+
if cur_len > 0:
|
306 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
307 |
+
attention_mask[i, :cur_len] = True
|
308 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
309 |
+
|
310 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
311 |
+
|
312 |
+
if _labels is None:
|
313 |
+
new_labels = None
|
314 |
+
else:
|
315 |
+
new_labels = new_labels_padded
|
316 |
+
|
317 |
+
if _attention_mask is None:
|
318 |
+
attention_mask = None
|
319 |
+
else:
|
320 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
321 |
+
|
322 |
+
if _position_ids is None:
|
323 |
+
position_ids = None
|
324 |
+
|
325 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
326 |
+
|
327 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
328 |
+
if model_args.mm_use_im_patch_token:
|
329 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
330 |
+
self.resize_token_embeddings(len(tokenizer))
|
331 |
+
|
332 |
+
if model_args.mm_use_im_start_end:
|
333 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
334 |
+
self.resize_token_embeddings(len(tokenizer))
|
335 |
+
|
336 |
+
if num_new_tokens > 0:
|
337 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
338 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
339 |
+
|
340 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
341 |
+
dim=0, keepdim=True)
|
342 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
343 |
+
dim=0, keepdim=True)
|
344 |
+
|
345 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
346 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
347 |
+
|
348 |
+
if model_args.tune_mm_mlp_adapter:
|
349 |
+
for p in self.get_input_embeddings().parameters():
|
350 |
+
p.requires_grad = True
|
351 |
+
for p in self.get_output_embeddings().parameters():
|
352 |
+
p.requires_grad = False
|
353 |
+
|
354 |
+
if model_args.pretrain_mm_mlp_adapter:
|
355 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
356 |
+
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
357 |
+
assert num_new_tokens == 2
|
358 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
359 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
360 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
361 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
362 |
+
else:
|
363 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
364 |
+
elif model_args.tune_embed_tokens:
|
365 |
+
for p in self.get_input_embeddings().parameters():
|
366 |
+
p.requires_grad = True
|
367 |
+
for p in self.get_output_embeddings().parameters():
|
368 |
+
p.requires_grad = False
|
369 |
+
print("Set input embeddings to trainable")
|
370 |
+
|
371 |
+
elif model_args.mm_use_im_patch_token:
|
372 |
+
if model_args.tune_mm_mlp_adapter:
|
373 |
+
for p in self.get_input_embeddings().parameters():
|
374 |
+
p.requires_grad = False
|
375 |
+
for p in self.get_output_embeddings().parameters():
|
376 |
+
p.requires_grad = False
|
377 |
+
|
378 |
+
if model_args.pretrain_mm_mlp_adapter:
|
379 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
|
380 |
+
if 'model.embed_tokens.weight' in mm_projector_weights.keys():
|
381 |
+
def get_w(weights, keyword):
|
382 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
383 |
+
self.get_model().embed_tokens.load_state_dict(get_w(mm_projector_weights, 'model.embed_tokens'))
|
tinychart/model/model_factory.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib
|
3 |
+
|
4 |
+
MODEL_REGISTRY = {}
|
5 |
+
TOKENIZER_REGISTRY = {}
|
6 |
+
|
7 |
+
|
8 |
+
def ModelSelect(model_name_or_path):
|
9 |
+
model = None
|
10 |
+
for name in MODEL_REGISTRY.keys():
|
11 |
+
if name in model_name_or_path.lower():
|
12 |
+
model = MODEL_REGISTRY[name]
|
13 |
+
if model is None:
|
14 |
+
model = MODEL_REGISTRY['llama']
|
15 |
+
return model
|
16 |
+
|
17 |
+
|
18 |
+
def TokenizerSelect(model_name_or_path):
|
19 |
+
tokenizer_init = None
|
20 |
+
for name in TOKENIZER_REGISTRY.keys():
|
21 |
+
if name in model_name_or_path.lower():
|
22 |
+
tokenizer_init = TOKENIZER_REGISTRY[name]
|
23 |
+
if tokenizer_init is None:
|
24 |
+
tokenizer_init = TOKENIZER_REGISTRY['llama']
|
25 |
+
return tokenizer_init
|
26 |
+
|
27 |
+
|
28 |
+
def register_model(name):
|
29 |
+
def register_model_cls(cls):
|
30 |
+
if name in MODEL_REGISTRY:
|
31 |
+
return MODEL_REGISTRY[name]
|
32 |
+
|
33 |
+
MODEL_REGISTRY[name] = cls
|
34 |
+
return cls
|
35 |
+
|
36 |
+
return register_model_cls
|
37 |
+
|
38 |
+
|
39 |
+
def register_tokenizer(name):
|
40 |
+
def register_tokenizer_cls(cls):
|
41 |
+
if name in TOKENIZER_REGISTRY:
|
42 |
+
return TOKENIZER_REGISTRY[name]
|
43 |
+
|
44 |
+
TOKENIZER_REGISTRY[name] = cls
|
45 |
+
return cls
|
46 |
+
|
47 |
+
return register_tokenizer_cls
|
48 |
+
|
49 |
+
|
50 |
+
def import_models(models_dir, namespace):
|
51 |
+
for file in os.listdir(models_dir):
|
52 |
+
path = os.path.join(models_dir, file)
|
53 |
+
if (
|
54 |
+
not file.startswith("_")
|
55 |
+
and not file.startswith(".")
|
56 |
+
and file.endswith(".py")
|
57 |
+
):
|
58 |
+
model_name = file[: file.find(".py")] if file.endswith(".py") else file
|
59 |
+
importlib.import_module(namespace + "." + model_name)
|
60 |
+
|
61 |
+
|
62 |
+
# automatically import any Python files in the models/ directory
|
63 |
+
models_dir = os.path.join(os.path.dirname(__file__), 'language_model')
|
64 |
+
import_models(models_dir, "tinychart.model.language_model")
|
tinychart/model/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tinychart.model.multimodal_encoder.siglip_encoder import SigLipVisionTower
|
3 |
+
|
4 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
5 |
+
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
6 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
7 |
+
return SigLipVisionTower(vision_tower, vision_tower_cfg, **kwargs)
|
tinychart/model/multimodal_encoder/merge.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
from typing import Callable, Tuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def do_nothing(x, mode=None):
|
15 |
+
return x
|
16 |
+
|
17 |
+
|
18 |
+
def bipartite_soft_matching(
|
19 |
+
metric: torch.Tensor,
|
20 |
+
r: int,
|
21 |
+
class_token: bool = False,
|
22 |
+
distill_token: bool = False,
|
23 |
+
) -> Tuple[Callable, Callable]:
|
24 |
+
"""
|
25 |
+
Applies ToMe with a balanced matching set (50%, 50%).
|
26 |
+
|
27 |
+
Input size is [batch, tokens, channels].
|
28 |
+
r indicates the number of tokens to remove (max 50% of tokens).
|
29 |
+
|
30 |
+
Extra args:
|
31 |
+
- class_token: Whether or not there's a class token.
|
32 |
+
- distill_token: Whether or not there's also a distillation token.
|
33 |
+
|
34 |
+
When enabled, the class token and distillation tokens won't get merged.
|
35 |
+
"""
|
36 |
+
protected = 0
|
37 |
+
if class_token:
|
38 |
+
protected += 1
|
39 |
+
if distill_token:
|
40 |
+
protected += 1
|
41 |
+
|
42 |
+
# We can only reduce by a maximum of 50% tokens
|
43 |
+
t = metric.shape[1]
|
44 |
+
r = min(r, (t - protected) // 2)
|
45 |
+
|
46 |
+
if r <= 0:
|
47 |
+
return do_nothing, do_nothing
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
metric = metric / metric.norm(dim=-1, keepdim=True)
|
51 |
+
a, b = metric[..., ::2, :], metric[..., 1::2, :]
|
52 |
+
scores = a @ b.transpose(-1, -2)
|
53 |
+
|
54 |
+
if class_token:
|
55 |
+
scores[..., 0, :] = -math.inf
|
56 |
+
if distill_token:
|
57 |
+
scores[..., :, 0] = -math.inf
|
58 |
+
|
59 |
+
node_max, node_idx = scores.max(dim=-1)
|
60 |
+
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
61 |
+
|
62 |
+
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
63 |
+
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
64 |
+
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
|
65 |
+
|
66 |
+
if class_token:
|
67 |
+
# Sort to ensure the class token is at the start
|
68 |
+
unm_idx = unm_idx.sort(dim=1)[0]
|
69 |
+
|
70 |
+
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
71 |
+
src, dst = x[..., ::2, :], x[..., 1::2, :]
|
72 |
+
n, t1, c = src.shape
|
73 |
+
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
74 |
+
src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
|
75 |
+
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
76 |
+
|
77 |
+
if distill_token:
|
78 |
+
return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
|
79 |
+
else:
|
80 |
+
return torch.cat([unm, dst], dim=1)
|
81 |
+
|
82 |
+
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
83 |
+
unm_len = unm_idx.shape[1]
|
84 |
+
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
85 |
+
n, _, c = unm.shape
|
86 |
+
|
87 |
+
src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))
|
88 |
+
|
89 |
+
out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)
|
90 |
+
|
91 |
+
out[..., 1::2, :] = dst
|
92 |
+
out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
|
93 |
+
out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
return merge, unmerge
|
98 |
+
|
99 |
+
|
100 |
+
def kth_bipartite_soft_matching(
|
101 |
+
metric: torch.Tensor, k: int
|
102 |
+
) -> Tuple[Callable, Callable]:
|
103 |
+
"""
|
104 |
+
Applies ToMe with the two sets as (every kth element, the rest).
|
105 |
+
If n is the number of tokens, resulting number of tokens will be n // z.
|
106 |
+
|
107 |
+
Input size is [batch, tokens, channels].
|
108 |
+
z indicates the stride for the first set.
|
109 |
+
z = 2 is equivalent to regular bipartite_soft_matching with r = 0.5 * N
|
110 |
+
"""
|
111 |
+
if k <= 1:
|
112 |
+
return do_nothing, do_nothing
|
113 |
+
|
114 |
+
def split(x):
|
115 |
+
t_rnd = (x.shape[1] // k) * k
|
116 |
+
x = x[:, :t_rnd, :].view(x.shape[0], -1, k, x.shape[2])
|
117 |
+
a, b = (
|
118 |
+
x[:, :, : (k - 1), :].contiguous().view(x.shape[0], -1, x.shape[-1]),
|
119 |
+
x[:, :, (k - 1), :],
|
120 |
+
)
|
121 |
+
return a, b
|
122 |
+
|
123 |
+
with torch.no_grad():
|
124 |
+
metric = metric / metric.norm(dim=-1, keepdim=True)
|
125 |
+
a, b = split(metric)
|
126 |
+
r = a.shape[1]
|
127 |
+
scores = a @ b.transpose(-1, -2)
|
128 |
+
|
129 |
+
_, dst_idx = scores.max(dim=-1)
|
130 |
+
dst_idx = dst_idx[..., None]
|
131 |
+
|
132 |
+
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
133 |
+
src, dst = split(x)
|
134 |
+
n, _, c = src.shape
|
135 |
+
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
136 |
+
|
137 |
+
return dst
|
138 |
+
|
139 |
+
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
140 |
+
n, _, c = x.shape
|
141 |
+
dst = x
|
142 |
+
|
143 |
+
src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)).to(x.dtype)
|
144 |
+
|
145 |
+
src = src.view(n, -1, (k - 1), c)
|
146 |
+
dst = dst.view(n, -1, 1, c)
|
147 |
+
|
148 |
+
out = torch.cat([src, dst], dim=-2)
|
149 |
+
out = out.contiguous().view(n, -1, c)
|
150 |
+
|
151 |
+
return out
|
152 |
+
|
153 |
+
return merge, unmerge
|
154 |
+
|
155 |
+
|
156 |
+
def random_bipartite_soft_matching(
|
157 |
+
metric: torch.Tensor, r: int
|
158 |
+
) -> Tuple[Callable, Callable]:
|
159 |
+
"""
|
160 |
+
Applies ToMe with the two sets as (r chosen randomly, the rest).
|
161 |
+
Input size is [batch, tokens, channels].
|
162 |
+
|
163 |
+
This will reduce the number of tokens by r.
|
164 |
+
"""
|
165 |
+
if r <= 0:
|
166 |
+
return do_nothing, do_nothing
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
B, N, _ = metric.shape
|
170 |
+
rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1)
|
171 |
+
|
172 |
+
a_idx = rand_idx[:, :r, :]
|
173 |
+
b_idx = rand_idx[:, r:, :]
|
174 |
+
|
175 |
+
def split(x):
|
176 |
+
C = x.shape[-1]
|
177 |
+
a = x.gather(dim=1, index=a_idx.expand(B, r, C))
|
178 |
+
b = x.gather(dim=1, index=b_idx.expand(B, N - r, C))
|
179 |
+
return a, b
|
180 |
+
|
181 |
+
metric = metric / metric.norm(dim=-1, keepdim=True)
|
182 |
+
a, b = split(metric)
|
183 |
+
scores = a @ b.transpose(-1, -2)
|
184 |
+
|
185 |
+
_, dst_idx = scores.max(dim=-1)
|
186 |
+
dst_idx = dst_idx[..., None]
|
187 |
+
|
188 |
+
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
189 |
+
src, dst = split(x)
|
190 |
+
C = src.shape[-1]
|
191 |
+
dst = dst.scatter_reduce(-2, dst_idx.expand(B, r, C), src, reduce=mode)
|
192 |
+
|
193 |
+
return dst
|
194 |
+
|
195 |
+
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
196 |
+
C = x.shape[-1]
|
197 |
+
dst = x
|
198 |
+
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, C))
|
199 |
+
|
200 |
+
out = torch.zeros(B, N, C, device=x.device, dtype=x.dtype)
|
201 |
+
|
202 |
+
out.scatter_(dim=-2, index=a_idx.expand(B, r, C), src=src)
|
203 |
+
out.scatter_(dim=-2, index=b_idx.expand(B, N - r, C), src=dst)
|
204 |
+
|
205 |
+
return out
|
206 |
+
|
207 |
+
return merge, unmerge
|
208 |
+
|
209 |
+
|
210 |
+
def merge_wavg(
|
211 |
+
merge: Callable, x: torch.Tensor, size: torch.Tensor = None
|
212 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
213 |
+
"""
|
214 |
+
Applies the merge function by taking a weighted average based on token size.
|
215 |
+
Returns the merged tensor and the new token sizes.
|
216 |
+
"""
|
217 |
+
if size is None:
|
218 |
+
size = torch.ones_like(x[..., 0, None])
|
219 |
+
|
220 |
+
x = merge(x * size, mode="sum")
|
221 |
+
size = merge(size, mode="sum")
|
222 |
+
|
223 |
+
x = x / size
|
224 |
+
return x, size
|
225 |
+
|
226 |
+
|
227 |
+
def merge_source(
|
228 |
+
merge: Callable, x: torch.Tensor, source: torch.Tensor = None
|
229 |
+
) -> torch.Tensor:
|
230 |
+
"""
|
231 |
+
For source tracking. Source is an adjacency matrix between the initial tokens and final merged groups.
|
232 |
+
x is used to find out how many tokens there are in case the source is None.
|
233 |
+
"""
|
234 |
+
if source is None:
|
235 |
+
n, t, _ = x.shape
|
236 |
+
source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t)
|
237 |
+
|
238 |
+
source = merge(source, mode="amax")
|
239 |
+
return source
|
tinychart/model/multimodal_encoder/siglip_encoder.py
ADDED
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
# Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py
|
3 |
+
'''
|
4 |
+
|
5 |
+
from typing import Optional, Tuple, Union, Dict
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from functools import partial, reduce
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
from torch import nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import os
|
14 |
+
import numpy as np
|
15 |
+
from transformers.image_processing_utils import BatchFeature, get_size_dict
|
16 |
+
from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
|
17 |
+
from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, )
|
18 |
+
from transformers.activations import ACT2FN
|
19 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
20 |
+
from transformers.modeling_utils import PreTrainedModel
|
21 |
+
from transformers import PretrainedConfig
|
22 |
+
from transformers.utils import ModelOutput
|
23 |
+
from tinychart.model.multimodal_encoder.merge import bipartite_soft_matching, merge_source, merge_wavg
|
24 |
+
|
25 |
+
|
26 |
+
class SigLipImageProcessor:
|
27 |
+
def __init__(self,
|
28 |
+
image_mean=(0.5, 0.5, 0.5),
|
29 |
+
image_std=(0.5, 0.5, 0.5),
|
30 |
+
size=(384, 384),
|
31 |
+
crop_size: Dict[str, int] = None,
|
32 |
+
resample=PILImageResampling.BICUBIC,
|
33 |
+
rescale_factor=1 / 255,
|
34 |
+
data_format=ChannelDimension.FIRST):
|
35 |
+
crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
|
36 |
+
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
|
37 |
+
|
38 |
+
self.image_mean = image_mean
|
39 |
+
self.image_std = image_std
|
40 |
+
self.size = size
|
41 |
+
self.resample = resample
|
42 |
+
self.rescale_factor = rescale_factor
|
43 |
+
self.data_format = data_format
|
44 |
+
self.crop_size = crop_size
|
45 |
+
|
46 |
+
def preprocess(self, images, return_tensors):
|
47 |
+
if isinstance(images, Image.Image):
|
48 |
+
images = [images]
|
49 |
+
else:
|
50 |
+
assert isinstance(images, list)
|
51 |
+
|
52 |
+
transforms = [
|
53 |
+
convert_to_rgb,
|
54 |
+
to_numpy_array,
|
55 |
+
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
|
56 |
+
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
|
57 |
+
partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
|
58 |
+
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
|
59 |
+
]
|
60 |
+
|
61 |
+
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
|
62 |
+
data = {"pixel_values": images}
|
63 |
+
|
64 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
65 |
+
|
66 |
+
|
67 |
+
class SigLipVisionConfig(PretrainedConfig):
|
68 |
+
model_type = "siglip_vision_model"
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
hidden_size=1152,
|
73 |
+
image_mean=(0.5, 0.5, 0.5),
|
74 |
+
intermediate_size=4304,
|
75 |
+
num_hidden_layers=27,
|
76 |
+
num_attention_heads=16,
|
77 |
+
num_channels=3,
|
78 |
+
image_size=384,
|
79 |
+
patch_size=14,
|
80 |
+
hidden_act="gelu_pytorch_tanh",
|
81 |
+
layer_norm_eps=1e-6,
|
82 |
+
attention_dropout=0.0,
|
83 |
+
**kwargs,
|
84 |
+
):
|
85 |
+
super().__init__(**kwargs)
|
86 |
+
|
87 |
+
self.hidden_size = hidden_size
|
88 |
+
self.intermediate_size = intermediate_size
|
89 |
+
self.num_hidden_layers = num_hidden_layers
|
90 |
+
self.num_attention_heads = num_attention_heads
|
91 |
+
self.num_channels = num_channels
|
92 |
+
self.patch_size = patch_size
|
93 |
+
self.image_size = image_size
|
94 |
+
self.attention_dropout = attention_dropout
|
95 |
+
self.layer_norm_eps = layer_norm_eps
|
96 |
+
self.hidden_act = hidden_act
|
97 |
+
self.image_mean = image_mean
|
98 |
+
|
99 |
+
for key, value in kwargs.items():
|
100 |
+
setattr(self, key, value)
|
101 |
+
|
102 |
+
@classmethod
|
103 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
104 |
+
cls._set_token_in_kwargs(kwargs)
|
105 |
+
|
106 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
107 |
+
|
108 |
+
# get the vision config dict if we are loading from SigLipConfig
|
109 |
+
if config_dict.get("model_type") == "siglip":
|
110 |
+
config_dict = config_dict["vision_config"]
|
111 |
+
|
112 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
113 |
+
logger.warning(
|
114 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
115 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
116 |
+
)
|
117 |
+
|
118 |
+
return cls.from_dict(config_dict, **kwargs)
|
119 |
+
|
120 |
+
|
121 |
+
@dataclass
|
122 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
|
123 |
+
class SigLipVisionModelOutput(ModelOutput):
|
124 |
+
"""
|
125 |
+
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
129 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
130 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
131 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
132 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
133 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
134 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
135 |
+
|
136 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
137 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
138 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
139 |
+
sequence_length)`.
|
140 |
+
|
141 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
142 |
+
heads.
|
143 |
+
"""
|
144 |
+
|
145 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
146 |
+
last_hidden_state: torch.FloatTensor = None
|
147 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
148 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
149 |
+
|
150 |
+
|
151 |
+
class SigLipVisionEmbeddings(nn.Module):
|
152 |
+
def __init__(self, config: SigLipVisionConfig):
|
153 |
+
super().__init__()
|
154 |
+
self.config = config
|
155 |
+
self.embed_dim = config.hidden_size
|
156 |
+
self.image_size = config.image_size
|
157 |
+
self.patch_size = config.patch_size
|
158 |
+
|
159 |
+
self.patch_embedding = nn.Conv2d(
|
160 |
+
in_channels=config.num_channels,
|
161 |
+
out_channels=self.embed_dim,
|
162 |
+
kernel_size=self.patch_size,
|
163 |
+
stride=self.patch_size,
|
164 |
+
padding="valid",
|
165 |
+
)
|
166 |
+
|
167 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
168 |
+
self.num_positions = self.num_patches
|
169 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
170 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
171 |
+
|
172 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
173 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
174 |
+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
175 |
+
|
176 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
177 |
+
return embeddings
|
178 |
+
|
179 |
+
|
180 |
+
class SigLipAttentionToMe(nn.Module):
|
181 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
182 |
+
|
183 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
184 |
+
def __init__(self, config):
|
185 |
+
super().__init__()
|
186 |
+
self.config = config
|
187 |
+
self.embed_dim = config.hidden_size
|
188 |
+
self.num_heads = config.num_attention_heads
|
189 |
+
self.head_dim = self.embed_dim // self.num_heads
|
190 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
191 |
+
raise ValueError(
|
192 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
193 |
+
f" {self.num_heads})."
|
194 |
+
)
|
195 |
+
self.scale = self.head_dim ** -0.5
|
196 |
+
self.dropout = config.attention_dropout
|
197 |
+
|
198 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
199 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
200 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
201 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
202 |
+
|
203 |
+
def forward(
|
204 |
+
self,
|
205 |
+
hidden_states: torch.Tensor,
|
206 |
+
attention_mask: Optional[torch.Tensor] = None,
|
207 |
+
output_attentions: Optional[bool] = False,
|
208 |
+
size: torch.Tensor = None,
|
209 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
210 |
+
"""Input shape: Batch x Time x Channel"""
|
211 |
+
|
212 |
+
batch_size, q_len, _ = hidden_states.size()
|
213 |
+
|
214 |
+
query_states = self.q_proj(hidden_states)
|
215 |
+
key_states = self.k_proj(hidden_states)
|
216 |
+
value_states = self.v_proj(hidden_states)
|
217 |
+
|
218 |
+
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
219 |
+
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
220 |
+
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
221 |
+
|
222 |
+
k_v_seq_len = key_states.shape[-2]
|
223 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
224 |
+
|
225 |
+
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
226 |
+
raise ValueError(
|
227 |
+
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
228 |
+
f" {attn_weights.size()}"
|
229 |
+
)
|
230 |
+
|
231 |
+
if attention_mask is not None:
|
232 |
+
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
233 |
+
raise ValueError(
|
234 |
+
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
235 |
+
)
|
236 |
+
attn_weights = attn_weights + attention_mask
|
237 |
+
|
238 |
+
# upcast attention to fp32
|
239 |
+
if size is not None:
|
240 |
+
attn_weights += size.log()[:, None, None, :, 0]
|
241 |
+
|
242 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
243 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
244 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
245 |
+
|
246 |
+
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
|
247 |
+
raise ValueError(
|
248 |
+
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
|
249 |
+
f" {attn_output.size()}"
|
250 |
+
)
|
251 |
+
|
252 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
253 |
+
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
254 |
+
|
255 |
+
attn_output = self.out_proj(attn_output)
|
256 |
+
return attn_output, attn_weights, key_states.mean(dim=1)
|
257 |
+
|
258 |
+
|
259 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
|
260 |
+
class SigLipMLP(nn.Module):
|
261 |
+
def __init__(self, config):
|
262 |
+
super().__init__()
|
263 |
+
self.config = config
|
264 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
265 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
266 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
267 |
+
|
268 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
269 |
+
hidden_states = self.fc1(hidden_states)
|
270 |
+
hidden_states = self.activation_fn(hidden_states)
|
271 |
+
hidden_states = self.fc2(hidden_states)
|
272 |
+
return hidden_states
|
273 |
+
|
274 |
+
|
275 |
+
class SigLipEncoderLayerToMe(nn.Module):
|
276 |
+
def __init__(self, config: SigLipVisionConfig, layer_id=None):
|
277 |
+
super().__init__()
|
278 |
+
self.embed_dim = config.hidden_size
|
279 |
+
self.self_attn = SigLipAttentionToMe(config)
|
280 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
281 |
+
self.mlp = SigLipMLP(config)
|
282 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
283 |
+
self.r = config.tome_r
|
284 |
+
self.layer_id = layer_id
|
285 |
+
|
286 |
+
# Ignore copy
|
287 |
+
def forward(
|
288 |
+
self,
|
289 |
+
hidden_states: torch.Tensor,
|
290 |
+
attention_mask: torch.Tensor,
|
291 |
+
output_attentions: Optional[bool] = False,
|
292 |
+
attention_size=None,
|
293 |
+
source=None,
|
294 |
+
trace_source=False
|
295 |
+
) -> Tuple[torch.FloatTensor]:
|
296 |
+
"""
|
297 |
+
Args:
|
298 |
+
hidden_states (`torch.FloatTensor`):
|
299 |
+
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
300 |
+
attention_mask (`torch.FloatTensor`):
|
301 |
+
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
302 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
303 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
304 |
+
returned tensors for more detail.
|
305 |
+
"""
|
306 |
+
residual = hidden_states
|
307 |
+
|
308 |
+
hidden_states = self.layer_norm1(hidden_states)
|
309 |
+
hidden_states, attn_weights, metric = self.self_attn(
|
310 |
+
hidden_states=hidden_states,
|
311 |
+
attention_mask=attention_mask,
|
312 |
+
output_attentions=output_attentions,
|
313 |
+
size=attention_size
|
314 |
+
)
|
315 |
+
hidden_states = residual + hidden_states
|
316 |
+
|
317 |
+
if self.r > 0:
|
318 |
+
merge, unmerge = bipartite_soft_matching(
|
319 |
+
metric,
|
320 |
+
r=self.r,
|
321 |
+
class_token=False,
|
322 |
+
distill_token=False
|
323 |
+
)
|
324 |
+
if trace_source:
|
325 |
+
source = merge_source(merge, hidden_states, source)
|
326 |
+
hidden_states, attention_size = merge_wavg(merge, hidden_states, attention_size)
|
327 |
+
residual = hidden_states
|
328 |
+
hidden_states = self.layer_norm2(hidden_states)
|
329 |
+
hidden_states = self.mlp(hidden_states)
|
330 |
+
hidden_states = residual + hidden_states
|
331 |
+
|
332 |
+
outputs = (hidden_states,)
|
333 |
+
|
334 |
+
if output_attentions:
|
335 |
+
outputs += (attn_weights,)
|
336 |
+
|
337 |
+
if trace_source:
|
338 |
+
outputs += (source,)
|
339 |
+
|
340 |
+
outputs += (attention_size,)
|
341 |
+
|
342 |
+
return outputs
|
343 |
+
|
344 |
+
|
345 |
+
class SigLipPreTrainedModel(PreTrainedModel):
|
346 |
+
"""
|
347 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
348 |
+
models.
|
349 |
+
"""
|
350 |
+
|
351 |
+
config_class = SigLipVisionConfig
|
352 |
+
base_model_prefix = "siglip"
|
353 |
+
supports_gradient_checkpointing = True
|
354 |
+
|
355 |
+
def _init_weights(self, module):
|
356 |
+
"""Initialize the weights"""
|
357 |
+
pass
|
358 |
+
|
359 |
+
|
360 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
|
361 |
+
class SigLipEncoder(nn.Module):
|
362 |
+
"""
|
363 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
364 |
+
[`SigLipEncoderLayer`].
|
365 |
+
|
366 |
+
Args:
|
367 |
+
config: SigLipVisionConfig
|
368 |
+
"""
|
369 |
+
|
370 |
+
def __init__(self, config: SigLipVisionConfig):
|
371 |
+
super().__init__()
|
372 |
+
self.config = config
|
373 |
+
self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
374 |
+
self.gradient_checkpointing = False
|
375 |
+
|
376 |
+
# Ignore copy
|
377 |
+
def forward(
|
378 |
+
self,
|
379 |
+
inputs_embeds,
|
380 |
+
attention_mask: Optional[torch.Tensor] = None,
|
381 |
+
output_attentions: Optional[bool] = None,
|
382 |
+
output_hidden_states: Optional[bool] = None,
|
383 |
+
return_dict: Optional[bool] = None,
|
384 |
+
) -> Union[Tuple, BaseModelOutput]:
|
385 |
+
r"""
|
386 |
+
Args:
|
387 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
388 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
389 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
390 |
+
than the model's internal embedding lookup matrix.
|
391 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
392 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
393 |
+
|
394 |
+
- 1 for tokens that are **not masked**,
|
395 |
+
- 0 for tokens that are **masked**.
|
396 |
+
|
397 |
+
[What are attention masks?](../glossary#attention-mask)
|
398 |
+
output_attentions (`bool`, *optional*):
|
399 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
400 |
+
returned tensors for more detail.
|
401 |
+
output_hidden_states (`bool`, *optional*):
|
402 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
403 |
+
for more detail.
|
404 |
+
return_dict (`bool`, *optional*):
|
405 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
406 |
+
"""
|
407 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
408 |
+
output_hidden_states = (
|
409 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
410 |
+
)
|
411 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
412 |
+
|
413 |
+
encoder_states = () if output_hidden_states else None
|
414 |
+
all_attentions = () if output_attentions else None
|
415 |
+
|
416 |
+
hidden_states = inputs_embeds
|
417 |
+
for encoder_layer in self.layers:
|
418 |
+
if output_hidden_states:
|
419 |
+
encoder_states = encoder_states + (hidden_states,)
|
420 |
+
if self.gradient_checkpointing and self.training:
|
421 |
+
layer_outputs = self._gradient_checkpointing_func(
|
422 |
+
encoder_layer.__call__,
|
423 |
+
hidden_states,
|
424 |
+
attention_mask,
|
425 |
+
output_attentions,
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
layer_outputs = encoder_layer(
|
429 |
+
hidden_states,
|
430 |
+
attention_mask,
|
431 |
+
output_attentions=output_attentions,
|
432 |
+
)
|
433 |
+
|
434 |
+
hidden_states = layer_outputs[0]
|
435 |
+
|
436 |
+
if output_attentions:
|
437 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
438 |
+
|
439 |
+
if output_hidden_states:
|
440 |
+
encoder_states = encoder_states + (hidden_states,)
|
441 |
+
|
442 |
+
if not return_dict:
|
443 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
444 |
+
return BaseModelOutput(
|
445 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
446 |
+
)
|
447 |
+
|
448 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
|
449 |
+
class SigLipEncoderToMe(nn.Module):
|
450 |
+
"""
|
451 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
452 |
+
[`SigLipEncoderLayerToMe`].
|
453 |
+
|
454 |
+
Args:
|
455 |
+
config: SigLipVisionConfig
|
456 |
+
"""
|
457 |
+
|
458 |
+
def __init__(self, config: SigLipVisionConfig):
|
459 |
+
super().__init__()
|
460 |
+
self.config = config
|
461 |
+
self.layers = nn.ModuleList([SigLipEncoderLayerToMe(config, layer_id=layer_id) for layer_id in range(config.num_hidden_layers)])
|
462 |
+
self.gradient_checkpointing = False
|
463 |
+
self.trace_source = getattr(config, 'trace_source', False)
|
464 |
+
|
465 |
+
# Ignore copy
|
466 |
+
def forward(
|
467 |
+
self,
|
468 |
+
inputs_embeds,
|
469 |
+
attention_mask: Optional[torch.Tensor] = None,
|
470 |
+
output_attentions: Optional[bool] = None,
|
471 |
+
output_hidden_states: Optional[bool] = None,
|
472 |
+
return_dict: Optional[bool] = None,
|
473 |
+
) -> Union[Tuple, BaseModelOutput]:
|
474 |
+
r"""
|
475 |
+
Args:
|
476 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
477 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
478 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
479 |
+
than the model's internal embedding lookup matrix.
|
480 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
481 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
482 |
+
|
483 |
+
- 1 for tokens that are **not masked**,
|
484 |
+
- 0 for tokens that are **masked**.
|
485 |
+
|
486 |
+
[What are attention masks?](../glossary#attention-mask)
|
487 |
+
output_attentions (`bool`, *optional*):
|
488 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
489 |
+
returned tensors for more detail.
|
490 |
+
output_hidden_states (`bool`, *optional*):
|
491 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
492 |
+
for more detail.
|
493 |
+
return_dict (`bool`, *optional*):
|
494 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
495 |
+
"""
|
496 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
497 |
+
output_hidden_states = (
|
498 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
499 |
+
)
|
500 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
501 |
+
|
502 |
+
encoder_states = () if output_hidden_states else None
|
503 |
+
all_attentions = () if output_attentions else None
|
504 |
+
|
505 |
+
hidden_states = inputs_embeds
|
506 |
+
attention_size = None
|
507 |
+
source = None
|
508 |
+
for encoder_layer in self.layers:
|
509 |
+
if output_hidden_states:
|
510 |
+
encoder_states = encoder_states + (hidden_states,)
|
511 |
+
if self.gradient_checkpointing and self.training:
|
512 |
+
|
513 |
+
layer_outputs = self._gradient_checkpointing_func(
|
514 |
+
encoder_layer.__call__,
|
515 |
+
hidden_states,
|
516 |
+
attention_mask,
|
517 |
+
output_attentions,
|
518 |
+
attention_size,
|
519 |
+
source if self.trace_source else None,
|
520 |
+
self.trace_source
|
521 |
+
)
|
522 |
+
else:
|
523 |
+
layer_outputs = encoder_layer(
|
524 |
+
hidden_states,
|
525 |
+
attention_mask,
|
526 |
+
output_attentions=output_attentions,
|
527 |
+
attention_size=attention_size,
|
528 |
+
source=source if self.trace_source else None,
|
529 |
+
trace_source=self.trace_source
|
530 |
+
)
|
531 |
+
|
532 |
+
hidden_states = layer_outputs[0]
|
533 |
+
if self.trace_source:
|
534 |
+
source = layer_outputs[-2]
|
535 |
+
|
536 |
+
attention_size = layer_outputs[-1]
|
537 |
+
|
538 |
+
if output_attentions:
|
539 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
540 |
+
|
541 |
+
if output_hidden_states:
|
542 |
+
encoder_states = encoder_states + (hidden_states,)
|
543 |
+
|
544 |
+
if not return_dict:
|
545 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
546 |
+
return BaseModelOutput(
|
547 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
548 |
+
)
|
549 |
+
|
550 |
+
|
551 |
+
class SigLipVisionTransformer(nn.Module):
|
552 |
+
def __init__(self, config: SigLipVisionConfig):
|
553 |
+
super().__init__()
|
554 |
+
self.config = config
|
555 |
+
embed_dim = config.hidden_size
|
556 |
+
|
557 |
+
self.embeddings = SigLipVisionEmbeddings(config)
|
558 |
+
self.encoder = SigLipEncoderToMe(config)
|
559 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
560 |
+
self.head = SigLipMultiheadAttentionPoolingHead(config)
|
561 |
+
|
562 |
+
def forward(
|
563 |
+
self,
|
564 |
+
pixel_values,
|
565 |
+
output_attentions: Optional[bool] = None,
|
566 |
+
output_hidden_states: Optional[bool] = None,
|
567 |
+
return_dict: Optional[bool] = None,
|
568 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
569 |
+
r"""
|
570 |
+
Returns:
|
571 |
+
|
572 |
+
"""
|
573 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
574 |
+
output_hidden_states = (
|
575 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
576 |
+
)
|
577 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
578 |
+
|
579 |
+
hidden_states = self.embeddings(pixel_values)
|
580 |
+
|
581 |
+
encoder_outputs = self.encoder(
|
582 |
+
inputs_embeds=hidden_states,
|
583 |
+
output_attentions=output_attentions,
|
584 |
+
output_hidden_states=output_hidden_states,
|
585 |
+
return_dict=return_dict,
|
586 |
+
)
|
587 |
+
|
588 |
+
last_hidden_state = encoder_outputs[0]
|
589 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
590 |
+
|
591 |
+
pooled_output = self.head(last_hidden_state)
|
592 |
+
|
593 |
+
if not return_dict:
|
594 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
595 |
+
|
596 |
+
return BaseModelOutputWithPooling(
|
597 |
+
last_hidden_state=last_hidden_state,
|
598 |
+
pooler_output=pooled_output,
|
599 |
+
hidden_states=encoder_outputs.hidden_states,
|
600 |
+
attentions=encoder_outputs.attentions,
|
601 |
+
)
|
602 |
+
|
603 |
+
|
604 |
+
class SigLipMultiheadAttentionPoolingHead(nn.Module):
|
605 |
+
"""Multihead Attention Pooling."""
|
606 |
+
|
607 |
+
def __init__(self, config: SigLipVisionConfig):
|
608 |
+
super().__init__()
|
609 |
+
|
610 |
+
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
611 |
+
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
|
612 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
613 |
+
self.mlp = SigLipMLP(config)
|
614 |
+
|
615 |
+
def forward(self, hidden_state):
|
616 |
+
batch_size = hidden_state.shape[0]
|
617 |
+
probe = self.probe.repeat(batch_size, 1, 1)
|
618 |
+
|
619 |
+
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
620 |
+
|
621 |
+
residual = hidden_state
|
622 |
+
hidden_state = self.layernorm(hidden_state)
|
623 |
+
hidden_state = residual + self.mlp(hidden_state)
|
624 |
+
|
625 |
+
return hidden_state[:, 0]
|
626 |
+
|
627 |
+
|
628 |
+
class SigLipVisionModel(SigLipPreTrainedModel):
|
629 |
+
config_class = SigLipVisionConfig
|
630 |
+
main_input_name = "pixel_values"
|
631 |
+
_no_split_modules = ["SigLipEncoderLayerToMe"]
|
632 |
+
|
633 |
+
def __init__(self, config: SigLipVisionConfig):
|
634 |
+
super().__init__(config)
|
635 |
+
|
636 |
+
self.vision_model = SigLipVisionTransformer(config)
|
637 |
+
del self.vision_model.encoder.layers[-1:]
|
638 |
+
self.vision_model.head = nn.Identity()
|
639 |
+
# Initialize weights and apply final processing
|
640 |
+
self.post_init()
|
641 |
+
|
642 |
+
def get_input_embeddings(self) -> nn.Module:
|
643 |
+
return self.vision_model.embeddings.patch_embedding
|
644 |
+
|
645 |
+
def forward(
|
646 |
+
self,
|
647 |
+
pixel_values,
|
648 |
+
output_attentions: Optional[bool] = None,
|
649 |
+
output_hidden_states: Optional[bool] = None,
|
650 |
+
return_dict: Optional[bool] = None,
|
651 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
652 |
+
r"""
|
653 |
+
Returns:
|
654 |
+
|
655 |
+
Examples:
|
656 |
+
|
657 |
+
```python
|
658 |
+
>>> from PIL import Image
|
659 |
+
>>> import requests
|
660 |
+
>>> from transformers import AutoProcessor, SigLipVisionModel
|
661 |
+
|
662 |
+
>>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
|
663 |
+
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
664 |
+
|
665 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
666 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
667 |
+
|
668 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
669 |
+
|
670 |
+
>>> outputs = model(**inputs)
|
671 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
672 |
+
>>> pooled_output = outputs.pooler_output # pooled features
|
673 |
+
```"""
|
674 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
675 |
+
|
676 |
+
return self.vision_model(
|
677 |
+
pixel_values=pixel_values,
|
678 |
+
output_attentions=output_attentions,
|
679 |
+
output_hidden_states=output_hidden_states,
|
680 |
+
return_dict=return_dict,
|
681 |
+
)
|
682 |
+
|
683 |
+
class SigLipVisionTower(nn.Module):
|
684 |
+
def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
|
685 |
+
super().__init__()
|
686 |
+
|
687 |
+
self.is_loaded = False
|
688 |
+
|
689 |
+
if vision_tower is not None:
|
690 |
+
self.config = SigLipVisionConfig.from_pretrained(vision_tower)
|
691 |
+
else:
|
692 |
+
self.config = SigLipVisionConfig()
|
693 |
+
|
694 |
+
self.vision_tower_name = vision_tower
|
695 |
+
|
696 |
+
self.image_processor = SigLipImageProcessor(size=(self.config.image_size, self.config.image_size), image_mean=self.config.image_mean)
|
697 |
+
|
698 |
+
if not delay_load:
|
699 |
+
self.load_model()
|
700 |
+
else:
|
701 |
+
self.cfg_only = self.config
|
702 |
+
|
703 |
+
def load_model(self):
|
704 |
+
if self.is_loaded:
|
705 |
+
return
|
706 |
+
|
707 |
+
self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name)
|
708 |
+
|
709 |
+
self.vision_tower.requires_grad_(False)
|
710 |
+
self.vision_tower.eval()
|
711 |
+
|
712 |
+
self.is_loaded = True
|
713 |
+
|
714 |
+
# @torch.no_grad()
|
715 |
+
def forward(self, images):
|
716 |
+
if type(images) is list:
|
717 |
+
image_features = []
|
718 |
+
for image in images:
|
719 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
720 |
+
output_hidden_states=True)
|
721 |
+
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
722 |
+
|
723 |
+
image_features.append(image_feature)
|
724 |
+
else:
|
725 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
|
726 |
+
output_hidden_states=True)
|
727 |
+
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
728 |
+
|
729 |
+
return image_features
|
730 |
+
|
731 |
+
@property
|
732 |
+
def dummy_feature(self):
|
733 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
734 |
+
|
735 |
+
@property
|
736 |
+
def dtype(self):
|
737 |
+
for p in self.vision_tower.parameters():
|
738 |
+
return p.dtype
|
739 |
+
|
740 |
+
@property
|
741 |
+
def device(self):
|
742 |
+
for p in self.vision_tower.parameters():
|
743 |
+
return p.device
|
744 |
+
|
745 |
+
@property
|
746 |
+
def hidden_size(self):
|
747 |
+
return self.config.hidden_size
|
748 |
+
|
749 |
+
@property
|
750 |
+
def num_patches(self):
|
751 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
tinychart/model/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import re
|
6 |
+
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
|
10 |
+
class IdentityMap(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
def forward(self, x, *args, **kwargs):
|
15 |
+
return x
|
16 |
+
|
17 |
+
@property
|
18 |
+
def config(self):
|
19 |
+
return {"mm_projector_type": 'identity'}
|
20 |
+
|
21 |
+
|
22 |
+
class SimpleResBlock(nn.Module):
|
23 |
+
def __init__(self, channels):
|
24 |
+
super().__init__()
|
25 |
+
self.pre_norm = nn.LayerNorm(channels)
|
26 |
+
|
27 |
+
self.proj = nn.Sequential(
|
28 |
+
nn.Linear(channels, channels),
|
29 |
+
nn.GELU(),
|
30 |
+
nn.Linear(channels, channels)
|
31 |
+
)
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.pre_norm(x)
|
34 |
+
return x + self.proj(x)
|
35 |
+
|
36 |
+
|
37 |
+
class ResamplerBlock(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
hidden_size: int = 768,
|
41 |
+
image_hidden_size: int = 1024,
|
42 |
+
num_heads: int = 12,
|
43 |
+
intermediate_size: int = None
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
assert hidden_size % num_heads == 0, "For MHSA, you must have number of heads divisible by initial hidden size"
|
47 |
+
intermediate_size = hidden_size * 4 if intermediate_size is None else intermediate_size
|
48 |
+
# intermediate_size = hidden_size * 4
|
49 |
+
self.scale = 1 / math.sqrt(hidden_size // num_heads)
|
50 |
+
self.num_heads = num_heads
|
51 |
+
self.to_q = nn.Linear(hidden_size, hidden_size, bias=False)
|
52 |
+
self.to_k = nn.Linear(image_hidden_size, hidden_size, bias=False)
|
53 |
+
self.to_v = nn.Linear(image_hidden_size, hidden_size, bias=False)
|
54 |
+
|
55 |
+
self.to_out = nn.Linear(hidden_size, hidden_size, bias=False)
|
56 |
+
|
57 |
+
self.feed_forward = nn.Sequential(
|
58 |
+
*[
|
59 |
+
nn.LayerNorm(hidden_size),
|
60 |
+
nn.Linear(hidden_size, intermediate_size, bias=False),
|
61 |
+
nn.GELU(),
|
62 |
+
nn.Linear(intermediate_size, hidden_size, bias=False),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
# prenorm for image features
|
66 |
+
self.norm_image = nn.LayerNorm(image_hidden_size)
|
67 |
+
self.norm_hidden = nn.LayerNorm(hidden_size)
|
68 |
+
|
69 |
+
def forward(self, hidden_states: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
70 |
+
# prenorm
|
71 |
+
x = self.norm_image(x)
|
72 |
+
residual_hidden_states = hidden_states
|
73 |
+
hidden_states = self.norm_hidden(hidden_states)
|
74 |
+
# compute Q, K, V
|
75 |
+
queries = self.to_q(hidden_states)
|
76 |
+
keys = self.to_k(x)
|
77 |
+
values = self.to_v(x)
|
78 |
+
# rearrange them into multi-head format
|
79 |
+
queries = rearrange(queries, "b n (h d) -> b h n d", h=self.num_heads)
|
80 |
+
keys = rearrange(keys, "b n (h d) -> b h n d", h=self.num_heads)
|
81 |
+
values = rearrange(values, "b n (h d) -> b h n d", h=self.num_heads)
|
82 |
+
# rescale
|
83 |
+
queries = self.scale * queries
|
84 |
+
# compute QK^T
|
85 |
+
scores = torch.einsum("... i d, ... j d -> ... i j", queries, keys)
|
86 |
+
# for stability
|
87 |
+
scores = scores - scores.amax(dim=-1, keepdim=True).detach()
|
88 |
+
# softmax
|
89 |
+
attention_scores = scores.softmax(dim=-1) # b h i j (i: number of queries, j: number of keys)
|
90 |
+
# dot product with V
|
91 |
+
out = torch.einsum("... i j, ... j d -> ... i d", attention_scores, values)
|
92 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=self.num_heads)
|
93 |
+
out = self.to_out(out) + residual_hidden_states
|
94 |
+
residual_out = out
|
95 |
+
out = self.feed_forward(out)
|
96 |
+
return out + residual_out
|
97 |
+
|
98 |
+
|
99 |
+
class Resampler(nn.Module):
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
hidden_size: int = 768,
|
103 |
+
image_hidden_size: int = 1024,
|
104 |
+
final_hidden_size: int = 4096,
|
105 |
+
num_heads: int = 12,
|
106 |
+
intermediate_size: int = None,
|
107 |
+
num_queries: int = 128,
|
108 |
+
num_layers: int = 3,
|
109 |
+
initializer_range: float = 0.02
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
self.resampler_blocks = nn.ModuleList(
|
113 |
+
[
|
114 |
+
ResamplerBlock(
|
115 |
+
hidden_size, image_hidden_size, num_heads, intermediate_size
|
116 |
+
) for _ in range(num_layers)
|
117 |
+
]
|
118 |
+
)
|
119 |
+
self.queries = nn.Parameter(torch.randn(num_queries, hidden_size))
|
120 |
+
self.post_norm = nn.LayerNorm(hidden_size)
|
121 |
+
|
122 |
+
self.final_proj = nn.Linear(hidden_size, final_hidden_size, bias=False)
|
123 |
+
|
124 |
+
# self.initializer_range = initializer_range
|
125 |
+
# for module in self.modules():
|
126 |
+
# if isinstance(module, (nn.Linear, nn.LayerNorm, nn.Conv2d)):
|
127 |
+
# self._init_weights(module)
|
128 |
+
#
|
129 |
+
# def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
130 |
+
# """Initialize the weights"""
|
131 |
+
# if isinstance(module, (nn.Linear, nn.Conv2d)):
|
132 |
+
# # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
133 |
+
# # `trunc_normal_cpu` not implemented in `half` issues
|
134 |
+
# module.weight.data = nn.init.trunc_normal_(
|
135 |
+
# module.weight.data.to(torch.float32), mean=0.0, std=self.initializer_range
|
136 |
+
# ).to(module.weight.dtype)
|
137 |
+
# if module.bias is not None:
|
138 |
+
# module.bias.data.zero_()
|
139 |
+
# elif isinstance(module, nn.LayerNorm):
|
140 |
+
# module.bias.data.zero_()
|
141 |
+
# module.weight.data.fill_(1.0)
|
142 |
+
|
143 |
+
def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
|
144 |
+
b = image_hidden_states.size(0)
|
145 |
+
queries = repeat(self.queries, 'n d -> b n d', b=b)
|
146 |
+
for resampler_block in self.resampler_blocks:
|
147 |
+
queries = resampler_block(queries, image_hidden_states)
|
148 |
+
|
149 |
+
# post norm
|
150 |
+
queries = self.post_norm(queries)
|
151 |
+
return self.final_proj(queries)
|
152 |
+
|
153 |
+
|
154 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
155 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
156 |
+
|
157 |
+
if projector_type == 'linear':
|
158 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
159 |
+
|
160 |
+
if projector_type == 'resampler':
|
161 |
+
hidden_size = getattr(config, 'resampler_hidden_size', 768)
|
162 |
+
image_hidden_size = config.mm_hidden_size
|
163 |
+
num_queries = getattr(config, 'num_queries', 128)
|
164 |
+
final_hidden_size = config.hidden_size
|
165 |
+
num_heads = 12
|
166 |
+
if hidden_size == 512:
|
167 |
+
num_heads = 8
|
168 |
+
num_layers = getattr(config, 'num_resampler_layers', 3)
|
169 |
+
|
170 |
+
initializer_range = getattr(config, 'initializer_range', 0.02)
|
171 |
+
print(
|
172 |
+
f"resampler config: resampler hidden size: {hidden_size}, num_queries: {num_queries}, "
|
173 |
+
f"num_resampler_layers: {num_layers}"
|
174 |
+
)
|
175 |
+
return Resampler(
|
176 |
+
hidden_size=hidden_size,
|
177 |
+
image_hidden_size=image_hidden_size,
|
178 |
+
num_queries=num_queries,
|
179 |
+
final_hidden_size=final_hidden_size,
|
180 |
+
num_layers=num_layers,
|
181 |
+
num_heads=num_heads,
|
182 |
+
initializer_range=initializer_range
|
183 |
+
)
|
184 |
+
|
185 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
186 |
+
if mlp_gelu_match:
|
187 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
188 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
189 |
+
for _ in range(1, mlp_depth):
|
190 |
+
modules.append(nn.GELU())
|
191 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
192 |
+
mlp = nn.Sequential(*modules)
|
193 |
+
if getattr(config, 'load_moe_mm_projector', False):
|
194 |
+
from deepspeed.moe.layer import MoE
|
195 |
+
mlp = MoE(
|
196 |
+
config.mm_hidden_size,
|
197 |
+
expert=mlp,
|
198 |
+
num_experts=4,
|
199 |
+
ep_size=1,
|
200 |
+
k=2,
|
201 |
+
capacity_factor=1.,
|
202 |
+
eval_capacity_factor=1.,
|
203 |
+
min_capacity=4,
|
204 |
+
use_residual=False,
|
205 |
+
)
|
206 |
+
|
207 |
+
def moe_forward_wrapper(forward_func):
|
208 |
+
return lambda *args, **kwargs: forward_func(*args, **kwargs)[0]
|
209 |
+
mlp.forward = moe_forward_wrapper(mlp.forward)
|
210 |
+
return mlp
|
211 |
+
|
212 |
+
if projector_type == 'identity':
|
213 |
+
return IdentityMap()
|
214 |
+
|
215 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
tinychart/utils.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import logging.handlers
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from tinychart.constants import LOGDIR
|
10 |
+
|
11 |
+
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
12 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
13 |
+
|
14 |
+
handler = None
|
15 |
+
|
16 |
+
|
17 |
+
def build_logger(logger_name, logger_filename):
|
18 |
+
global handler
|
19 |
+
|
20 |
+
formatter = logging.Formatter(
|
21 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
22 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
23 |
+
)
|
24 |
+
|
25 |
+
# Set the format of root handlers
|
26 |
+
if not logging.getLogger().handlers:
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
29 |
+
|
30 |
+
# Redirect stdout and stderr to loggers
|
31 |
+
stdout_logger = logging.getLogger("stdout")
|
32 |
+
stdout_logger.setLevel(logging.INFO)
|
33 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
34 |
+
sys.stdout = sl
|
35 |
+
|
36 |
+
stderr_logger = logging.getLogger("stderr")
|
37 |
+
stderr_logger.setLevel(logging.ERROR)
|
38 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
39 |
+
sys.stderr = sl
|
40 |
+
|
41 |
+
# Get logger
|
42 |
+
logger = logging.getLogger(logger_name)
|
43 |
+
logger.setLevel(logging.INFO)
|
44 |
+
|
45 |
+
# Add a file handler for all loggers
|
46 |
+
if handler is None:
|
47 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
48 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
49 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
50 |
+
filename, when='D', utc=True, encoding='UTF-8')
|
51 |
+
handler.setFormatter(formatter)
|
52 |
+
|
53 |
+
for name, item in logging.root.manager.loggerDict.items():
|
54 |
+
if isinstance(item, logging.Logger):
|
55 |
+
item.addHandler(handler)
|
56 |
+
|
57 |
+
return logger
|
58 |
+
|
59 |
+
|
60 |
+
class StreamToLogger(object):
|
61 |
+
"""
|
62 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
63 |
+
"""
|
64 |
+
def __init__(self, logger, log_level=logging.INFO):
|
65 |
+
self.terminal = sys.stdout
|
66 |
+
self.logger = logger
|
67 |
+
self.log_level = log_level
|
68 |
+
self.linebuf = ''
|
69 |
+
|
70 |
+
def __getattr__(self, attr):
|
71 |
+
return getattr(self.terminal, attr)
|
72 |
+
|
73 |
+
def write(self, buf):
|
74 |
+
temp_linebuf = self.linebuf + buf
|
75 |
+
self.linebuf = ''
|
76 |
+
for line in temp_linebuf.splitlines(True):
|
77 |
+
# From the io.TextIOWrapper LOGS:
|
78 |
+
# On output, if newline is None, any '\n' characters written
|
79 |
+
# are translated to the system default line separator.
|
80 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
81 |
+
# translates them so this is still cross platform.
|
82 |
+
if line[-1] == '\n':
|
83 |
+
self.logger.log(self.log_level, line.rstrip())
|
84 |
+
else:
|
85 |
+
self.linebuf += line
|
86 |
+
|
87 |
+
def flush(self):
|
88 |
+
if self.linebuf != '':
|
89 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
90 |
+
self.linebuf = ''
|
91 |
+
|
92 |
+
|
93 |
+
def disable_torch_init():
|
94 |
+
"""
|
95 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
96 |
+
"""
|
97 |
+
import torch
|
98 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
99 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
100 |
+
|
101 |
+
|
102 |
+
def violates_moderation(text):
|
103 |
+
"""
|
104 |
+
Check whether the text violates OpenAI moderation API.
|
105 |
+
"""
|
106 |
+
url = "https://api.openai.com/v1/moderations"
|
107 |
+
headers = {"Content-Type": "application/json",
|
108 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
109 |
+
text = text.replace("\n", "")
|
110 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
111 |
+
data = data.encode("utf-8")
|
112 |
+
try:
|
113 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
114 |
+
flagged = ret.json()["results"][0]["flagged"]
|
115 |
+
except requests.exceptions.RequestException as e:
|
116 |
+
flagged = False
|
117 |
+
except KeyError as e:
|
118 |
+
flagged = False
|
119 |
+
|
120 |
+
return flagged
|
121 |
+
|
122 |
+
|
123 |
+
def pretty_print_semaphore(semaphore):
|
124 |
+
if semaphore is None:
|
125 |
+
return "None"
|
126 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
local_rank = None
|
131 |
+
|
132 |
+
def rank0_print(*args):
|
133 |
+
if os.environ["RANK"] == '0':
|
134 |
+
print(*args)
|