|
import torch |
|
import clip |
|
import PIL.Image |
|
from PIL import Image |
|
import skimage.io as io |
|
import streamlit as st |
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup |
|
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel |
|
from model import generate2,ClipCaptionModel |
|
from engine import inference |
|
|
|
|
|
model_trained = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
model_trained.load_state_dict(torch.load('model_trained.pth',map_location=torch.device('cpu')),strict=False) |
|
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
|
|
def show_n_generate(img, model, greedy = True): |
|
image = Image.open(img) |
|
pixel_values = image_processor(image, return_tensors ="pt").pixel_values |
|
|
|
if greedy: |
|
generated_ids = model.generate(pixel_values, max_new_tokens = 30) |
|
else: |
|
generated_ids = model.generate( |
|
pixel_values, |
|
do_sample=True, |
|
max_new_tokens = 30, |
|
top_k=5) |
|
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return generated_text |
|
|
|
device = "cpu" |
|
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) |
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
|
|
prefix_length = 10 |
|
|
|
model = ClipCaptionModel(prefix_length) |
|
|
|
model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu')),strict=False) |
|
|
|
model = model.eval() |
|
|
|
coco_model = ClipCaptionModel(prefix_length) |
|
coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu')),strict=False) |
|
model = model.eval() |
|
|
|
|
|
def ui(): |
|
st.markdown("# Image Captioning") |
|
|
|
uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg']) |
|
|
|
if uploaded_file is not None: |
|
image = io.imread(uploaded_file) |
|
pil_image = PIL.Image.fromarray(image) |
|
image = preprocess(pil_image).unsqueeze(0).to(device) |
|
|
|
option = st.selectbox('Please select the Model',('Clip Captioning','Attention Decoder','VIT+GPT2')) |
|
|
|
if option=='Clip Captioning': |
|
with torch.no_grad(): |
|
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) |
|
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) |
|
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed) |
|
|
|
st.image(uploaded_file, width = 500, channels = 'RGB') |
|
st.markdown("**PREDICTION:** " + generated_text_prefix) |
|
elif option=='Attention Decoder': |
|
out = inference(uploaded_file) |
|
st.image(uploaded_file, width = 500, channels = 'RGB') |
|
st.markdown("**PREDICTION:** " + out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
ui() |
|
|
|
|