SmallCapDemo / app.py
RitaParadaRamos's picture
Upload app.py
79b66b6
raw
history blame
3.88 kB
import requests
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, AutoTokenizer, CLIPFeatureExtractor, AutoModel, AutoModelForCausalLM
from transformers.models.auto.configuration_auto import AutoConfig
from src.vision_encoder_decoder import SmallCap, SmallCapConfig
from src.gpt2 import ThisGPT2Config, ThisGPT2LMHeadModel
from src.utils import prep_strings, postprocess_preds
import json
from src.retrieve_caps import *
from PIL import Image
from torchvision import transforms
from src.opt import ThisOPTConfig, ThisOPTForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
# load feature extractor
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
# load and configure tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
tokenizer.pad_token = '!'
tokenizer.eos_token = '.'
# load model
# AutoConfig.register("this_gpt2", ThisGPT2Config)
# AutoModel.register(ThisGPT2Config, ThisGPT2LMHeadModel)
# AutoModelForCausalLM.register(ThisGPT2Config, ThisGPT2LMHeadModel)
# AutoConfig.register("smallcap", SmallCapConfig)
# AutoModel.register(SmallCapConfig, SmallCap)
# model = AutoModel.from_pretrained("Yova/SmallCap7M")
AutoConfig.register("this_opt", ThisOPTConfig)
AutoModel.register(ThisOPTConfig, ThisOPTForCausalLM)
AutoModelForCausalLM.register(ThisOPTConfig, ThisOPTForCausalLM)
AutoConfig.register("smallcap", SmallCapConfig)
AutoModel.register(SmallCapConfig, SmallCap)
model = AutoModel.from_pretrained("Yova/SmallCapOPT7M")
model= model.to(device)
template = open('src/template.txt').read().strip() + ' '
# precompute captions for retrieval
captions = json.load(open('coco_index_captions.json'))
retrieval_model, feature_extractor_retrieval = clip.load("RN50x64", device=device)
retrieval_index = faiss.read_index('coco_index')
#res = faiss.StandardGpuResources()
#retrieval_index = faiss.index_cpu_to_gpu(res, 0, retrieval_index)
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
def retrieve_caps(image_embedding, index, k=4):
xq = image_embedding.astype(np.float32)
faiss.normalize_L2(xq)
D, I = index.search(xq, k)
return I
def classify_image(image):
inp = transforms.ToTensor()(image)
pixel_values_retrieval = feature_extractor_retrieval(image).to(device)
with torch.no_grad():
image_embedding = retrieval_model.encode_image(pixel_values_retrieval.unsqueeze(0)).cpu().numpy()
nns = retrieve_caps(image_embedding, retrieval_index)[0]
caps = [captions[i] for i in nns][:4]
# prepare prompt
decoder_input_ids = prep_strings('', tokenizer, template=template, retrieved_caps=caps, k=4, is_test=True)
# generate caption
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
with torch.no_grad():
pred = model.generate(pixel_values.to(device),
decoder_input_ids=torch.tensor([decoder_input_ids]).to(device),
max_new_tokens=25, no_repeat_ngram_size=0, length_penalty=0,
min_length=1, num_beams=3, eos_token_id=tokenizer.eos_token_id)
#inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
#prediction = inception_net.predict(inp).flatten()
retrieved_caps="Retrieved captions: \n{}\n{}\n{}\n{}".format(*caps)
#return retrieved_caps + "\n\n\n Generated caption:\n" + str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer))
return str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer)) + "\n\n\n"+ retrieved_caps
image = gr.Image(type="pil")
textbox = gr.Textbox(placeholder="Generated caption and retrieved captions...", lines=4)
title = "SmallCap Demo"
gr.Interface(
fn=classify_image, inputs=image, outputs=textbox, title=title
).launch()