File size: 5,016 Bytes
f7b7142
1f43fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7b7142
 
 
 
1f43fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os, time, copy
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"

from PIL import Image

import gradio as gr

import numpy as np
import torch
from transformers import logging
logging.set_verbosity_error()

from fromage import models
from fromage import utils

BASE_WIDTH = 512
MODEL_DIR = './fromage_model/fromage_vis4'


class ChatBotCheese:
    def __init__(self):
        from huggingface_hub import hf_hub_download
        model_ckpt_path = hf_hub_download("alvanlii/fromage", "pretrained_ckpt.pth.tar")
        self.model = models.load_fromage(MODEL_DIR, model_ckpt_path)
        self.curr_image = None
        self.chat_history = ''

    def add_image(self, state, image_in):
        state = state + [(f"![](/file={image_in.name})", "Ok, now type your message")]
        self.curr_image = Image.open(image_in.name).convert('RGB')
        return state, state

    def save_im(self, image_pil):
        file_name = f"{int(time.time())}_{np.random.randint(100)}.png"
        image_pil.save(file_name)
        return file_name

    def chat(self, input_text, state, ret_scale_factor, num_ims, num_words, temp):
        # model_outputs = ["heyo", []]
        self.chat_history += f'Q: {input_text} \nA:'
        if self.curr_image is not None:
            model_outputs = self.model.generate_for_images_and_texts([self.curr_image, self.chat_history], num_words=num_words, max_num_rets=num_ims, ret_scale_factor=ret_scale_factor, temperature=temp)
        else:
            model_outputs = self.model.generate_for_images_and_texts([self.chat_history], max_num_rets=num_ims, num_words=num_words, ret_scale_factor=ret_scale_factor, temperature=temp)
        self.chat_history += ' '.join([s for s in model_outputs if type(s) == str]) + '\n'

        im_names = []
        if len(model_outputs) > 1:
            im_names = [self.save_im(im) for im in model_outputs[1]]

        response = model_outputs[0] 
        for im_name in im_names:
            response += f'<img src="/file={im_name}">'
        state.append((input_text, response.replace("[RET]", "")))
        self.curr_image = None
        return state, state    

    def reset(self):
        self.chat_history = ""
        self.curr_image = None
        return [], []

    def main(self):
        with gr.Blocks(css="#chatbot .overflow-y-auto{height:1500px}") as demo:
            gr.Markdown(
                """
                ## FROMAGe
                ### Grounding Language Models to Images for Multimodal Generation
                Jing Yu Koh, Ruslan Salakhutdinov, Daniel Fried <br/>
                [Paper](https://arxiv.org/abs/2301.13823) [Github](https://github.com/kohjingyu/fromage) <br/>
                - Instructions:
                  - [Optional] Upload an image
                  - [Optional] Change the parameters
                  - Send a message by typing into the box and pressing Enter on your keyboard
                - Check out the examples at the bottom!
                """
            )

            chatbot = gr.Chatbot(elem_id="chatbot")
            gr_state = gr.State([])

            with gr.Row():
                with gr.Column(scale=0.85):
                    txt = gr.Textbox(show_label=False, placeholder="Upload an image first [Optional]. Then enter text and press enter,").style(container=False)
                with gr.Column(scale=0.15, min_width=0):
                    btn = gr.UploadButton("🖼️", file_types=["image"])     

            with gr.Row():
                with gr.Column(scale=0.20, min_width=0):
                    reset_btn = gr.Button("Reset Messages")
                gr_ret_scale_factor = gr.Number(value=1.0, label="Increased prob of returning images", interactive=True)
                gr_num_ims = gr.Number(value=3, precision=1, label="Max # of Images returned", interactive=True)
                gr_num_words = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
                gr_temp = gr.Number(value=0.0, label="Temperature", interactive=True)

            with gr.Row():
                gr.Image("example_1.png", label="Example 1")
                gr.Image("example_2.png", label="Example 2")
                gr.Image("example_3.png", label="Example 3")
                

            txt.submit(self.chat, [txt, gr_state, gr_ret_scale_factor, gr_num_ims, gr_num_words, gr_temp], [gr_state, chatbot])
            txt.submit(lambda :"", None, txt)
            btn.upload(self.add_image, [gr_state, btn], [gr_state, chatbot])
            reset_btn.click(self.reset, [], [gr_state, chatbot])

            # chatbot.change(fn = upload_button_config, outputs=btn_upload)
            # text_in.submit(None, [], [], _js = "() => document.getElementById('#chatbot-component').scrollTop = document.getElementById('#chatbot-component').scrollHeight")

        demo.launch(share=False, server_name="0.0.0.0")

def main():
    cheddar = ChatBotCheese()
    cheddar.main()

if __name__ == "__main__":
    cheddar = ChatBotCheese()
    cheddar.main()