File size: 3,316 Bytes
120a3c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30474d6
 
 
 
120a3c2
 
 
 
 
30474d6
120a3c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from PIL import Image

import requests
import json
import gradio as gr


from io import BytesIO

def encode_image(image):
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    buffered.seek(0)

    return buffered


def query_api(image, prompt, decoding_method):
    # local host for testing
    url = "http://34.132.142.70:5000/api/generate"

    headers = {
        'User-Agent': 'BLIP-2 HuggingFace Space'
    }

    data = {"prompt": prompt, "use_nucleus_sampling": decoding_method == "Nucleus sampling"}
    
    image = encode_image(image)
    files = {"image": image}

    response = requests.post(url, data=data, files=files, headers=headers)

    if response.status_code == 200:
        return response.json()
    else:
        return "Error: " + response.text


def prepend_question(text):
    text = text.strip().lower()

    return "question: " + text
    

def prepend_answer(text):
    text = text.strip().lower()

    return "answer: " + text


def get_prompt_from_history(history):
    prompts = []

    for i in range(len(history)):
        if i % 2 == 0:
            prompts.append(prepend_question(history[i]))
        else:
            prompts.append(prepend_answer(history[i]))
    
    return "\n".join(prompts)


def postp_answer(text):
    if text.startswith("answer: "):
        return text[8:]
    elif text.startswith("a: "):
        return text[2:]
    else:
        return text


def prep_question(text):
    if text.startswith("question: "):
        text = text[10:]
    elif text.startswith("q: "):
        text = text[2:]
    
    if not text.endswith("?"):
        text += "?"
    
    return text


def inference(image, text_input, decoding_method, history=[]):
    text_input = prep_question(text_input)
    history.append(text_input)

    # prompt = '\n'.join(history)
    prompt = get_prompt_from_history(history)
    # print("prompt: " + prompt)

    output = query_api(image, prompt, decoding_method)
    output = [postp_answer(output[0])]
    history += output
    
    chat = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)]  # convert to tuples of list
        
    return chat, history


inputs = [gr.inputs.Image(type='pil'),
          gr.inputs.Textbox(lines=2, label="Text input"),
          gr.inputs.Radio(choices=['Nucleus sampling','Beam search'], type="value", default="Nucleus sampling", label="Text Decoding Method"),
          "state",
         ]

outputs = ["chatbot", "state"]
           
title = "BLIP-2"
description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p> 
<p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>"

iface = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article)
iface.launch(enable_queue=True)