File size: 4,549 Bytes
c389f87 5fc5f75 b0f6fc2 c389f87 3bc3692 5fc5f75 c389f87 3bc3692 c389f87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
---
pipeline_tag: image-to-text
---
# Image captioning model
End-to-end Transformer based image captioning model, where both the encoder and decoder use standard pre-trained transformer architectures.
### Repository: https://github.com/yesidc/image-captioning
## Encoder
The encoder uses the pre-trained Swin transformer (Liu et al., 2021) that is a general-purpose backbone for computer vision. It outperforms ViT, DeiT and ResNe(X)t models at tasks such as image classification, object detection and semantic segmentation. The fact that this model is not pre-trained to be a 'narrow expert'--- a model pre-trained to perform a specific task e.g., image classification --- makes it a good candidate for fine-tuning on a downstream task.
## Decoder
Distilgpt2
## Dataset
The model is fine-tuned and evaluated on the COCO 2017 dataset.
## How to use this model using Transformer's pipeline
```
from transformers import pipeline
image_to_text = pipeline("image-to-text", model="yesidcanoc/image-captioning-swin-tiny-distilgpt2")
# Provide path the image file
caption = image_to_text("./COCO_val2014_000000457986.jpg")
print(caption)
```
## How To use this model using `GenerateCaptions` utility class.
Adapt the code below to your needs.
```
import os
from PIL import Image
import torchvision.transforms as transforms
from transformers import GPT2TokenizerFast, VisionEncoderDecoderModel
class DataProcessing:
def __init__(self):
# GPT-2 tokenizer
self.tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2')
self.tokenizer.pad_token = self.tokenizer.eos_token
# Define the transforms to be applied to the images
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
class GenerateCaptions(DataProcessing):
NUM_BEAMS = 3
MAX_LENGTH = 15
EARLY_STOPPING = True
DO_SAMPLE = True
TOP_K = 10
NUM_RETURN_SEQUENCES = 2 # number of captions to generate
def __init__(self, captioning_model):
self.captioning_model = captioning_model
super().__init__()
def read_img_predict(self, path):
try:
with Image.open(path) as img:
if img.mode != "RGB":
img = img.convert('RGB')
img_transformed = self.transform(img).unsqueeze(0)
# tensor dimensions max_lenght X num_return_sequences, where ij == some_token_id
model_output = self.captioning_model.generate(
img_transformed,
num_beams=self.NUM_BEAMS,
max_length=self.MAX_LENGTH,
early_stopping=self.EARLY_STOPPING,
do_sample=self.DO_SAMPLE,
top_k=self.TOP_K,
num_return_sequences=self.NUM_RETURN_SEQUENCES,
)
# g is a tensor like this one: tensor([50256, 13, 198, 198, 198, 198, 198, 198, 198, 50256,
# 50256, 50256, 50256, 50256, 50256])
captions = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in model_output]
return captions
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {path}")
def generate_caption(self, path):
"""
Generate captions for a single image or a directory of images
:param path: path to image or directory of images
:return: captions
"""
if os.path.isdir(path):
self.decoded_predictions = []
for root, dirs, files in os.walk(path):
for file in files:
self.decoded_predictions.append(self.read_img_predict(os.path.join(root, file)))
return self.decoded_predictions
elif os.path.isfile(path):
return self.read_img_predict(path)
else:
raise ValueError(f"Invalid path: {path}")
image_captioning_model = VisionEncoderDecoderModel.from_pretrained("yesidcanoc/image-captioning-swin-tiny-distilgpt2")
generate_captions = GenerateCaptions(image_captioning_model)
captions = generate_captions.generate_caption('../data/test_data/images')
print(captions)
```
## References
- Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., Lin, S., & Guo, B. (2021). Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. ArXiv. /abs/2103.14030 |