CM2000112 / internals /pipelines /img_to_text.py
jayparmr's picture
Upload folder using huggingface_hub
1bc457e
raw
history blame
1.33 kB
import re
import torch
from torchvision import transforms
from transformers import BlipForConditionalGeneration, BlipProcessor
from internals.util.commons import download_image
from internals.util.config import get_hf_cache_dir
class Image2Text:
__loaded = False
def load(self):
if self.__loaded:
return
self.processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-large",
cache_dir=get_hf_cache_dir(),
)
self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large",
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
self.__loaded = True
def process(self, imageUrl: str) -> str:
self.load()
image = download_image(imageUrl).resize((512, 512))
inputs = self.processor.__call__(image, return_tensors="pt").to(
"cuda", torch.float16
)
output_ids = self.model.generate(
**inputs, do_sample=False, top_p=0.9, max_length=128
)
output_text = self.processor.batch_decode(output_ids)
print(output_text)
output_text = output_text[0]
output_text = re.sub("</.>|\\n|\[SEP\]", "", output_text)
return output_text