File size: 4,771 Bytes
2773523
b8c788a
2773523
c57b6d0
2773523
cbdd0b8
 
1b1e4db
 
2773523
3a227f4
800cbf3
2773523
2d29620
1b1e4db
05c2134
 
1b1e4db
bd5ba96
d522bbe
 
b8c788a
 
740fb26
336c80c
3072768
38ab1e3
3072768
 
 
2773523
f61d812
76c8f3a
 
d6b2a16
76c8f3a
 
2773523
 
 
 
6a931bc
 
740fb26
b8c788a
5dd3d52
1c3da59
 
f61d812
 
 
b8c788a
 
 
 
2773523
2d29620
05c2134
 
2773523
6a931bc
b8c788a
6a931bc
336c80c
8b3c656
2773523
 
800cbf3
b8c788a
2773523
329d18e
b8c788a
6334863
2773523
329d18e
2773523
05c2134
2773523
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel, InstructBlipForConditionalGeneration
import torch
import open_clip

from huggingface_hub import hf_hub_download

device = "cuda" if torch.cuda.is_available() else "cpu"

torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')

git_processor_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
git_model_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco").to(device)

blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)

blip2_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b-coco")
blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b-coco", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16)

instructblip_processor = AutoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
instructblip_model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16)

def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
    inputs = processor(images=image, return_tensors="pt").to(device)

    if use_float_16:
        inputs = inputs.to(torch.float16)
    
    generated_ids = model.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5) 

    if tokenizer is not None:
        generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    else:
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
   
    return generated_caption


def generate_caption_blip2(processor, model, image, replace_token=False):    
    prompt = "A photo of"
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=model.device, dtype=torch.float16)

    generated_ids = model.generate(**inputs,
                                   num_beams=5, max_length=50, min_length=1, top_p=0.9,
                                   repetition_penalty=1.5, length_penalty=1.0, temperature=1)
    if replace_token:
        # TODO remove once https://github.com/huggingface/transformers/pull/24492 is merged
        generated_ids[generated_ids == 0] = 2 
    
    return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]


def generate_captions(image):
    caption_git_large_coco = generate_caption(git_processor_large_coco, git_model_large_coco, image)

    caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)

    caption_blip2 = generate_caption_blip2(blip2_processor, blip2_model, image).strip()

    caption_instructblip = generate_caption_blip2(instructblip_processor, instructblip_model, image, replace_token=True)

    return caption_git_large_coco, caption_blip_large, caption_blip2, caption_instructblip

   
examples = [["cats.jpg"], ["stop_sign.png"], ["astronaut.jpg"]]
outputs = [gr.outputs.Textbox(label="Caption generated by GIT-large fine-tuned on COCO"), gr.outputs.Textbox(label="Caption generated by BLIP-large"), gr.outputs.Textbox(label="Caption generated by BLIP-2 OPT 6.7b"), gr.outputs.Textbox(label="Caption generated by InstructBLIP"), ] 

title = "Interactive demo: comparing image captioning models"
description = "Gradio Demo to compare GIT, BLIP, BLIP-2 and InstructBLIP, 4 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"

interface = gr.Interface(fn=generate_captions, 
                         inputs=gr.inputs.Image(type="pil"),
                         outputs=outputs,
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         enable_queue=True)
interface.launch(debug=True)