|
from PIL import Image |
|
import io |
|
import torch |
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
from utils.image_utils import load_image |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
class ImageCaptioning: |
|
|
|
def __init__(self): |
|
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) |
|
|
|
def get_caption(self, image): |
|
|
|
|
|
img = Image.open(io.BytesIO(image)) |
|
img = self.processor(img, return_tensors="pt").to(device) |
|
|
|
|
|
output = self.model.generate(**img) |
|
|
|
|
|
caption = self.processor.batch_decode(output, skip_special_tokens=True)[0] |
|
|
|
return caption |
|
|