|
|
|
import torch |
|
import os |
|
import transformers |
|
|
|
from transformers import Idefics2ForConditionalGeneration |
|
from peft import LoraConfig, get_peft_model |
|
from joint_inference import IdeficsJointInferenceModel |
|
|
|
def get_model(): |
|
|
|
repo = 'lil-lab/cogen' |
|
checkpoint = "HuggingFaceM4/idefics2-8b" |
|
model = Idefics2ForConditionalGeneration.from_pretrained(checkpoint, torch_dtype=torch.bfloat16).cuda() |
|
|
|
|
|
target_modules=r'(.*(vision_model|modality_projection|perceiver_resampler).*(out_proj|fc1|fc2|down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)|(.*(k_proj|q_proj|v_proj).*$)' |
|
lora_config = LoraConfig( |
|
r=16, lora_alpha=8, |
|
lora_dropout=0.1, |
|
target_modules=target_modules, |
|
init_lora_weights="gaussian" |
|
) |
|
model = get_peft_model(model, lora_config, adapter_name="initial") |
|
model.load_adapter(repo, "initial", revision="r0_full") |
|
|
|
|
|
new_targets = set() |
|
for n, p in model.named_parameters(): |
|
if 'lora' in n: |
|
new_targets.add(n[17:n.find('lora')-1]) |
|
new_targets = list(new_targets) |
|
|
|
lora_config = LoraConfig( |
|
r=16, lora_alpha=8, |
|
lora_dropout=0.1, |
|
target_modules=new_targets, |
|
init_lora_weights="gaussian" |
|
) |
|
model.add_adapter('final', lora_config) |
|
model.load_adapter(repo, "final", revision="r3_full") |
|
model = IdeficsJointInferenceModel(0.5, 0, model=model).cuda() |
|
model.eval() |
|
|
|
return model |
|
|