Yair
update
4034b15
raw
history blame
1.13 kB
import gradio as gr
import os
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer
def create_caption_transformer(img):
"""
create_caption_transformer() create a caption for an image using a transformer model
that was trained on 'Flickr image dataset'
:param img: a numpy array of the image
:return: a string of the image caption
"""
sample = feature_extractor(img, return_tensors="pt").pixel_values.to('cpu')
caption_ids = model.generate(sample, max_length=15)[0] # TODO: take care of the caption length
caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
caption_text = caption_text.split('.')[0]
return caption_text
model = VisionEncoderDecoderModel.from_pretrained(os.getcwd() + '/transformer').to('cpu')
feature_extractor = AutoFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
iface = gr.Interface(fn=create_caption_transformer,
inputs="image",
outputs='text',
).launch(share=True)