Image captioning model
How To use this model.
Adapt the code below to your needs.
import os
from PIL import Image
import torchvision.transforms as transforms
from transformers import GPT2TokenizerFast, VisionEncoderDecoderModel
class DataProcessing:
def __init__(self):
# GPT-2 tokenizer
self.tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2')
self.tokenizer.pad_token = self.tokenizer.eos_token
# Define the transforms to be applied to the images
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
class GenerateCaptions(DataProcessing):
NUM_BEAMS = 3
MAX_LENGTH = 15
EARLY_STOPPING = True
DO_SAMPLE = True
TOP_K = 10
NUM_RETURN_SEQUENCES = 2 # number of captions to generate
def __init__(self, captioning_model):
self.captioning_model = captioning_model
super().__init__()
def read_img_predict(self, path):
try:
with Image.open(path) as img:
if img.mode != "RGB":
img = img.convert('RGB')
img_transformed = self.transform(img).unsqueeze(0)
# tensor dimensions max_lenght X num_return_sequences, where ij == some_token_id
model_output = self.captioning_model.generate(
img_transformed,
num_beams=self.NUM_BEAMS,
max_length=self.MAX_LENGTH,
early_stopping=self.EARLY_STOPPING,
do_sample=self.DO_SAMPLE,
top_k=self.TOP_K,
num_return_sequences=self.NUM_RETURN_SEQUENCES,
)
# g is a tensor like this one: tensor([50256, 13, 198, 198, 198, 198, 198, 198, 198, 50256,
# 50256, 50256, 50256, 50256, 50256])
captions = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in model_output]
return captions
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {path}")
def generate_caption(self, path):
"""
Generate captions for a single image or a directory of images
:param path: path to image or directory of images
:return: captions
"""
if os.path.isdir(path):
self.decoded_predictions = []
for root, dirs, files in os.walk(path):
for file in files:
self.decoded_predictions.append(self.read_img_predict(os.path.join(root, file)))
return self.decoded_predictions
elif os.path.isfile(path):
return self.read_img_predict(path)
else:
raise ValueError(f"Invalid path: {path}")
image_captioning_model = VisionEncoderDecoderModel.from_pretrained("yesidcanoc/image-captioning-swin-tiny-distilgpt2")
generate_captions = GenerateCaptions(image_captioning_model)
captions = generate_captions.generate_caption('../data/test_data/images')
print(captions)