Rediones / utils /caption_utils.py
Testys's picture
Just s
e66ee6d
raw
history blame
867 Bytes
from PIL import Image
import io
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from utils.image_utils import load_image
device = "cuda" if torch.cuda.is_available() else "cpu"
class ImageCaptioning:
def __init__(self):
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
def get_caption(self, image):
# Preprocessing the Image
img = Image.open(io.BytesIO(image))
img = self.processor(img, return_tensors="pt").to(device)
# Generating captions
output = self.model.generate(**img)
# decode the output
caption = self.processor.batch_decode(output, skip_special_tokens=True)[0]
return caption