Spaces:
Runtime error
Runtime error
File size: 6,824 Bytes
f239efc afc99d0 f239efc afc99d0 f239efc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import torch
import os
from peft import get_peft_model, LoraConfig, TaskType
from safetensors import safe_open
from peft import PeftModel
from tasks.eval.eval_utils import Conversation
from models.pllava import PllavaProcessor, PllavaForConditionalGeneration, PllavaConfig
from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map,load_checkpoint_in_model
from accelerate.utils import get_balanced_memory
from transformers import StoppingCriteria
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
return False
else:
outputs = self.tokenizer.batch_decode(
output_ids[:, self.start_len:], skip_special_tokens=True
)
flag = True
for output in outputs:
for keyword in self.keywords:
if keyword not in output:
flag = False
return False
return flag
def load_pllava(repo_id, num_frames, use_lora=False, weight_dir=None, lora_alpha=32, use_multi_gpus=False, pooling_shape=(16,12,12)):
kwargs = {
'num_frames': num_frames,
}
# print("===============>pooling_shape", pooling_shape)
if num_frames == 0:
kwargs.update(pooling_shape=(0,12,12)) # produce a bug if ever usen the pooling projector
config = PllavaConfig.from_pretrained(
repo_id if not use_lora else weight_dir,
pooling_shape=pooling_shape,
**kwargs,
)
with torch.no_grad():
model = PllavaForConditionalGeneration.from_pretrained(repo_id, config=config, torch_dtype=torch.bfloat16)
try:
processor = PllavaProcessor.from_pretrained(repo_id)
except Exception as e:
processor = PllavaProcessor.from_pretrained('llava-hf/llava-1.5-7b-hf')
# config lora
if use_lora and weight_dir is not None:
print("Use lora")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False, target_modules=["q_proj", "v_proj"],
r=128, lora_alpha=lora_alpha, lora_dropout=0.
)
print("Lora Scaling:", lora_alpha/128)
model.language_model = get_peft_model(model.language_model, peft_config)
assert weight_dir is not None, "pass a folder to your lora weight"
print("Finish use lora")
# load weights
if weight_dir is not None:
state_dict = {}
save_fnames = os.listdir(weight_dir)
if "model.safetensors" in save_fnames:
use_full = False
for fn in save_fnames:
if fn.startswith('model-0'):
use_full=True
break
else:
use_full= True
if not use_full:
print("Loading weight from", weight_dir, "model.safetensors")
with safe_open(f"{weight_dir}/model.safetensors", framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
else:
print("Loading weight from", weight_dir)
for fn in save_fnames:
if fn.startswith('model-0'):
with safe_open(f"{weight_dir}/{fn}", framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if 'model' in state_dict.keys():
msg = model.load_state_dict(state_dict['model'], strict=False)
else:
msg = model.load_state_dict(state_dict, strict=False)
print(msg)
# dispatch model weight
if use_multi_gpus:
max_memory = get_balanced_memory(
model,
max_memory=None,
no_split_module_classes=["LlamaDecoderLayer"],
dtype='bfloat16',
low_zero=False,
)
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=["LlamaDecoderLayer"],
dtype='bfloat16'
)
dispatch_model(model, device_map=device_map)
print(model.hf_device_map)
model = model.eval()
return model, processor
def load_adapters(model, adapter_model_name_or_paths):
for adapter_model_name_or_path in adapter_model_name_or_paths:
if not isinstance(model, PeftModel):
model = PeftModel.from_pretrained(model, adapter_model_name_or_path, adapter_model_name_or_path)
else:
model.load_adapter(adapter_model_name_or_path, adapter_model_name_or_path)
return model
def pllava_answer(conv: Conversation, model, processor, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0, stop_criteria_keywords=None, print_res=False):
# torch.cuda.empty_cache()
prompt = conv.get_prompt()
inputs = processor(text=prompt, images=img_list, return_tensors="pt")
if inputs['pixel_values'] is None:
inputs.pop('pixel_values')
inputs = inputs.to(model.device)
# set up stopping criteria
if stop_criteria_keywords is not None:
stopping_criteria = [KeywordsStoppingCriteria(stop_criteria_keywords, processor.tokenizer, inputs["input_ids"])]
else:
stopping_criteria= None
with torch.no_grad():
output_token = model.generate(**inputs, media_type='video',
do_sample=do_sample, max_new_tokens=max_new_tokens, num_beams=num_beams, min_length=min_length,
top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature,
stopping_criteria=stopping_criteria,)
output_text = processor.batch_decode(output_token, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
if print_res: # debug usage
print('### PROMPTING LM WITH: ', prompt)
print('### LM OUTPUT TEXT: ', output_text)
if conv.roles[-1] == "<|im_start|>assistant\n":
split_tag = "<|im_start|> assistant\n"
else:
split_tag = conv.roles[-1]
output_text = output_text.split(split_tag)[-1]
ending = conv.sep if isinstance(conv.sep, str) else conv.sep[1]
output_text = output_text.removesuffix(ending)
conv.messages[-1][1] = output_text
return output_text, conv
|