Spaces:
Runtime error
Runtime error
import traceback | |
import gradio as gr | |
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, VisionEncoderDecoderModel, BitsAndBytesConfig | |
import torch | |
import open_clip | |
from PIL import Image | |
import requests | |
from huggingface_hub import hf_hub_download | |
# Load the Blip base model | |
preprocessor_blip_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
model_blip_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
# Load the Blip large model | |
preprocessor_blip_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
model_blip_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") | |
# Load the GIT coco base model | |
preprocessor_git_base_coco = AutoProcessor.from_pretrained("microsoft/git-base-coco") | |
model_git_base_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco") | |
# Load the GIT coco large model | |
preprocessor_git_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco") | |
model_git_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") | |
# Load the CLIP model | |
model_oc_coca, _, transform_oc_coca = open_clip.create_model_and_transforms( | |
model_name="coca_ViT-L-14", | |
pretrained="mscoco_finetuned_laion2B-s13B-b90k" | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Transfer the models to the device | |
model_blip_base.to(device) | |
model_blip_large.to(device) | |
model_git_base_coco.to(device) | |
model_git_large_coco.to(device) | |
model_oc_coca.to(device) | |
def generate_caption( | |
preprocessor, | |
model, | |
image, | |
tokenizer=None, | |
): | |
""" | |
Generate captions for the given image. | |
----- | |
Parameters | |
preprocessor: AutoProcessor | |
The preprocessor for the model. | |
model: BlipForConditionalGeneration | |
The model to use. | |
image: PIL.Image | |
The image to generate captions for. | |
tokenizer: AutoTokenizer | |
The tokenizer to use. If None, the default tokenizer for the model will be used. | |
use_float_16: bool | |
Whether to use float16 precision. This can speed up inference, but may lead to worse results. | |
----- | |
Returns | |
str | |
The generated caption. | |
""" | |
pixel_values = preprocessor(images=image, return_tensors="pt").pixel_values.to(device) | |
generated_ids = model.generate( | |
pixel_values=pixel_values, | |
max_length=50, | |
) | |
if tokenizer is None: | |
generated_caption = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
else: | |
generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# generated_ids = model.generate(**inputs, max_new_tokens=32) | |
# generated_text = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
return generated_caption | |
def generate_captions_clip( | |
model, | |
transform, | |
image | |
): | |
""" | |
Generate captions for the given image using CLIP. | |
----- | |
Parameters | |
model: VisionEncoderDecoderModel | |
The CLIP model to use. | |
transform: Callable | |
The transform to apply to the image before passing it to the model. | |
image: PIL.Image | |
The image to generate captions for. | |
----- | |
Returns | |
str | |
The generated caption. | |
""" | |
im = transform(image).unsqueeze(0).to(device) | |
with torch.no_grad(), torch.cuda.amp.autocast(): | |
generated = model.generate(im, seq_len=20) | |
generated_caption = open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "") | |
return generated_caption | |
def generate_captions( | |
image, | |
max_length, | |
temperature, | |
): | |
""" | |
Generate captions for the given image. | |
----- | |
Parameters | |
image: PIL.Image | |
The image to generate captions for. | |
----- | |
Returns | |
str | |
The generated caption. | |
""" | |
caption_blip_base = "" | |
caption_blip_large = "" | |
caption_git_large_coco = "" | |
caption_oc_coca = "" | |
# Generate captions for the image using the Blip base model | |
try: | |
caption_blip_base = generate_caption(preprocessor_blip_base, model_blip_base, image).strip() | |
except Exception as e: | |
print(e) | |
# Generate captions for the image using the Blip large model | |
try: | |
caption_blip_large = generate_caption(preprocessor_blip_large, model_blip_large, image).strip() | |
except Exception as e: | |
print(e) | |
# Generate captions for the image using the GIT coco base model | |
try: | |
caption_git_base_coco = generate_caption(preprocessor_git_base_coco, model_git_base_coco, image).strip() | |
except Exception as e: | |
print(e) | |
# Generate captions for the image using the GIT coco large model | |
try: | |
caption_git_large_coco = generate_caption(preprocessor_git_large_coco, model_git_large_coco, image).strip() | |
except Exception as e: | |
print(e) | |
# Generate captions for the image using the CLIP model | |
try: | |
caption_oc_coca = generate_captions_clip(model_oc_coca, transform_oc_coca, image).strip() | |
except Exception as e: | |
print(e) | |
return caption_blip_base, caption_blip_large, caption_git_base_coco, caption_git_large_coco, caption_oc_coca | |
# Create the interface | |
iface = gr.Interface( | |
fn=generate_captions, | |
# Define the inputs: Image, Slider for Max Length, Slider for Temperature | |
inputs=[ | |
gr.inputs.Image(type="pil", label="Image"), | |
gr.inputs.Slider(minimum=16, maximum=64, step=2, default=32, label="Max Length"), | |
gr.inputs.Slider(minimum=0.5, maximum=1.5, step=0.1, default=1.0, label="Temperature"), | |
], | |
# Define the outputs | |
outputs=[ | |
gr.outputs.Textbox(label="Blip base"), | |
gr.outputs.Textbox(label="Blip large"), | |
gr.outputs.Textbox(label="GIT base coco"), | |
gr.outputs.Textbox(label="GIT large coco"), | |
gr.outputs.Textbox(label="CLIP"), | |
], | |
title="Image Captioning", | |
description="Generate captions for images using the Blip2 model, the Blip base model, the Blip large model, the GIT large coco model, and the CLIP model.", | |
enable_queue=True, | |
) | |
# Launch the interface | |
iface.launch(debug=True) |