Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import datetime | |
import json | |
from typing import Optional | |
import transformers | |
from dataclasses import dataclass, field | |
import io | |
import spaces | |
import base64 | |
from PIL import Image | |
import gradio as gr | |
import time | |
import hashlib | |
from utils import build_logger | |
from conversation import conv_seed_llama2 | |
import hydra | |
import pyrootutils | |
import torch | |
import re | |
import time | |
from omegaconf import OmegaConf | |
from flask import Flask | |
import json | |
from typing import Optional | |
import cv2 | |
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler, StableDiffusionImg2ImgPipeline | |
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
BOI_TOKEN = '<img>' | |
EOI_TOKEN = '</img>' | |
IMG_TOKEN = '<img_{:05d}>' | |
IMG_FLAG = '<image>' | |
num_img_in_tokens = 64 | |
num_img_out_tokens = 64 | |
instruction_prompt = '{instruction}' | |
resolution_grids = ['1x1', '1x2', '1x3', '1x4', '1x5', '1x6', '1x10', '2x1', '3x1', '4x1', '5x1', '6x1', '10x1', '2x2', | |
'2x3', '3x2', '2x4', '4x2'] | |
base_resolution = 448 | |
app = Flask(__name__) | |
def decode_image(encoded_image: str) -> Image: | |
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8')) | |
buffer = io.BytesIO(decoded_bytes) | |
image = Image.open(buffer) | |
return image | |
def encode_image(image: Image.Image, format: str = 'PNG') -> str: | |
with io.BytesIO() as buffer: | |
image.save(buffer, format=format) | |
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
return encoded_image | |
class Arguments: | |
# config.json 1 | |
image_transform: Optional[str] = field(default='configs/processer/qwen_448_transform.yaml', | |
metadata={"help": "config path of image transform"}) | |
tokenizer: Optional[str] = field(default='configs/tokenizer/clm_llama_tokenizer.yaml', | |
metadata={"help": "config path of tokenizer used to initialize tokenizer"}) | |
llm: Optional[str] = field(default='configs/clm_models/llama2chat7b_lora.yaml', metadata={"help": "config path of llm"}) | |
visual_encoder: Optional[str] = field(default='configs/visual_tokenizer/qwen_vitg_448.yaml', | |
metadata={"help": "config path of visual encoder"}) | |
sd_adapter: Optional[str] = field( | |
default='configs/detokenizer/detokenizer_sdxl_qwen_vit_adapted.yaml', | |
metadata={"help": "config path of sd adapter"}) | |
agent: Optional[str] = field(default='configs/clm_models/agent_7b_sft.yaml', | |
metadata={"help": "Hugging Face model path of agent model"}) | |
diffusion_path: Optional[str] = field(default='stabilityai/stable-diffusion-xl-base-1.0', | |
metadata={"help": "diffusion model path"}) | |
port: Optional[str] = field(default=80, metadata={"help": "network port"}) | |
llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"}) | |
vit_sd_device: Optional[str] = field(default='cuda:0', metadata={"help": "sd and vit device"}) | |
dtype: Optional[str] = field(default='fp16', metadata={"help": "mix percision"}) | |
parser = transformers.HfArgumentParser(Arguments) | |
args, = parser.parse_args_into_dataclasses() | |
class LLMService: | |
def __init__(self, args) -> None: | |
self.llm_device = args.llm_device | |
self.vit_sd_device = args.vit_sd_device | |
dtype = args.dtype | |
if dtype == 'fp16': | |
self.dtype = torch.float16 | |
elif dtype == 'bf16': | |
self.dtype = torch.bfloat16 | |
else: | |
raise ValueError | |
image_transform_cfg = OmegaConf.load(args.image_transform) | |
self.image_transform = hydra.utils.instantiate(image_transform_cfg) | |
tokenizer_cfg = OmegaConf.load(args.tokenizer) | |
self.tokenizer = hydra.utils.instantiate(tokenizer_cfg) | |
visual_encoder_cfg = OmegaConf.load(args.visual_encoder) | |
self.visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) | |
self.visual_encoder.eval().to(self.vit_sd_device, dtype=self.dtype) | |
print('Init visual encoder done') | |
llm_cfg = OmegaConf.load(args.llm) | |
llm = hydra.utils.instantiate(llm_cfg, torch_dtype=self.dtype) | |
print('Init llm done.') | |
agent_cfg = OmegaConf.load(args.agent) | |
self.agent = hydra.utils.instantiate(agent_cfg, llm=llm) | |
self.agent.eval().to(self.llm_device, dtype=self.dtype) | |
self.agent.llm.base_model.model.use_kv_cache_head = False | |
print('Init agent mdoel Done') | |
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.diffusion_path, subfolder="scheduler") | |
vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device, | |
dtype=self.dtype) | |
unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(self.vit_sd_device, | |
dtype=self.dtype) | |
sd_adapter_cfg = OmegaConf.load(args.sd_adapter) | |
self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(self.vit_sd_device, | |
dtype=self.dtype) | |
# self.sd_adapter.init_pipe(vae=vae, | |
# scheduler=noise_scheduler, | |
# visual_encoder=self.visual_encoder.cpu(), | |
# image_transform=self.image_transform, | |
# discrete_model=None, | |
# dtype=self.dtype, | |
# device="cpu") | |
self.sd_adapter.init_pipe(vae=vae, | |
scheduler=noise_scheduler, | |
visual_encoder=self.visual_encoder, | |
image_transform=self.image_transform, | |
discrete_model=None, | |
dtype=self.dtype, | |
device=self.vit_sd_device) | |
print('Init sd adapter pipe done.') | |
self.visual_encoder.to(self.vit_sd_device, dtype=self.dtype) | |
# model_id_or_path = "stablediffusionapi/realistic-vision-v51" | |
# self.vae_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, safety_checker=None, | |
# torch_dtype=torch.float16) | |
self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] | |
self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] | |
service = LLMService(args) | |
def generate(text_list, image_list, image_embed_list, max_new_tokens): | |
with torch.no_grad(): | |
print('text_list: {}'.format(text_list)) | |
text_list = text_list.split(IMG_FLAG) | |
text_list = [text_list[0]] + ["[INST]"+item for item in text_list[1:-1]] + [text_list[-1]] | |
top_p = 0.5 | |
window_size = 8 | |
assert len(text_list) == len(image_list) + 1 | |
image_tokens = BOI_TOKEN + ''.join( | |
[IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN | |
input_images = [] | |
if len(image_list) > 0: | |
image_tensor_list = [] | |
embeds_cmp_mask = [] | |
embeds_gen_mask = [] | |
for idx, image_item in enumerate(image_list): | |
if isinstance(image_item, str): | |
image = decode_image(image_item) | |
print('after decode image size:', image.size) | |
input_images.append(image) | |
image_tensor = service.image_transform(image) | |
image_tensor_list.append(image_tensor) | |
embeds_cmp_mask.append(True) | |
embeds_gen_mask.append(False) | |
else: | |
raise ValueError | |
# pixel_values = torch.stack(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) | |
# | |
# image_embeds = service.visual_encoder(pixel_values) | |
# image_embeds = image_embeds.to(service.llm_device) | |
print(image_embed_list) | |
image_embed_list = [t.squeeze(0) for t in image_embed_list] | |
image_embeds = torch.stack(image_embed_list, dim=0) | |
image_embeds = image_embeds.to(service.llm_device) | |
embeds_cmp_mask = torch.tensor(embeds_cmp_mask, dtype=torch.bool).to(service.llm_device) | |
embeds_gen_mask = torch.tensor(embeds_gen_mask, dtype=torch.bool).to(service.llm_device) | |
else: | |
image_embeds = None | |
patch_position = 0 | |
embeds_cmp_mask = None | |
embeds_gen_mask = None | |
input_text = image_tokens.join(text_list) | |
print('input_text fed to LLM:', input_text) | |
input_ids = service.tokenizer.encode(input_text, add_special_tokens=False) | |
while image_embeds.shape[0] > window_size: | |
eoi_prompt_idx = input_text.index(EOI_TOKEN) | |
input_text = input_text[eoi_prompt_idx + len(EOI_TOKEN) + len('[INST]'):] | |
image_embeds = image_embeds[1:] | |
input_ids = service.tokenizer.encode(input_text, add_special_tokens=False) | |
if image_embeds is not None: | |
embeds_cmp_mask = torch.tensor([True] * image_embeds.shape[0]).to(service.llm_device, dtype=torch.bool) | |
input_ids = [service.tokenizer.bos_token_id] + input_ids | |
input_ids = torch.tensor(input_ids).to(service.llm_device, dtype=torch.long) | |
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) | |
ids_gen_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) | |
boi_indices = torch.where(input_ids == service.boi_token_id)[0].tolist() | |
eoi_indices = torch.where(input_ids == service.eoi_token_id)[0].tolist() | |
for boi_idx, eoi_idx in zip(boi_indices, eoi_indices): | |
ids_cmp_mask[boi_idx + 1:eoi_idx] = True | |
input_ids = input_ids.unsqueeze(0) | |
ids_cmp_mask = ids_cmp_mask.unsqueeze(0) | |
ids_gen_mask = ids_gen_mask.unsqueeze(0) | |
error_msg = [] | |
print('image_embeds_shape: ' + str(image_embeds.shape)) | |
print('image_embeds: {}'.format(image_embeds)) | |
print('input_ids: ' + str(input_ids)) | |
print('ids_cmp_mask: ' + str(ids_cmp_mask)) | |
output = service.agent.generate( | |
tokenizer=service.tokenizer, | |
input_ids=input_ids, | |
image_embeds=image_embeds, | |
embeds_cmp_mask=embeds_cmp_mask, | |
ids_cmp_mask=ids_cmp_mask, | |
num_img_gen_tokens=num_img_out_tokens, | |
max_new_tokens=max_new_tokens, | |
dtype=service.dtype, | |
device=service.llm_device, | |
top_p=top_p, | |
) | |
gen_imgs_base64_list = [] | |
generated_text = output['text'] | |
torch.cuda.empty_cache() | |
if output['has_img_output']: | |
# print('loading visual encoder and llm to CPU, and sd to GPU') | |
# a = time.time() | |
# service.agent = service.agent.cpu() | |
# service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype) | |
# print("Loading finished: ", time.time() - a) | |
img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype) | |
for img_idx in range(output['num_gen_imgs']): | |
img_feat = img_gen_feat[img_idx:img_idx + 1] | |
generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0] | |
gen_imgs_base64_list.append(generated_image) | |
# a = time.time() | |
# service.sd_adapter = service.sd_adapter.cpu() | |
# service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype) | |
# service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype) | |
# print("Loading finished: ", time.time() - a) | |
print('[func generate inout+output]: {}'.format(input_text + generated_text)) | |
return {'text': generated_text, 'images': gen_imgs_base64_list, 'image_embeds': img_feat.detach().clone(), 'error_msg': error_msg} | |
def http_bot(dialog_state, input_state, max_new_tokens, max_length, | |
request: gr.Request): | |
print('input_state:', input_state) | |
print(dialog_state.messages) | |
if len(dialog_state.messages) == 0 or len( | |
dialog_state.messages[-1]['message']['text'].strip(' ?.;!/')) == 0: | |
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 | |
if len(dialog_state.messages) >= max_length: | |
output_state = init_input_state() | |
output_state['text'] = 'Error: History exceeds maximum rounds, please clear history and restart.' | |
dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state}) | |
input_state = init_input_state() | |
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 3 + (enable_btn,) | |
prompt = dialog_state.get_prompt() | |
text = prompt['text'] | |
print('text from http_bot: {}'.format(text)) | |
max_new_tokens = int(max_new_tokens) | |
images = prompt['images'] | |
image_embeds = prompt['image_embeds'] | |
results = generate(text, images, image_embeds, max_new_tokens) | |
generated_text = results['text'] | |
pattern = r' <img_000\d{2}>' | |
# Replace all occurrences of the pattern with the replacement text | |
generated_text = re.sub(pattern, '', generated_text) | |
generated_text = generated_text.replace(' '+service.tokenizer.eos_token, '')\ | |
.replace('[INST]', '').replace(' '+BOI_TOKEN, '').replace(' '+EOI_TOKEN, IMG_FLAG) | |
results['text'] = generated_text | |
print('response: ', {'text': results['text'], 'error_msg': results['error_msg']}) | |
output_state = init_input_state() | |
image_dir = get_conv_image_dir() | |
output_state['text'] = results['text'] | |
output_state['image_embeds'].append(results['image_embeds']) | |
for image_base64 in results['images']: | |
if image_base64 == '': | |
image_path = '' | |
else: | |
if isinstance(image_base64, Image.Image): | |
print('generated image is in Image.Image') | |
image = image_base64 | |
else: | |
print('generated image is in Image_base64') | |
image = decode_image(image_base64) | |
image = image.convert('RGB') | |
image_path = get_image_name(image=image, image_dir=image_dir) | |
if not os.path.exists(image_path): | |
image.save(image_path) | |
output_state['images'].append(image_path) | |
dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state}) | |
vote_last_response(dialog_state, 'common', request) | |
input_state = init_input_state() | |
chatbot = update_error_msg(dialog_state.to_gradio_chatbot(), results['error_msg']) | |
return (dialog_state, input_state, chatbot) + (enable_btn,) * 4 | |
IMG_FLAG = '<image>' | |
LOGDIR = 'log' | |
logger = build_logger("gradio_seed_story", LOGDIR) | |
headers = {"User-Agent": "SEED-Story Client"} | |
no_change_btn = gr.Button() | |
enable_btn = gr.Button(interactive=True) | |
disable_btn = gr.Button(interactive=False) | |
conv_seed_llama = conv_seed_llama2 | |
def get_conv_log_filename(): | |
t = datetime.datetime.now() | |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") | |
return name | |
def get_conv_image_dir(): | |
name = os.path.join(LOGDIR, 'images') | |
os.makedirs(name, exist_ok=True) | |
return name | |
def get_image_name(image, image_dir=None): | |
buffer = io.BytesIO() | |
image.save(buffer, format='PNG') | |
image_bytes = buffer.getvalue() | |
md5 = hashlib.md5(image_bytes).hexdigest() | |
if image_dir is not None: | |
image_name = os.path.join(image_dir, md5 + '.png') | |
else: | |
image_name = md5 + '.png' | |
return image_name | |
def resize_image_square(image, target_size=448): | |
resized_image = image.resize((target_size, target_size)) | |
return resized_image | |
def resize_image(image, max_size=512): | |
width, height = image.size | |
aspect_ratio = float(width) / float(height) | |
if width > height: | |
new_width = max_size | |
new_height = int(new_width / aspect_ratio) | |
else: | |
new_height = max_size | |
new_width = int(new_height * aspect_ratio) | |
resized_image = image.resize((new_width, new_height)) | |
return resized_image | |
def center_crop_image(image, max_aspect_ratio=1.5): | |
width, height = image.size | |
aspect_ratio = max(width, height) / min(width, height) | |
if aspect_ratio >= max_aspect_ratio: | |
if width > height: | |
new_width = int(height * max_aspect_ratio) | |
left = (width - new_width) // 2 | |
right = (width + new_width) // 2 | |
top = 0 | |
bottom = height | |
else: | |
new_height = int(width * max_aspect_ratio) | |
left = 0 | |
right = width | |
top = (height - new_height) // 2 | |
bottom = (height + new_height) // 2 | |
cropped_image = image.crop((left, top, right, bottom)) | |
return cropped_image | |
else: | |
return image | |
def vote_last_response(state, vote_type, request: gr.Request): | |
with open(get_conv_log_filename(), "a") as fout: | |
print(state) | |
print(state.dict()) | |
dic = state.dict() | |
for i in range(len(dic['messages'])): | |
dic['messages'][i]['message'].pop('image_embeds') | |
print(dic) | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"state": dic, | |
"ip": request.client.host, | |
} | |
fout.write(json.dumps(data) + "\n") | |
def upvote_last_response(state, request: gr.Request): | |
logger.info(f"upvote. ip: {request.client.host}") | |
vote_last_response(state, "upvote", request) | |
return (disable_btn,) * 2 | |
def downvote_last_response(state, request: gr.Request): | |
logger.info(f"downvote. ip: {request.client.host}") | |
vote_last_response(state, "downvote", request) | |
return (disable_btn,) * 2 | |
def regenerate(dialog_state, request: gr.Request): | |
logger.info(f"regenerate. ip: {request.client.host}") | |
if dialog_state.messages[-1]['role'] == dialog_state.roles[1]: | |
dialog_state.messages.pop() | |
return ( | |
dialog_state, | |
dialog_state.to_gradio_chatbot(), | |
) + (disable_btn,) * 4 | |
def clear_history(request: gr.Request): | |
logger.info(f"clear_history. ip: {request.client.host}") | |
dialog_state = conv_seed_llama.copy() | |
input_state = init_input_state() | |
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 | |
def init_input_state(): | |
return {'images': [], 'text': '', 'image_embeds': []} | |
def add_text(dialog_state, input_state, text, request: gr.Request): | |
logger.info(f"add_text. ip: {request.client.host}.") | |
if text is None or len(text) == 0: | |
return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 | |
input_state['text'] += text | |
if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]: | |
dialog_state.messages[-1]['message'] = input_state | |
else: | |
dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state}) | |
print('add_text: ', dialog_state.to_gradio_chatbot()) | |
return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 | |
def is_blank(image): | |
image_array = np.array(image) | |
unique_colors = np.unique(image_array) | |
print('unique_colors', len(unique_colors)) | |
return len(unique_colors) == 1 | |
def add_image(dialog_state, input_state, image, request: gr.Request): | |
logger.info(f"add_image. ip: {request.client.host}.") | |
if image is None: | |
return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 | |
image = image.convert('RGB') | |
print('image size:', image.size) | |
# image = center_crop_image(image, max_aspect_ratio=10) | |
image_dir = get_conv_image_dir() | |
image_path = get_image_name(image=image, image_dir=image_dir) | |
if not os.path.exists(image_path): | |
image.save(image_path) | |
input_state['images'].append(image_path) | |
image_tensor = service.image_transform(image).unsqueeze(0).to(service.llm_device, dtype=service.dtype) | |
image_embeds = service.visual_encoder(image_tensor).detach().clone() | |
image_embeds = image_embeds.to(service.llm_device) | |
input_state['image_embeds'].append(image_embeds) | |
input_state['text'] += IMG_FLAG | |
if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]: | |
dialog_state.messages[-1]['message'] = input_state | |
else: | |
dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state}) | |
print('add_image:', dialog_state) | |
return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 | |
def update_error_msg(chatbot, error_msg): | |
if len(error_msg) > 0: | |
info = '\n-------------\nSome errors occurred during response, please clear history and restart.\n' + '\n'.join( | |
error_msg) | |
chatbot[-1][-1] = chatbot[-1][-1] + info | |
return chatbot | |
def load_demo(request: gr.Request): | |
logger.info(f"load_demo. ip: {request.client.host}") | |
dialog_state = conv_seed_llama.copy() | |
input_state = init_input_state() | |
return dialog_state, input_state | |
title = (""" | |
# SEED-Story | |
[[Paper]](https://arxiv.org/abs/2407.08683) [[Code]](https://github.com/TencentARC/SEED-Story) | |
Demo of the multimodal story generation model SEED-Story-George. It is trained on StoryStream-Curious George subset. | |
SEED-Story is a MLLM capable of generating multimodal long stories consisting of rich and coherent narrative texts, along with images that are consistent in characters and style. | |
## Tips: | |
* Check out the conversation examples (at the bottom) for inspiration. | |
* Our demo requires a mix of an image and a starting sentence as input. You can freely upload an image or enter text, and then click on "Submit". Then, The model generates the next story image and text. | |
* You can click on "Continue Generation" to make the model generate a next story image and text based on all previous story boards. | |
* SEED-Story was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable. | |
""") | |
css = """ | |
img { | |
font-family: 'Helvetica'; | |
font-weight: 300; | |
line-height: 2; | |
text-align: center; | |
width: auto; | |
height: auto; | |
display: block; | |
position: relative; | |
} | |
img:before { | |
content: " "; | |
display: block; | |
position: absolute; | |
top: -10px; | |
left: 0; | |
height: auto; | |
width: 100%; | |
background-color: rgb(230, 230, 230); | |
border: 2px dotted rgb(200, 200, 200); | |
border-radius: 5px; | |
} | |
img:after { | |
content: " "; | |
display: block; | |
font-size: 16px; | |
font-style: normal; | |
font-family: FontAwesome; | |
color: rgb(100, 100, 100); | |
position: absolute; | |
top: 5px; | |
left: 0; | |
width: 100%; | |
text-align: center; | |
} | |
""" | |
if __name__ == '__main__': | |
examples_mix = [ | |
['https://github.com/TencentARC/SEED-Story/blob/master/assets/demo_examples/2.jpg?raw=true', | |
'One day, George, the curious brown monkey, decided to explore a new room. He peeked out from behind a dresser, looking both curious and cautious. The dresser had three drawers, each with a round handle. An electrical outlet was visible on the wall.'], | |
['https://github.com/TencentARC/SEED-Story/blob/master/assets/demo_examples/4.jpg?raw=true', | |
'In the bustling city, a beautiful blue and yellow bird took flight, soaring high above the buildings. Among the clouds, a heart-shaped formation appeared, as if nature was sending a love note to the world below. Other birds joined, their silhouettes dancing in the distance.'], | |
] | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(title) | |
dialog_state = gr.State() | |
input_state = gr.State() | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(): | |
image = gr.Image(type='pil', label='input_image') | |
with gr.Row(): | |
text = gr.Textbox(lines=5, | |
show_label=False, | |
label='input_text', | |
elem_id='textbox', | |
placeholder="Enter text and image, and press submit,", container=False) | |
with gr.Row(): | |
# add_image_btn = gr.Button("Add Image") | |
# add_text_btn = gr.Button("Add Text") | |
submit_btn = gr.Button("Submit") | |
continue_btn = gr.Button("Continue Generation") | |
with gr.Row(): | |
max_new_tokens = gr.Slider(minimum=64, | |
maximum=1024, | |
value=768, | |
step=64, | |
interactive=True, | |
label="Max Output Tokens") | |
max_length = gr.Slider(minimum=1, maximum=30, value=10, step=1, interactive=True, | |
label="Max Story Length") | |
with gr.Column(scale=7): | |
chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-Story", height=700) | |
with gr.Row(): | |
upvote_btn = gr.Button(value="π Upvote", interactive=False) | |
downvote_btn = gr.Button(value="π Downvote", interactive=False) | |
regenerate_btn = gr.Button(value="π Regenerate", interactive=False) | |
clear_btn = gr.Button(value="ποΈ Clear history", interactive=False) | |
with gr.Row(): | |
with gr.Column(scale=1.0): | |
gr.Examples(examples=examples_mix, label='Input examples', inputs=[image, text], cache_examples=False) | |
# Register listeners | |
btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn] | |
upvote_btn.click(upvote_last_response, [dialog_state], [upvote_btn, downvote_btn]) | |
downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn]) | |
regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then( | |
http_bot, [dialog_state, input_state, max_new_tokens, max_length], | |
[dialog_state, input_state, chatbot] + btn_list) | |
# add_image_btn.click(add_image, [dialog_state, input_state, image], | |
# [dialog_state, input_state, image, chatbot] + btn_list) | |
# | |
# add_text_btn.click(add_text, [dialog_state, input_state, text], | |
# [dialog_state, input_state, text, chatbot] + btn_list) | |
submit_btn.click( | |
add_text, [dialog_state, input_state, text], | |
[dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then( | |
add_image, [dialog_state, input_state, image], | |
[dialog_state, input_state, image, chatbot] + btn_list).then( | |
http_bot, | |
[dialog_state, input_state, max_new_tokens, max_length], | |
[dialog_state, input_state, chatbot] + btn_list) | |
continue_btn.click( | |
http_bot, | |
[dialog_state, input_state, max_new_tokens, max_length], | |
[dialog_state, input_state, chatbot] + btn_list) | |
clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list) | |
demo.load(load_demo, None, [dialog_state, input_state]) | |
demo.launch(debug=True) |