import re import torch from torchvision import transforms from transformers import BlipForConditionalGeneration, BlipProcessor from internals.util.commons import download_image from internals.util.config import get_hf_cache_dir class Image2Text: __loaded = False def load(self): if self.__loaded: return self.processor = BlipProcessor.from_pretrained( "Salesforce/blip-image-captioning-large", cache_dir=get_hf_cache_dir(), ) self.model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), ).to("cuda") self.__loaded = True def process(self, imageUrl: str) -> str: self.load() image = download_image(imageUrl).resize((512, 512)) inputs = self.processor.__call__(image, return_tensors="pt").to( "cuda", torch.float16 ) output_ids = self.model.generate( **inputs, do_sample=False, top_p=0.9, max_length=128 ) output_text = self.processor.batch_decode(output_ids) print(output_text) output_text = output_text[0] output_text = re.sub("|\\n|\[SEP\]", "", output_text) return output_text