File size: 5,480 Bytes
f7b7142 1f43fd8 00d336f 6ee4b37 1f43fd8 6ee4b37 1f43fd8 6ee4b37 00d336f 1f43fd8 00d336f 1f43fd8 b836d7e 1f43fd8 abb3b26 1f43fd8 abb3b26 1f43fd8 00d336f abb3b26 b722a02 f7b7142 b722a02 1f43fd8 b722a02 de65f8b 1f43fd8 00d336f 1f43fd8 00d336f 1f43fd8 b836d7e 1f43fd8 00d336f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os, time, copy
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
from PIL import Image
import gradio as gr
import numpy as np
import torch
from transformers import logging
logging.set_verbosity_error()
from fromage import models
from fromage import utils
BASE_WIDTH = 512
MODEL_DIR = './fromage_model/fromage_vis4'
class ChatBotCheese:
def __init__(self):
from huggingface_hub import hf_hub_download
model_ckpt_path = hf_hub_download("alvanlii/fromage", "pretrained_ckpt.pth.tar")
self.model = models.load_fromage(MODEL_DIR, model_ckpt_path)
self.curr_image = None
def add_image(self, state, image_in):
state = state + [(f"![](/file={image_in.name})", "Ok, now type your message")]
self.curr_image = Image.open(image_in.name).convert('RGB')
return state, state
def save_im(self, image_pil):
file_name = f"{int(time.time())}_{np.random.randint(100)}.png"
image_pil.save(file_name)
return file_name
def chat(self, input_text, state, ret_scale_factor, num_ims, num_words, temp, chat_state):
chat_state.append(f'Q: {input_text} \nA:')
chat_history = " ".join(chat_state)
model_input = []
print(chat_history)
if self.curr_image is not None:
model_input = [self.curr_image, chat_history]
else:
model_input = [chat_history]
model_outputs = self.model.generate_for_images_and_texts(model_input, max_num_rets=num_ims, num_words=num_words, ret_scale_factor=ret_scale_factor, temperature=temp)
chat_state.append(' '.join([s for s in model_outputs if type(s) == str]) + '\n')
im_names = []
if len(model_outputs) > 1:
im_names = [self.save_im(im) for im in model_outputs[1]]
response = model_outputs[0]
for im_name in im_names:
response += f'<img src="/file={im_name}">'
state.append((input_text, response.replace("[RET]", "")))
# self.curr_image = None
return state, state, chat_state
def reset(self):
self.curr_image = None
return [], [], []
def main(self):
with gr.Blocks(css="#chatbot {height:600px; overflow-y:auto;}") as demo:
gr.Markdown(
"""
### FROMAGe: Grounding Language Models to Images for Multimodal Generation
Jing Yu Koh, Ruslan Salakhutdinov, Daniel Fried <br/>
[Paper](https://arxiv.org/abs/2301.13823) [Github](https://github.com/kohjingyu/fromage) [Official Demo](https://huggingface.co/spaces/jykoh/fromage) <br/>
This is an unofficial Gradio demo for the paper FROMAGe <br/>
- Instructions (in order):
- [Optional] Upload an image (the button with a photo emoji)
- [Optional] Change the parameters
- Send a message by typing into the box and pressing Enter on your keyboard
- Ask about the image! Tell it to find similar images, or ones with different styles.
- Check out the examples at the bottom!
##### Notes
- Please be kind to it!
- It retrieves images from a database, and does not edit images
- If it returns nothing, try resetting and refreshing the page
"""
)
chatbot = gr.Chatbot(elem_id="chatbot")
gr_state = gr.State([])
gr_chat_state = gr.State([])
with gr.Row():
with gr.Column(scale=0.85):
txt = gr.Textbox(show_label=False, placeholder="Upload an image first [Optional]. Then enter text and press enter,").style(container=False)
with gr.Column(scale=0.15, min_width=0):
btn = gr.UploadButton("🖼️", file_types=["image"])
with gr.Row():
with gr.Column(scale=0.20, min_width=0):
reset_btn = gr.Button("Reset Messages")
gr_ret_scale_factor = gr.Number(value=1.0, label="Increased prob of returning images", interactive=True)
gr_num_ims = gr.Number(value=3, precision=1, label="Max # of Images returned", interactive=True)
gr_num_words = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
gr_temp = gr.Number(value=0.0, label="Temperature", interactive=True)
with gr.Row():
gr.Image("example_1.png", label="Example 1")
gr.Image("example_2.png", label="Example 2")
gr.Image("example_3.png", label="Example 3")
txt.submit(self.chat, [txt, gr_state, gr_ret_scale_factor, gr_num_ims, gr_num_words, gr_temp, gr_chat_state], [gr_state, chatbot, gr_chat_state])
txt.submit(lambda :"", None, txt)
btn.upload(self.add_image, [gr_state, btn], [gr_state, chatbot])
reset_btn.click(self.reset, [], [gr_state, chatbot, gr_chat_state])
# chatbot.change(fn = upload_button_config, outputs=btn_upload)
# text_in.submit(None, [], [], _js = "() => document.getElementById('#chatbot-component').scrollTop = document.getElementById('#chatbot-component').scrollHeight")
demo.launch(share=False, server_name="0.0.0.0")
def main():
cheddar = ChatBotCheese()
cheddar.main()
if __name__ == "__main__":
main() |