yesidcanoc's picture
Create README.md
5fc5f75
|
raw
history blame
3.19 kB
# Image captioning model
## How To use this model.
Adapt the code below to your needs.
```
import os
from PIL import Image
import torchvision.transforms as transforms
from transformers import GPT2TokenizerFast, VisionEncoderDecoderModel
class DataProcessing:
def __init__(self):
# GPT-2 tokenizer
self.tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2')
self.tokenizer.pad_token = self.tokenizer.eos_token
# Define the transforms to be applied to the images
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
class GenerateCaptions(DataProcessing):
NUM_BEAMS = 3
MAX_LENGTH = 15
EARLY_STOPPING = True
DO_SAMPLE = True
TOP_K = 10
NUM_RETURN_SEQUENCES = 2 # number of captions to generate
def __init__(self, captioning_model):
self.captioning_model = captioning_model
super().__init__()
def read_img_predict(self, path):
try:
with Image.open(path) as img:
if img.mode != "RGB":
img = img.convert('RGB')
img_transformed = self.transform(img).unsqueeze(0)
# tensor dimensions max_lenght X num_return_sequences, where ij == some_token_id
model_output = self.captioning_model.generate(
img_transformed,
num_beams=self.NUM_BEAMS,
max_length=self.MAX_LENGTH,
early_stopping=self.EARLY_STOPPING,
do_sample=self.DO_SAMPLE,
top_k=self.TOP_K,
num_return_sequences=self.NUM_RETURN_SEQUENCES,
)
# g is a tensor like this one: tensor([50256, 13, 198, 198, 198, 198, 198, 198, 198, 50256,
# 50256, 50256, 50256, 50256, 50256])
captions = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in model_output]
return captions
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {path}")
def generate_caption(self, path):
"""
Generate captions for a single image or a directory of images
:param path: path to image or directory of images
:return: captions
"""
if os.path.isdir(path):
self.decoded_predictions = []
for root, dirs, files in os.walk(path):
for file in files:
self.decoded_predictions.append(self.read_img_predict(os.path.join(root, file)))
return self.decoded_predictions
elif os.path.isfile(path):
return self.read_img_predict(path)
else:
raise ValueError(f"Invalid path: {path}")
image_captioning_model = VisionEncoderDecoderModel.from_pretrained("yesidcanoc/image-captioning-swin-tiny-distilgpt2")
generate_captions = GenerateCaptions(image_captioning_model)
captions = generate_captions.generate_caption('../data/test_data/images')
print(captions)
```