chats-bug
Added git base coco
9cce4c8
import traceback
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, VisionEncoderDecoderModel, BitsAndBytesConfig
import torch
import open_clip
from PIL import Image
import requests
from huggingface_hub import hf_hub_download
# Load the Blip base model
preprocessor_blip_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model_blip_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# Load the Blip large model
preprocessor_blip_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model_blip_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
# Load the GIT coco base model
preprocessor_git_base_coco = AutoProcessor.from_pretrained("microsoft/git-base-coco")
model_git_base_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
# Load the GIT coco large model
preprocessor_git_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
model_git_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
# Load the CLIP model
model_oc_coca, _, transform_oc_coca = open_clip.create_model_and_transforms(
model_name="coca_ViT-L-14",
pretrained="mscoco_finetuned_laion2B-s13B-b90k"
)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Transfer the models to the device
model_blip_base.to(device)
model_blip_large.to(device)
model_git_base_coco.to(device)
model_git_large_coco.to(device)
model_oc_coca.to(device)
def generate_caption(
preprocessor,
model,
image,
tokenizer=None,
):
"""
Generate captions for the given image.
-----
Parameters
preprocessor: AutoProcessor
The preprocessor for the model.
model: BlipForConditionalGeneration
The model to use.
image: PIL.Image
The image to generate captions for.
tokenizer: AutoTokenizer
The tokenizer to use. If None, the default tokenizer for the model will be used.
use_float_16: bool
Whether to use float16 precision. This can speed up inference, but may lead to worse results.
-----
Returns
str
The generated caption.
"""
pixel_values = preprocessor(images=image, return_tensors="pt").pixel_values.to(device)
generated_ids = model.generate(
pixel_values=pixel_values,
max_length=50,
)
if tokenizer is None:
generated_caption = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# generated_ids = model.generate(**inputs, max_new_tokens=32)
# generated_text = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return generated_caption
def generate_captions_clip(
model,
transform,
image
):
"""
Generate captions for the given image using CLIP.
-----
Parameters
model: VisionEncoderDecoderModel
The CLIP model to use.
transform: Callable
The transform to apply to the image before passing it to the model.
image: PIL.Image
The image to generate captions for.
-----
Returns
str
The generated caption.
"""
im = transform(image).unsqueeze(0).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
generated = model.generate(im, seq_len=20)
generated_caption = open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
return generated_caption
def generate_captions(
image,
max_length,
temperature,
):
"""
Generate captions for the given image.
-----
Parameters
image: PIL.Image
The image to generate captions for.
-----
Returns
str
The generated caption.
"""
caption_blip_base = ""
caption_blip_large = ""
caption_git_large_coco = ""
caption_oc_coca = ""
# Generate captions for the image using the Blip base model
try:
caption_blip_base = generate_caption(preprocessor_blip_base, model_blip_base, image).strip()
except Exception as e:
print(e)
# Generate captions for the image using the Blip large model
try:
caption_blip_large = generate_caption(preprocessor_blip_large, model_blip_large, image).strip()
except Exception as e:
print(e)
# Generate captions for the image using the GIT coco base model
try:
caption_git_base_coco = generate_caption(preprocessor_git_base_coco, model_git_base_coco, image).strip()
except Exception as e:
print(e)
# Generate captions for the image using the GIT coco large model
try:
caption_git_large_coco = generate_caption(preprocessor_git_large_coco, model_git_large_coco, image).strip()
except Exception as e:
print(e)
# Generate captions for the image using the CLIP model
try:
caption_oc_coca = generate_captions_clip(model_oc_coca, transform_oc_coca, image).strip()
except Exception as e:
print(e)
return caption_blip_base, caption_blip_large, caption_git_base_coco, caption_git_large_coco, caption_oc_coca
# Create the interface
iface = gr.Interface(
fn=generate_captions,
# Define the inputs: Image, Slider for Max Length, Slider for Temperature
inputs=[
gr.inputs.Image(type="pil", label="Image"),
gr.inputs.Slider(minimum=16, maximum=64, step=2, default=32, label="Max Length"),
gr.inputs.Slider(minimum=0.5, maximum=1.5, step=0.1, default=1.0, label="Temperature"),
],
# Define the outputs
outputs=[
gr.outputs.Textbox(label="Blip base"),
gr.outputs.Textbox(label="Blip large"),
gr.outputs.Textbox(label="GIT base coco"),
gr.outputs.Textbox(label="GIT large coco"),
gr.outputs.Textbox(label="CLIP"),
],
title="Image Captioning",
description="Generate captions for images using the Blip2 model, the Blip base model, the Blip large model, the GIT large coco model, and the CLIP model.",
enable_queue=True,
)
# Launch the interface
iface.launch(debug=True)