Spaces:
Build error
Build error
File size: 1,333 Bytes
4034b15 6437a53 4034b15 18c5ad2 4034b15 1fb4187 6437a53 7dbd807 4034b15 6437a53 4034b15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import os
import torch
import gradio as gr
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer
def create_caption_transformer(img):
"""
create_caption_transformer() create a caption for an image using a transformer model
that was trained on 'Flickr image dataset'
:param img: a numpy array of the image
:return: a string of the image caption
"""
sample = feature_extractor(img, return_tensors="pt").pixel_values.to('cpu')
caption_ids = model.generate(sample, max_length=15)[0] # TODO: take care of the caption length
caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
caption_text = caption_text.split('.')[0]
return caption_text
torch.__version__
IMAGES_EXAMPLES_FOLDER = 'examples/'
images = os.listdir(IMAGES_EXAMPLES_FOLDER)
IMAGES_EXAMPLES = [IMAGES_EXAMPLES_FOLDER + img for img in images]
model = VisionEncoderDecoderModel.from_pretrained(os.getcwd()).to('cpu')
feature_extractor = AutoFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
iface = gr.Interface(fn=create_caption_transformer,
inputs="image",
outputs='text',
examples=IMAGES_EXAMPLES
).launch(share=True)
|