File size: 3,187 Bytes
5fc5f75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Image captioning model

## How To use this model.

Adapt the code below to your needs. 
```
import os
from PIL import Image
import torchvision.transforms as transforms
from transformers import GPT2TokenizerFast, VisionEncoderDecoderModel

class DataProcessing:
    def __init__(self):
        # GPT-2 tokenizer
        self.tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token
        # Define the transforms to be applied to the images
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

class GenerateCaptions(DataProcessing):
    NUM_BEAMS = 3
    MAX_LENGTH = 15
    EARLY_STOPPING = True
    DO_SAMPLE = True
    TOP_K = 10
    NUM_RETURN_SEQUENCES = 2 # number of captions to generate

    def __init__(self, captioning_model):
        self.captioning_model = captioning_model
        super().__init__()

    def read_img_predict(self, path):
        try:
            with Image.open(path) as img:
                if img.mode != "RGB":
                    img = img.convert('RGB')
                img_transformed = self.transform(img).unsqueeze(0)
                # tensor dimensions max_lenght X num_return_sequences, where ij == some_token_id
                model_output = self.captioning_model.generate(
                    img_transformed,
                    num_beams=self.NUM_BEAMS,
                    max_length=self.MAX_LENGTH,
                    early_stopping=self.EARLY_STOPPING,
                    do_sample=self.DO_SAMPLE,
                    top_k=self.TOP_K,
                    num_return_sequences=self.NUM_RETURN_SEQUENCES,
                )
                # g is a tensor like this one: tensor([50256,    13,   198,   198,   198,   198,   198,   198,   198, 50256,
                # 50256, 50256, 50256, 50256, 50256])
                captions = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in model_output]

                return captions
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {path}")

    def generate_caption(self, path):
        """
        Generate captions for a single image or a directory of images
        :param path: path to image or directory of images
        :return: captions
        """
        if os.path.isdir(path):
            self.decoded_predictions = []
            for root, dirs, files in os.walk(path):
                for file in files:
                    self.decoded_predictions.append(self.read_img_predict(os.path.join(root, file)))
            return self.decoded_predictions
        elif os.path.isfile(path):
            return self.read_img_predict(path)
        else:
            raise ValueError(f"Invalid path: {path}")





image_captioning_model = VisionEncoderDecoderModel.from_pretrained("yesidcanoc/image-captioning-swin-tiny-distilgpt2")

generate_captions = GenerateCaptions(image_captioning_model)

captions = generate_captions.generate_caption('../data/test_data/images')

print(captions)


```