File size: 1,333 Bytes
4034b15
6437a53
 
4034b15
18c5ad2
 
4034b15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fb4187
6437a53
 
 
7dbd807
4034b15
 
 
 
 
6437a53
4034b15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import torch
import gradio as gr
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer


def create_caption_transformer(img):
    """
    create_caption_transformer() create a caption for an image using a transformer model
    that was trained on 'Flickr image dataset'
    :param img: a numpy array of the image
    :return: a string of the image caption
    """

    sample = feature_extractor(img, return_tensors="pt").pixel_values.to('cpu')
    caption_ids = model.generate(sample, max_length=15)[0]  # TODO: take care of the caption length
    caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
    caption_text = caption_text.split('.')[0]
    return caption_text


torch.__version__
IMAGES_EXAMPLES_FOLDER = 'examples/'
images = os.listdir(IMAGES_EXAMPLES_FOLDER)
IMAGES_EXAMPLES = [IMAGES_EXAMPLES_FOLDER + img for img in images]
model = VisionEncoderDecoderModel.from_pretrained(os.getcwd()).to('cpu')
feature_extractor = AutoFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
iface = gr.Interface(fn=create_caption_transformer,
                     inputs="image",
                     outputs='text',
                     examples=IMAGES_EXAMPLES
                     ).launch(share=True)