import logging from omegaconf import OmegaConf from lavis.models import registry from lavis.models import load_preprocess from ldm.util import instantiate_from_config def load_blip2_model(cfg, is_eval=False, device="cpu"): model_cls = registry.get_model_class(cfg.model_name) # load preprocess default_cfg = OmegaConf.load(model_cls.default_config_path(cfg.model_type)) default_cfg.model.pretrained = cfg.pretrained if default_cfg.model.image_size != cfg.params.img_size: default_cfg.model.image_size = cfg.params.img_size model = model_cls.from_config(default_cfg.model) model.cfg = default_cfg.model if is_eval: model.eval() if default_cfg is not None: preprocess_cfg = default_cfg.preprocess vis_processors, txt_processors = load_preprocess(preprocess_cfg) else: vis_processors, txt_processors = None, None logging.info( f"""No default preprocess for model {name} ({model_type}). This can happen if the model is not finetuned on downstream datasets, or it is not intended for direct use without finetuning. """ ) if device == "cpu" or device == torch.device("cpu"): model = model.float() return model.to(device), vis_processors, txt_processors def load_qformer_model(cfg): blip2_model, vis_processor, txt_processor = load_blip2_model(cfg) q_former = instantiate_from_config(cfg) if blip2_model.query_tokens.shape != q_former.query_tokens.shape: blip2_model.query_tokens = q_former.query_tokens model_name = cfg.params.get('model_name', 'bert-base-uncased') if model_name == 'bert-base-uncased': q_former.load_state_dict(blip2_model.state_dict(), strict=False) return q_former, (vis_processor, txt_processor)