Spaces:
Runtime error
Runtime error
File size: 4,771 Bytes
2773523 b8c788a 2773523 c57b6d0 2773523 cbdd0b8 1b1e4db 2773523 3a227f4 800cbf3 2773523 2d29620 1b1e4db 05c2134 1b1e4db bd5ba96 d522bbe b8c788a 740fb26 336c80c 3072768 38ab1e3 3072768 2773523 f61d812 76c8f3a d6b2a16 76c8f3a 2773523 6a931bc 740fb26 b8c788a 5dd3d52 1c3da59 f61d812 b8c788a 2773523 2d29620 05c2134 2773523 6a931bc b8c788a 6a931bc 336c80c 8b3c656 2773523 800cbf3 b8c788a 2773523 329d18e b8c788a 6334863 2773523 329d18e 2773523 05c2134 2773523 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel, InstructBlipForConditionalGeneration
import torch
import open_clip
from huggingface_hub import hf_hub_download
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')
git_processor_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
git_model_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco").to(device)
blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
blip2_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b-coco")
blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b-coco", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16)
instructblip_processor = AutoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
instructblip_model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16)
def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
inputs = processor(images=image, return_tensors="pt").to(device)
if use_float_16:
inputs = inputs.to(torch.float16)
generated_ids = model.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
if tokenizer is not None:
generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
def generate_caption_blip2(processor, model, image, replace_token=False):
prompt = "A photo of"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=model.device, dtype=torch.float16)
generated_ids = model.generate(**inputs,
num_beams=5, max_length=50, min_length=1, top_p=0.9,
repetition_penalty=1.5, length_penalty=1.0, temperature=1)
if replace_token:
# TODO remove once https://github.com/huggingface/transformers/pull/24492 is merged
generated_ids[generated_ids == 0] = 2
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
def generate_captions(image):
caption_git_large_coco = generate_caption(git_processor_large_coco, git_model_large_coco, image)
caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)
caption_blip2 = generate_caption_blip2(blip2_processor, blip2_model, image).strip()
caption_instructblip = generate_caption_blip2(instructblip_processor, instructblip_model, image, replace_token=True)
return caption_git_large_coco, caption_blip_large, caption_blip2, caption_instructblip
examples = [["cats.jpg"], ["stop_sign.png"], ["astronaut.jpg"]]
outputs = [gr.outputs.Textbox(label="Caption generated by GIT-large fine-tuned on COCO"), gr.outputs.Textbox(label="Caption generated by BLIP-large"), gr.outputs.Textbox(label="Caption generated by BLIP-2 OPT 6.7b"), gr.outputs.Textbox(label="Caption generated by InstructBLIP"), ]
title = "Interactive demo: comparing image captioning models"
description = "Gradio Demo to compare GIT, BLIP, BLIP-2 and InstructBLIP, 4 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"
interface = gr.Interface(fn=generate_captions,
inputs=gr.inputs.Image(type="pil"),
outputs=outputs,
examples=examples,
title=title,
description=description,
article=article,
enable_queue=True)
interface.launch(debug=True) |