|
# Image captioning model |
|
|
|
## How To use this model. |
|
|
|
Adapt the code below to your needs. |
|
``` |
|
import os |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
from transformers import GPT2TokenizerFast, VisionEncoderDecoderModel |
|
|
|
class DataProcessing: |
|
def __init__(self): |
|
# GPT-2 tokenizer |
|
self.tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2') |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
# Define the transforms to be applied to the images |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
class GenerateCaptions(DataProcessing): |
|
NUM_BEAMS = 3 |
|
MAX_LENGTH = 15 |
|
EARLY_STOPPING = True |
|
DO_SAMPLE = True |
|
TOP_K = 10 |
|
NUM_RETURN_SEQUENCES = 2 # number of captions to generate |
|
|
|
def __init__(self, captioning_model): |
|
self.captioning_model = captioning_model |
|
super().__init__() |
|
|
|
def read_img_predict(self, path): |
|
try: |
|
with Image.open(path) as img: |
|
if img.mode != "RGB": |
|
img = img.convert('RGB') |
|
img_transformed = self.transform(img).unsqueeze(0) |
|
# tensor dimensions max_lenght X num_return_sequences, where ij == some_token_id |
|
model_output = self.captioning_model.generate( |
|
img_transformed, |
|
num_beams=self.NUM_BEAMS, |
|
max_length=self.MAX_LENGTH, |
|
early_stopping=self.EARLY_STOPPING, |
|
do_sample=self.DO_SAMPLE, |
|
top_k=self.TOP_K, |
|
num_return_sequences=self.NUM_RETURN_SEQUENCES, |
|
) |
|
# g is a tensor like this one: tensor([50256, 13, 198, 198, 198, 198, 198, 198, 198, 50256, |
|
# 50256, 50256, 50256, 50256, 50256]) |
|
captions = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in model_output] |
|
|
|
return captions |
|
except FileNotFoundError: |
|
raise FileNotFoundError(f"File not found: {path}") |
|
|
|
def generate_caption(self, path): |
|
""" |
|
Generate captions for a single image or a directory of images |
|
:param path: path to image or directory of images |
|
:return: captions |
|
""" |
|
if os.path.isdir(path): |
|
self.decoded_predictions = [] |
|
for root, dirs, files in os.walk(path): |
|
for file in files: |
|
self.decoded_predictions.append(self.read_img_predict(os.path.join(root, file))) |
|
return self.decoded_predictions |
|
elif os.path.isfile(path): |
|
return self.read_img_predict(path) |
|
else: |
|
raise ValueError(f"Invalid path: {path}") |
|
|
|
|
|
|
|
|
|
|
|
image_captioning_model = VisionEncoderDecoderModel.from_pretrained("yesidcanoc/image-captioning-swin-tiny-distilgpt2") |
|
|
|
generate_captions = GenerateCaptions(image_captioning_model) |
|
|
|
captions = generate_captions.generate_caption('../data/test_data/images') |
|
|
|
print(captions) |
|
|
|
|
|
``` |