# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil import pdb from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig import torch CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 LOGDIR = "." # Model Constants IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" IMAGE_PLACEHOLDER = "" # Added by Ferret DEFAULT_REGION_FEA_TOKEN = "" VOCAB_IMAGE_W = 1000 VOCAB_IMAGE_H = 1000 # GROUNDING PROMPTS GROUNDING_TEMPLATES = [ '\nProvide the bounding boxes of the mentioned objects.', '\nInclude the coordinates for each mentioned object.', '\nLocate the objects with their coordinates.', '\nAnswer in [x1, y1, x2, y2] format.', '\nMention the objects and their locations using the format [x1, y1, x2, y2].', '\nDraw boxes around the mentioned objects.', '\nUse boxes to show where each thing is.', '\nTell me where the objects are with coordinates.', '\nList where each object is with boxes.', '\nShow me the regions with boxes.' ] DEFAULT_REGION_FEA_TOKEN = "" def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"): kwargs = {"device_map": device_map} if load_8bit: kwargs['load_in_8bit'] = True elif load_4bit: kwargs['load_in_4bit'] = True kwargs['quantization_config'] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ) else: kwargs['torch_dtype'] = torch.float16 if 'llava' in model_name.lower() or 'ferret' in model_name.lower(): # Load LLaVA/FERRET model if 'lora' in model_name.lower() and model_base is not None: lora_cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, trust_remote_code=True) print('Loading LLaVA/FERRET from base model...') model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs, trust_remote_code=True) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) print('Loading additional LLaVA/FERRET weights...') if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder) return torch.load(cache_file, map_location='cpu') non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} if any(k.startswith('model.model.') for k in non_lora_trainables): non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print('Loading LoRA weights...') model = PeftModel.from_pretrained(model, model_path, trust_remote_code=True) print('Merging LoRA weights...') model = model.merge_and_unload() print('Model is loaded...') elif model_base is not None: # this may be mm projector only print('Loading LLaVA/FERRET from base model...') tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs, trust_remote_code=True) mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} model.load_state_dict(mm_projector_weights, strict=False) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs, trust_remote_code=True) else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto", trust_remote_code=True) print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path, trust_remote_code=True) print(f"Merging weights") model = model.merge_and_unload() print('Convert to FP16...') model.to(torch.float16) else: use_fast = False tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs, trust_remote_code=True) image_processor = None if 'llava' in model_name.lower() or 'ferret' in model_name.lower(): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) mm_im_region_fea_token = getattr(model.config, "im_region_fea_token", None) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_im_region_fea_token is not None: tokenizer.add_tokens([DEFAULT_REGION_FEA_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() vision_tower_path = os.path.join(model_path, 'vision_tower') if not vision_tower.is_loaded or os.path.exists(vision_tower_path): if os.path.exists(vision_tower_path): print(f'Start Loading vision tower from {vision_tower_path}') vision_tower.load_model(vision_tower_path=vision_tower_path) print(f'Finish Loading vision tower from {vision_tower_path}') else: vision_tower.load_model() vision_tower.to(device='cuda', dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len