File size: 1,332 Bytes
19b3da3 1bc457e 19b3da3 b71808f 19b3da3 b71808f 19b3da3 1bc457e 19b3da3 1bc457e 19b3da3 b71808f 19b3da3 b71808f 19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import re
import torch
from torchvision import transforms
from transformers import BlipForConditionalGeneration, BlipProcessor
from internals.util.commons import download_image
from internals.util.config import get_hf_cache_dir
class Image2Text:
__loaded = False
def load(self):
if self.__loaded:
return
self.processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-large",
cache_dir=get_hf_cache_dir(),
)
self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large",
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
self.__loaded = True
def process(self, imageUrl: str) -> str:
self.load()
image = download_image(imageUrl).resize((512, 512))
inputs = self.processor.__call__(image, return_tensors="pt").to(
"cuda", torch.float16
)
output_ids = self.model.generate(
**inputs, do_sample=False, top_p=0.9, max_length=128
)
output_text = self.processor.batch_decode(output_ids)
print(output_text)
output_text = output_text[0]
output_text = re.sub("</.>|\\n|\[SEP\]", "", output_text)
return output_text
|