import requests import gradio as gr import torch from transformers import ViTFeatureExtractor, AutoTokenizer, CLIPFeatureExtractor, AutoModel, AutoModelForCausalLM from transformers.models.auto.configuration_auto import AutoConfig from src.vision_encoder_decoder import SmallCap, SmallCapConfig from src.gpt2 import ThisGPT2Config, ThisGPT2LMHeadModel from src.utils import prep_strings, postprocess_preds import json from src.retrieve_caps import * from PIL import Image from torchvision import transforms from src.opt import ThisOPTConfig, ThisOPTForCausalLM device = "cuda" if torch.cuda.is_available() else "cpu" # load feature extractor feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") # load and configure tokenizer tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") tokenizer.pad_token = '!' tokenizer.eos_token = '.' # load model # AutoConfig.register("this_gpt2", ThisGPT2Config) # AutoModel.register(ThisGPT2Config, ThisGPT2LMHeadModel) # AutoModelForCausalLM.register(ThisGPT2Config, ThisGPT2LMHeadModel) # AutoConfig.register("smallcap", SmallCapConfig) # AutoModel.register(SmallCapConfig, SmallCap) # model = AutoModel.from_pretrained("Yova/SmallCap7M") AutoConfig.register("this_opt", ThisOPTConfig) AutoModel.register(ThisOPTConfig, ThisOPTForCausalLM) AutoModelForCausalLM.register(ThisOPTConfig, ThisOPTForCausalLM) AutoConfig.register("smallcap", SmallCapConfig) AutoModel.register(SmallCapConfig, SmallCap) model = AutoModel.from_pretrained("Yova/SmallCapOPT7M") model= model.to(device) template = open('src/template.txt').read().strip() + ' ' # precompute captions for retrieval captions = json.load(open('coco_index_captions.json')) retrieval_model, feature_extractor_retrieval = clip.load("RN50x64", device=device) retrieval_index = faiss.read_index('coco_index') #res = faiss.StandardGpuResources() #retrieval_index = faiss.index_cpu_to_gpu(res, 0, retrieval_index) # Download human-readable labels for ImageNet. response = requests.get("https://git.io/JJkYN") labels = response.text.split("\n") def retrieve_caps(image_embedding, index, k=4): xq = image_embedding.astype(np.float32) faiss.normalize_L2(xq) D, I = index.search(xq, k) return I def classify_image(image): inp = transforms.ToTensor()(image) pixel_values_retrieval = feature_extractor_retrieval(image).to(device) with torch.no_grad(): image_embedding = retrieval_model.encode_image(pixel_values_retrieval.unsqueeze(0)).cpu().numpy() nns = retrieve_caps(image_embedding, retrieval_index)[0] caps = [captions[i] for i in nns][:4] # prepare prompt decoder_input_ids = prep_strings('', tokenizer, template=template, retrieved_caps=caps, k=4, is_test=True) # generate caption pixel_values = feature_extractor(image, return_tensors="pt").pixel_values with torch.no_grad(): pred = model.generate(pixel_values.to(device), decoder_input_ids=torch.tensor([decoder_input_ids]).to(device), max_new_tokens=25, no_repeat_ngram_size=0, length_penalty=0, min_length=1, num_beams=3, eos_token_id=tokenizer.eos_token_id) #inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp) #prediction = inception_net.predict(inp).flatten() retrieved_caps="Retrieved captions: \n{}\n{}\n{}\n{}".format(*caps) #return retrieved_caps + "\n\n\n Generated caption:\n" + str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer)) return str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer)) + "\n\n\n"+ retrieved_caps image = gr.Image(type="pil") textbox = gr.Textbox(placeholder="Generated caption and retrieved captions...", lines=4) title = "SmallCap Demo" gr.Interface( fn=classify_image, inputs=image, outputs=textbox, title=title ).launch()