Yair
add requirements.txt
1fb4187
raw
history blame
1.33 kB
import os
import torch
import gradio as gr
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer
def create_caption_transformer(img):
"""
create_caption_transformer() create a caption for an image using a transformer model
that was trained on 'Flickr image dataset'
:param img: a numpy array of the image
:return: a string of the image caption
"""
sample = feature_extractor(img, return_tensors="pt").pixel_values.to('cpu')
caption_ids = model.generate(sample, max_length=15)[0] # TODO: take care of the caption length
caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
caption_text = caption_text.split('.')[0]
return caption_text
torch.__version__
IMAGES_EXAMPLES_FOLDER = 'examples/'
images = os.listdir(IMAGES_EXAMPLES_FOLDER)
IMAGES_EXAMPLES = [IMAGES_EXAMPLES_FOLDER + img for img in images]
model = VisionEncoderDecoderModel.from_pretrained(os.getcwd()).to('cpu')
feature_extractor = AutoFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
iface = gr.Interface(fn=create_caption_transformer,
inputs="image",
outputs='text',
examples=IMAGES_EXAMPLES
).launch(share=True)