File size: 4,993 Bytes
a14156a
 
 
 
 
59e90bd
a14156a
7cb104b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a14156a
7cb104b
a14156a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
from transformers import T5Tokenizer, ViTFeatureExtractor
from torch import nn

class Encoder(nn.Module):
    def __init__(self, pretrained_model):
        """
        Implements the Encoder."

        Args:
            pretrained_model (str): name of the pretrained model

        """

        super(Encoder, self).__init__()

        self.encoder = ViTModel.from_pretrained(pretrained_model)

    def forward(self, input):
        out  = self.encoder(pixel_values = input)

        return out

class Decoder(nn.Module):
    def __init__(self, pretrained_model, encoder_modeldim):
        """
        Implements the Decoder."

        Args:
            pretrained_model (str): name of the pretrained model

        """

        super(Decoder, self).__init__()

        self.decoder = T5ForConditionalGeneration.from_pretrained(pretrained_model)
        self.linear = nn.Linear(self.decoder.model_dim, encoder_modeldim, bias = False)
        self.encoder_modeldim = encoder_modeldim

    def forward(self, output_encoder, targets, decoder_ids=None):

        if self.decoder.model_dim!=self.encoder_modeldim:
            print(f"Changed model hidden dimension from {self.encoder_modeldim} to {self.decoder.model_dim}")
            output_encoder = self.linear(output_encoder)
            print(output_encoder.shape)

        # Validation/Testing
        if decoder_ids is not None:
            out = self.decoder(encoder_outputs=output_encoder, decoder_input_ids=decoder_ids)

        # Training
        else:
            out = self.decoder(encoder_outputs=output_encoder, labels=targets)

        return out

class EncoderDecoder(nn.Module):
    def __init__(self, pretrained_model: Tuple[str], encoder_dmodel=768, eos_token_id=None, pad_token_id=None):
        """
        Implements a model that combines MyEncoder and MyDecoder."

        Args:
            pretrained_model (tuple): name of the pretrained model
            encoder_dmodel (int):  hidden dimension of the encoder output
            eos_token_id (torch.long): token used for end of sentence
            pad_token_id (torch.long): token used for padding

        """

        super(EncoderDecoder, self).__init__()
        self.eos_token_id = eos_token_id
        self.pad_token_id = pad_token_id
        self.encoder = Encoder(pretrained_model[0])
        self.encoder_dmodel = encoder_dmodel

        # Freeze parameters from encoder
        #for p in self.encoder.parameters():
        #    p.requires_grad=False

        self.decoder = Decoder(pretrained_model[1], self.encoder_dmodel)
        self.decoder_start_token_id = self.decoder.decoder.config.decoder_start_token_id

    def forward(self, images = None, targets = None, decoder_ids = None):
        output_encoder  = self.encoder(images)
        out = self.decoder(output_encoder, targets, decoder_ids)

        return out

# Model loading and setting up the device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("model_vit_ai.pt", map_location=device)
model.to(device)

# Tokenizer and Feature Extractor
tokenizer = T5Tokenizer.from_pretrained('t5-base')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

# Define the image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])

def preprocess_image(image):
    image = Image.fromarray(image.astype('uint8'), 'RGB')
    image = transform(image)
    return image.unsqueeze(0)

def generate_caption(image):
    model.eval()
    with torch.no_grad():
        image_tensor = preprocess_image(image).to(device)
        decoder_input_ids = torch.full((1, 1), model.decoder_start_token_id, dtype=torch.long, device=device)

        for _ in range(50):
            outputs = model(images=image_tensor, decoder_ids=decoder_input_ids)
            next_token_logits = outputs.logits[:, -1, :]
            next_token_id = next_token_logits.argmax(1, keepdim=True)
            decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1)

            if torch.eq(next_token_id, tokenizer.eos_token_id).all():
                break

        caption = tokenizer.decode(decoder_input_ids.squeeze(0), skip_special_tokens=True)
    return caption

sample_images = [
    "sample_image1.jpg",
    "sample_image2.jpg",
    "sample_image3.jpg"
]

# Define Gradio interface
interface = gr.Interface(
    fn=generate_caption,
    inputs=gr.inputs.Image(source="upload", tool='editor', type="numpy", label="Upload an image or take a photo"),
    outputs='text',
    examples=sample_images,
    title="Image Captioning Model",
    description="Upload an image, select a sample image, or use your webcam to take a photo and generate a caption."
)

# Run the interface
interface.launch(debug=True)