Spaces:
Build error
Build error
import os | |
import torch | |
import gradio as gr | |
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer | |
def create_caption_transformer(img): | |
""" | |
create_caption_transformer() create a caption for an image using a transformer model | |
that was trained on 'Flickr image dataset' | |
:param img: a numpy array of the image | |
:return: a string of the image caption | |
""" | |
sample = feature_extractor(img, return_tensors="pt").pixel_values.to('cpu') | |
caption_ids = model.generate(sample, max_length=15)[0] # TODO: take care of the caption length | |
caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True) | |
caption_text = caption_text.split('.')[0] | |
return caption_text | |
torch.__version__ | |
IMAGES_EXAMPLES_FOLDER = 'examples/' | |
images = os.listdir(IMAGES_EXAMPLES_FOLDER) | |
IMAGES_EXAMPLES = [IMAGES_EXAMPLES_FOLDER + img for img in images] | |
model = VisionEncoderDecoderModel.from_pretrained(os.getcwd()).to('cpu') | |
feature_extractor = AutoFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') | |
tokenizer = AutoTokenizer.from_pretrained('gpt2') | |
iface = gr.Interface(fn=create_caption_transformer, | |
inputs="image", | |
outputs='text', | |
examples=IMAGES_EXAMPLES | |
).launch(share=True) | |