# Image captioning model ## How To use this model. Adapt the code below to your needs. ``` import os from PIL import Image import torchvision.transforms as transforms from transformers import GPT2TokenizerFast, VisionEncoderDecoderModel class DataProcessing: def __init__(self): # GPT-2 tokenizer self.tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2') self.tokenizer.pad_token = self.tokenizer.eos_token # Define the transforms to be applied to the images self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) class GenerateCaptions(DataProcessing): NUM_BEAMS = 3 MAX_LENGTH = 15 EARLY_STOPPING = True DO_SAMPLE = True TOP_K = 10 NUM_RETURN_SEQUENCES = 2 # number of captions to generate def __init__(self, captioning_model): self.captioning_model = captioning_model super().__init__() def read_img_predict(self, path): try: with Image.open(path) as img: if img.mode != "RGB": img = img.convert('RGB') img_transformed = self.transform(img).unsqueeze(0) # tensor dimensions max_lenght X num_return_sequences, where ij == some_token_id model_output = self.captioning_model.generate( img_transformed, num_beams=self.NUM_BEAMS, max_length=self.MAX_LENGTH, early_stopping=self.EARLY_STOPPING, do_sample=self.DO_SAMPLE, top_k=self.TOP_K, num_return_sequences=self.NUM_RETURN_SEQUENCES, ) # g is a tensor like this one: tensor([50256, 13, 198, 198, 198, 198, 198, 198, 198, 50256, # 50256, 50256, 50256, 50256, 50256]) captions = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in model_output] return captions except FileNotFoundError: raise FileNotFoundError(f"File not found: {path}") def generate_caption(self, path): """ Generate captions for a single image or a directory of images :param path: path to image or directory of images :return: captions """ if os.path.isdir(path): self.decoded_predictions = [] for root, dirs, files in os.walk(path): for file in files: self.decoded_predictions.append(self.read_img_predict(os.path.join(root, file))) return self.decoded_predictions elif os.path.isfile(path): return self.read_img_predict(path) else: raise ValueError(f"Invalid path: {path}") image_captioning_model = VisionEncoderDecoderModel.from_pretrained("yesidcanoc/image-captioning-swin-tiny-distilgpt2") generate_captions = GenerateCaptions(image_captioning_model) captions = generate_captions.generate_caption('../data/test_data/images') print(captions) ```