|
--- |
|
pipeline_tag: image-to-text |
|
--- |
|
# Image captioning model |
|
|
|
End-to-end Transformer based image captioning model, where both the encoder and decoder use standard pre-trained transformer architectures. |
|
|
|
## Encoder |
|
The encoder uses the pre-trained Swin transformer (Liu et al., 2021) that is a general-purpose backbone for computer vision. It outperforms ViT, DeiT and ResNe(X)t models at tasks such as image classification, object detection and semantic segmentation. The fact that this model is not pre-trained to be a 'narrow expert'--- a model pre-trained to perform a specific task e.g., image classification --- makes it a good candidate for fine-tuning on a downstream task. |
|
|
|
## Decoder |
|
|
|
Distilgpt2 |
|
|
|
## Dataset |
|
|
|
The model is fine-tuned and evaluated on the COCO 2017 dataset. |
|
|
|
|
|
## How To use this model. |
|
|
|
Adapt the code below to your needs. |
|
``` |
|
import os |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
from transformers import GPT2TokenizerFast, VisionEncoderDecoderModel |
|
|
|
class DataProcessing: |
|
def __init__(self): |
|
# GPT-2 tokenizer |
|
self.tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2') |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
# Define the transforms to be applied to the images |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
class GenerateCaptions(DataProcessing): |
|
NUM_BEAMS = 3 |
|
MAX_LENGTH = 15 |
|
EARLY_STOPPING = True |
|
DO_SAMPLE = True |
|
TOP_K = 10 |
|
NUM_RETURN_SEQUENCES = 2 # number of captions to generate |
|
|
|
def __init__(self, captioning_model): |
|
self.captioning_model = captioning_model |
|
super().__init__() |
|
|
|
def read_img_predict(self, path): |
|
try: |
|
with Image.open(path) as img: |
|
if img.mode != "RGB": |
|
img = img.convert('RGB') |
|
img_transformed = self.transform(img).unsqueeze(0) |
|
# tensor dimensions max_lenght X num_return_sequences, where ij == some_token_id |
|
model_output = self.captioning_model.generate( |
|
img_transformed, |
|
num_beams=self.NUM_BEAMS, |
|
max_length=self.MAX_LENGTH, |
|
early_stopping=self.EARLY_STOPPING, |
|
do_sample=self.DO_SAMPLE, |
|
top_k=self.TOP_K, |
|
num_return_sequences=self.NUM_RETURN_SEQUENCES, |
|
) |
|
# g is a tensor like this one: tensor([50256, 13, 198, 198, 198, 198, 198, 198, 198, 50256, |
|
# 50256, 50256, 50256, 50256, 50256]) |
|
captions = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in model_output] |
|
|
|
return captions |
|
except FileNotFoundError: |
|
raise FileNotFoundError(f"File not found: {path}") |
|
|
|
def generate_caption(self, path): |
|
""" |
|
Generate captions for a single image or a directory of images |
|
:param path: path to image or directory of images |
|
:return: captions |
|
""" |
|
if os.path.isdir(path): |
|
self.decoded_predictions = [] |
|
for root, dirs, files in os.walk(path): |
|
for file in files: |
|
self.decoded_predictions.append(self.read_img_predict(os.path.join(root, file))) |
|
return self.decoded_predictions |
|
elif os.path.isfile(path): |
|
return self.read_img_predict(path) |
|
else: |
|
raise ValueError(f"Invalid path: {path}") |
|
|
|
|
|
|
|
|
|
|
|
image_captioning_model = VisionEncoderDecoderModel.from_pretrained("yesidcanoc/image-captioning-swin-tiny-distilgpt2") |
|
|
|
generate_captions = GenerateCaptions(image_captioning_model) |
|
|
|
captions = generate_captions.generate_caption('../data/test_data/images') |
|
|
|
print(captions) |
|
|
|
|
|
``` |
|
|
|
## References |
|
- Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., Lin, S., & Guo, B. (2021). Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. ArXiv. /abs/2103.14030 |