|
import re |
|
|
|
import torch |
|
from torchvision import transforms |
|
from transformers import BlipForConditionalGeneration, BlipProcessor |
|
|
|
from internals.util.commons import download_image |
|
from internals.util.config import get_hf_cache_dir |
|
|
|
|
|
class Image2Text: |
|
__loaded = False |
|
|
|
def load(self): |
|
if self.__loaded: |
|
return |
|
|
|
self.processor = BlipProcessor.from_pretrained( |
|
"Salesforce/blip-image-captioning-large", |
|
cache_dir=get_hf_cache_dir(), |
|
) |
|
self.model = BlipForConditionalGeneration.from_pretrained( |
|
"Salesforce/blip-image-captioning-large", |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
).to("cuda") |
|
|
|
self.__loaded = True |
|
|
|
def process(self, imageUrl: str) -> str: |
|
self.load() |
|
image = download_image(imageUrl).resize((512, 512)) |
|
inputs = self.processor.__call__(image, return_tensors="pt").to( |
|
"cuda", torch.float16 |
|
) |
|
output_ids = self.model.generate( |
|
**inputs, do_sample=False, top_p=0.9, max_length=128 |
|
) |
|
output_text = self.processor.batch_decode(output_ids) |
|
print(output_text) |
|
output_text = output_text[0] |
|
output_text = re.sub("</.>|\\n|\[SEP\]", "", output_text) |
|
return output_text |
|
|