Ferret-UI-Llama8b / builder.py
jadechoghari's picture
Create builder.py
7ddba36 verified
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import pdb
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
IMAGE_PLACEHOLDER = "<image-placeholder>"
# Added by Ferret
DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
VOCAB_IMAGE_W = 1000
VOCAB_IMAGE_H = 1000
# GROUNDING PROMPTS
GROUNDING_TEMPLATES = [
'\nProvide the bounding boxes of the mentioned objects.',
'\nInclude the coordinates for each mentioned object.',
'\nLocate the objects with their coordinates.',
'\nAnswer in [x1, y1, x2, y2] format.',
'\nMention the objects and their locations using the format [x1, y1, x2, y2].',
'\nDraw boxes around the mentioned objects.',
'\nUse boxes to show where each thing is.',
'\nTell me where the objects are with coordinates.',
'\nList where each object is with boxes.',
'\nShow me the regions with boxes.'
]
DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):
kwargs = {"device_map": device_map}
if load_8bit:
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
else:
kwargs['torch_dtype'] = torch.float16
if 'llava' in model_name.lower() or 'ferret' in model_name.lower():
# Load LLaVA/FERRET model
if 'lora' in model_name.lower() and model_base is not None:
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
print('Loading LLaVA/FERRET from base model...')
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
if model.lm_head.weight.shape[0] != token_num:
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
print('Loading additional LLaVA/FERRET weights...')
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
else:
# this is probably from HF Hub
from huggingface_hub import hf_hub_download
def load_from_hf(repo_id, filename, subfolder=None):
cache_file = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder)
return torch.load(cache_file, map_location='cpu')
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
if any(k.startswith('model.model.') for k in non_lora_trainables):
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
model.load_state_dict(non_lora_trainables, strict=False)
from peft import PeftModel
print('Loading LoRA weights...')
model = PeftModel.from_pretrained(model, model_path)
print('Merging LoRA weights...')
model = model.merge_and_unload()
print('Model is loaded...')
elif model_base is not None:
# this may be mm projector only
print('Loading LLaVA/FERRET from base model...')
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
model.load_state_dict(mm_projector_weights, strict=False)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
else:
# Load language model
if model_base is not None:
# PEFT model
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
print(f"Loading LoRA weights from {model_path}")
model = PeftModel.from_pretrained(model, model_path)
print(f"Merging weights")
model = model.merge_and_unload()
print('Convert to FP16...')
model.to(torch.float16)
else:
use_fast = False
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
image_processor = None
if 'llava' in model_name.lower() or 'ferret' in model_name.lower():
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
mm_im_region_fea_token = getattr(model.config, "im_region_fea_token", None)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_im_region_fea_token is not None:
tokenizer.add_tokens([DEFAULT_REGION_FEA_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
vision_tower_path = os.path.join(model_path, 'vision_tower')
if not vision_tower.is_loaded or os.path.exists(vision_tower_path):
if os.path.exists(vision_tower_path):
print(f'Start Loading vision tower from {vision_tower_path}')
vision_tower.load_model(vision_tower_path=vision_tower_path)
print(f'Finish Loading vision tower from {vision_tower_path}')
else:
vision_tower.load_model()
vision_tower.to(device='cuda', dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len