## All Generation Gradio Interface import uuid import time from .utils import * from .vote_utils import t2s_logger, t2s_multi_logger, i2s_logger, i2s_multi_logger from .constants import IMAGE_DIR, OFFLINE_DIR, TEXT_PROMPT_PATH with open(TEXT_PROMPT_PATH, 'r') as f: prompt_list = json.load(f) class State: def __init__(self, model_name, i2s_mode=False, offline=False, prompt=None, image=None, offline_idx=None, normal_video=None , rgb_video=None): self.conv_id = uuid.uuid4().hex self.model_name = model_name self.i2s_mode = i2s_mode self.offline = offline self.prompt = prompt self.image = image self.offline_idx = offline_idx # self.output = None self.normal_video = normal_video self.rgb_video = rgb_video def dict(self): base = { "conv_id": self.conv_id, "model_name": self.model_name, "i2s_mode": self.i2s_mode, "offline": self.offline, "prompt": self.prompt } if not self.offline and not self.offline_idx: base['offline_idx'] = self.offline_idx return base # class StateI2S: # def __init__(self, model_name): # self.conv_id = uuid.uuid4().hex # self.model_name = model_name # self.image = None # self.output = None # def dict(self): # base = { # "conv_id": self.conv_id, # "model_name": self.model_name, # } # return base def sample_t2s_model(state_0, state_1, model_list): model_name_0, model_name_1 = random.sample(eval(model_list), 2) if state_0 is None: state_0 = State(model_name_0, i2s_mode=False) if state_1 is None: state_1 = State(model_name_1, i2s_mode=False) state_0.model_name = model_name_0 state_0.i2s_mode = False state_1.model_name = model_name_1 state_1.i2s_mode = False return state_0, state_1, model_name_0, model_name_1 def sample_i2s_model(state_0, state_1, model_list): model_name_0, model_name_1 = random.sample(eval(model_list), 2) if state_0 is None: state_0 = State(model_name_0, i2s_mode=True) if state_1 is None: state_1 = State(model_name_1, i2s_mode=True) state_0.model_name = model_name_0 state_0.i2s_mode = True state_1.model_name = model_name_1 state_1.i2s_mode = True return state_0, state_1, model_name_0, model_name_1 def sample_prompt(state, model_name): if state is None: state = State(model_name) idx = random.randint(0, len(prompt_list)-1) prompt = prompt_list[idx] state.model_name = model_name state.prompt = prompt return state, prompt def sample_prompt_side_by_side(state_0, state_1, model_name_0, model_name_1): if state_0 is None: state_0 = State(model_name_0) if state_1 is None: state_1 = State(model_name_1) idx = random.randint(0, len(prompt_list)-1) prompt = prompt_list[idx] state_0.offline, state_1.offline = True, True state_0.offline_idx, state_1.offline_idx = idx, idx state_0.prompt, state_1.prompt = prompt, prompt return state_0, state_1, prompt def sample_image(state, model_name): if state is None: state = State(model_name) idx = random.randint(0, len(prompt_list)-1) prompt = prompt_list[idx] state.model_name = model_name state.prompt = prompt return state, prompt def sample_image_side_by_side(state_0, state_1, model_name_0, model_name_1): if state_0 is None: state_0 = State(model_name_0) if state_1 is None: state_1 = State(model_name_1) idx = random.randint(0, len(prompt_list)-1) prompt = prompt_list[idx] state_0.offline, state_1.offline = True, True state_0.offline_idx, state_1.offline_idx = idx, idx state_0.prompt, state_1.prompt = prompt, prompt return state_0, state_1, prompt def generate_t2s(gen_func, render_func, state, text, model_name, request: gr.Request): if not text: raise gr.Warning("Prompt cannot be empty.") if not model_name: raise gr.Warning("Model name cannot be empty.") if state is None: state = State(model_name, i2s_mode=False, offline=False) ip = get_ip(request) t2s_logger.info(f"generate. ip: {ip}") state.model_name = model_name state.prompt = text try: idx = prompt_list.index(text) state.offline = True state.offline_idx = idx except: state.offline = False state.offline_idx = None if not state.offline and not state.offline_idx: start_time = time.time() normal_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "normal", f"{state.offline_idx}.mp4") rgb_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "rgb", f"{state.offline_idx}.mp4") state.normal_video = normal_video state.rgb_video = rgb_video yield state, normal_video, rgb_video # logger.info(f"===output===: {output}") data = { "ip": ip, "model": model_name, "type": "offline", "gen_params": {}, "state": state.dict(), "start": round(start_time, 4), } else: start_time = time.time() shape = gen_func(text, model_name) generate_time = time.time() - start_time normal_video, rgb_video = render_func(shape, model_name) finish_time = time.time() render_time = finish_time - start_time - generate_time state.normal_video = normal_video state.rgb_video = rgb_video yield state, normal_video, rgb_video # logger.info(f"===output===: {output}") data = { "ip": ip, "model": model_name, "type": "online", "gen_params": {}, "state": state.dict(), "start": round(start_time, 4), "time": round(finish_time - start_time, 4), "generate_time": round(generate_time, 4), "render_time": round(render_time, 4), } with open(get_conv_log_filename(), "a") as fout: fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' # os.makedirs(os.path.dirname(output_file), exist_ok=True) # with open(output_file, 'w') as f: # state.output.save(f, 'PNG') # save_image_file_on_log_server(output_file) def generate_t2s_multi(gen_func, render_func, state_0, state_1, text, model_name_0, model_name_1, request: gr.Request): if not text: raise gr.Warning("Prompt cannot be empty.") if not model_name_0: raise gr.Warning("Model name A cannot be empty.") if not model_name_1: raise gr.Warning("Model name B cannot be empty.") if state_0 is None: state_0 = State(model_name_0, i2s_mode=False, offline=False) if state_1 is None: state_1 = State(model_name_1, i2s_mode=False, offline=False) ip = get_ip(request) t2s_multi_logger.info(f"generate. ip: {ip}") state_0.model_name, state_1.model_name = model_name_0, model_name_1 state_0.prompt, state_1.prompt = text, text try: idx = prompt_list.index(text) state_0.offline, state_1.offline = True, True state_0.offline_idx, state_1.offline_idx = idx, idx except: state_0.offline, state_1.offline = False, False state_0.offline_idx, state_1.offline_idx = None, None if not state_0.offline and not state_0.offline_idx: start_time = time.time() normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") state_0.normal_video = normal_video_0 state_0.rgb_video = rgb_video_0 state_1.normal_video = normal_video_1 state_1.rgb_video = rgb_video_1 yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1 # logger.info(f"===output===: {output}") data_0 = { "ip": get_ip(request), "model": model_name_0, "type": "offline", "gen_params": {}, "state": state_0.dict(), "start": round(start_time, 4), } data_1 = { "ip": get_ip(request), "model": model_name_1, "type": "offline", "gen_params": {}, "state": state_1.dict(), "start": round(start_time, 4), } else: start_time = time.time() shape_0, shape_1 = gen_func(text, model_name_0, model_name_1) generate_time = time.time() - start_time normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) finish_time = time.time() render_time = finish_time - start_time - generate_time state_0.normal_video = normal_video_0 state_0.rgb_video = rgb_video_0 state_1.normal_video = normal_video_1 state_1.rgb_video = rgb_video_1 yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1 # logger.info(f"===output===: {output}") data_0 = { "ip": get_ip(request), "model": model_name_0, "type": "online", "gen_params": {}, "state": state_0.dict(), "start": round(start_time, 4), "time": round(finish_time - start_time, 4), "generate_time": round(generate_time, 4), "render_time": round(render_time, 4), } data_1 = { "ip": get_ip(request), "model": model_name_1, "type": "online", "gen_params": {}, "state": state_1.dict(), "start": round(start_time, 4), "time": round(finish_time - start_time, 4), "generate_time": round(generate_time, 4), "render_time": round(render_time, 4), } with open(get_conv_log_filename(), "a") as fout: fout.write(json.dumps(data_0) + "\n") fout.write(json.dumps(data_1) + "\n") append_json_item_on_log_server(data_0, get_conv_log_filename()) append_json_item_on_log_server(data_1, get_conv_log_filename()) # for i, state in enumerate([state_0, state_1]): # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' # os.makedirs(os.path.dirname(output_file), exist_ok=True) # with open(output_file, 'w') as f: # state.output.save(f, 'PNG') # save_image_file_on_log_server(output_file) def generate_t2s_multi_annoy(gen_func, render_func, state_0, state_1, text, model_name_0, model_name_1, request: gr.Request): if not text: raise gr.Warning("Prompt cannot be empty.") if state_0 is None: state_0 = State(model_name_0, i2s_mode=False, offline=False) if state_1 is None: state_1 = State(model_name_1, i2s_mode=False, offline=False) ip = get_ip(request) t2s_multi_logger.info(f"generate. ip: {ip}") state_0.model_name, state_1.model_name = model_name_0, model_name_1 state_0.prompt, state_1.prompt = text, text try: idx = prompt_list.index(text) state_0.offline, state_1.offline = True, True state_0.offline_idx, state_1.offline_idx = idx, idx except: state_0.offline, state_1.offline = False, False state_0.offline_idx, state_1.offline_idx = None, None if not state_0.offline and not state_0.offline_idx: start_time = time.time() normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") state_0.normal_video = normal_video_0 state_0.rgb_video = rgb_video_0 state_1.normal_video = normal_video_1 state_1.rgb_video = rgb_video_1 yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_1, rgb_video_1, \ gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") # logger.info(f"===output===: {output}") data_0 = { "ip": get_ip(request), "model": model_name_0, "type": "offline", "gen_params": {}, "state": state_0.dict(), "start": round(start_time, 4), } data_1 = { "ip": get_ip(request), "model": model_name_1, "type": "offline", "gen_params": {}, "state": state_1.dict(), "start": round(start_time, 4), } else: start_time = time.time() shape_0, shape_1 = gen_func(text, model_name_0, model_name_1) generate_time = time.time() - start_time normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) finish_time = time.time() render_time = finish_time - start_time - generate_time state_0.normal_video = normal_video_0 state_0.rgb_video = rgb_video_0 state_1.normal_video = normal_video_1 state_1.rgb_video = rgb_video_1 yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \ gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") # logger.info(f"===output===: {output}") data_0 = { "ip": get_ip(request), "model": model_name_0, "type": "online", "gen_params": {}, "state": state_0.dict(), "start": round(start_time, 4), "time": round(finish_time - start_time, 4), "generate_time": round(generate_time, 4), "render_time": round(render_time, 4), } data_1 = { "ip": get_ip(request), "model": model_name_1, "type": "online", "gen_params": {}, "state": state_1.dict(), "start": round(start_time, 4), "time": round(finish_time - start_time, 4), "generate_time": round(generate_time, 4), "render_time": round(render_time, 4), } with open(get_conv_log_filename(), "a") as fout: fout.write(json.dumps(data_0) + "\n") fout.write(json.dumps(data_1) + "\n") append_json_item_on_log_server(data_0, get_conv_log_filename()) append_json_item_on_log_server(data_1, get_conv_log_filename()) # for i, state in enumerate([state_0, state_1]): # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' # os.makedirs(os.path.dirname(output_file), exist_ok=True) # with open(output_file, 'w') as f: # state.output.save(f, 'PNG') # save_image_file_on_log_server(output_file) def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Request): if not image: raise gr.Warning("Image cannot be empty.") if not model_name: raise gr.Warning("Model name cannot be empty.") if state is None: state = State(model_name, i2s_mode=True, offline=False) ip = get_ip(request) t2s_logger.info(f"generate. ip: {ip}") state.model_name = model_name state.image = image if not state.offline and not state.offline_idx: start_time = time.time() normal_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "normal", f"{state.offline_idx}.mp4") rgb_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "rgb", f"{state.offline_idx}.mp4") state.normal_video = normal_video state.rgb_video = rgb_video yield state, normal_video, rgb_video # logger.info(f"===output===: {output}") data = { "ip": ip, "model": model_name, "type": "offline", "gen_params": {}, "state": state.dict(), "start": round(start_time, 4), } else: start_time = time.time() shape = gen_func(image, model_name) generate_time = time.time() - start_time normal_video, rgb_video = render_func(shape, model_name) finish_time = time.time() render_time = finish_time - start_time - generate_time state.normal_video = normal_video state.rgb_video = rgb_video yield state, normal_video, rgb_video # logger.info(f"===output===: {output}") data = { "ip": ip, "model": model_name, "type": "online", "gen_params": {}, "state": state.dict(), "start": round(start_time, 4), "time": round(finish_time - start_time, 4), "generate_time": round(generate_time, 4), "render_time": round(render_time, 4), } with open(get_conv_log_filename(), "a") as fout: fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' # os.makedirs(os.path.dirname(src_img_file), exist_ok=True) # with open(src_img_file, 'w') as f: # state.source_image.save(f, 'PNG') # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' # with open(output_file, 'w') as f: # state.output.save(f, 'PNG') # save_image_file_on_log_server(src_img_file) # save_image_file_on_log_server(output_file) def generate_i2s_multi(gen_func, render_func, state_0, state_1, image, model_name_0, model_name_1, request: gr.Request): if not image: raise gr.Warning("Image cannot be empty.") if not model_name_0: raise gr.Warning("Model name A cannot be empty.") if not model_name_1: raise gr.Warning("Model name B cannot be empty.") if state_0 is None: state_0 = State(model_name_0, i2s_mode=True, offline=False) if state_1 is None: state_1 = State(model_name_1, i2s_mode=True, offline=False) ip = get_ip(request) t2s_multi_logger.info(f"generate. ip: {ip}") state_0.model_name, state_1.model_name = model_name_0, model_name_1 state_0.image, state_1.image = image, image if not state_0.offline and not state_0.offline_idx and \ not state_1.offline and not state_1.offline_idx: start_time = time.time() normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") state_0.normal_video = normal_video_0 state_0.rgb_video = rgb_video_0 state_1.normal_video = normal_video_1 state_1.rgb_video = rgb_video_1 yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \ gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") # logger.info(f"===output===: {output}") data_0 = { "ip": get_ip(request), "model": model_name_0, "type": "offline", "gen_params": {}, "state": state_0.dict(), "start": round(start_time, 4), } data_1 = { "ip": get_ip(request), "model": model_name_1, "type": "offline", "gen_params": {}, "state": state_1.dict(), "start": round(start_time, 4), } else: start_time = time.time() shape_0, shape_1 = gen_func(image, model_name_0, model_name_1) generate_time = time.time() - start_time normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) finish_time = time.time() render_time = finish_time - start_time - generate_time state_0.normal_video = normal_video_0 state_0.rgb_video = rgb_video_0 state_1.normal_video = normal_video_1 state_1.rgb_video = rgb_video_1 yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1 # logger.info(f"===output===: {output}") data_0 = { "ip": get_ip(request), "model": model_name_0, "type": "online", "gen_params": {}, "state": state_0.dict(), "start": round(start_time, 4), "time": round(finish_time - start_time, 4), "generate_time": round(generate_time, 4), "render_time": round(render_time, 4), } data_1 = { "ip": get_ip(request), "model": model_name_1, "type": "online", "gen_params": {}, "state": state_1.dict(), "start": round(start_time, 4), "time": round(finish_time - start_time, 4), "generate_time": round(generate_time, 4), "render_time": round(render_time, 4), } with open(get_conv_log_filename(), "a") as fout: fout.write(json.dumps(data_0) + "\n") fout.write(json.dumps(data_1) + "\n") append_json_item_on_log_server(data_0, get_conv_log_filename()) append_json_item_on_log_server(data_1, get_conv_log_filename()) # for i, state in enumerate([state_0, state_1]): # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' # os.makedirs(os.path.dirname(src_img_file), exist_ok=True) # with open(src_img_file, 'w') as f: # state.source_image.save(f, 'PNG') # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' # with open(output_file, 'w') as f: # state.output.save(f, 'PNG') # save_image_file_on_log_server(src_img_file) # save_image_file_on_log_server(output_file) def generate_i2s_multi_annoy(gen_func, state_0, state_1, image, model_name_0, model_name_1, request: gr.Request): if not image: raise gr.Warning("Image cannot be empty.") if state_0 is None: state_0 = State(model_name_0, i2s_mode=True, offline=False) if state_1 is None: state_1 = State(model_name_1, i2s_mode=True, offline=False) ip = get_ip(request) t2s_multi_logger.info(f"generate. ip: {ip}") state_0.model_name, state_1.model_name = model_name_0, model_name_1 state_0.image, state_1.image = image, image if not state_0.offline and not state_0.offline_idx and \ not state_1.offline and not state_1.offline_idx: start_time = time.time() normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") state_0.normal_video = normal_video_0 state_0.rgb_video = rgb_video_0 state_1.normal_video = normal_video_1 state_1.rgb_video = rgb_video_1 yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \ gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") # logger.info(f"===output===: {output}") data_0 = { "ip": get_ip(request), "model": model_name_0, "type": "offline", "gen_params": {}, "state": state_0.dict(), "start": round(start_time, 4), } data_1 = { "ip": get_ip(request), "model": model_name_1, "type": "offline", "gen_params": {}, "state": state_1.dict(), "start": round(start_time, 4), } else: start_time = time.time() shape_0, shape_1 = gen_func(image, model_name_0, model_name_1) generate_time = time.time() - start_time normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) finish_time = time.time() render_time = finish_time - start_time - generate_time state_0.normal_video = normal_video_0 state_0.rgb_video = rgb_video_0 state_1.normal_video = normal_video_1 state_1.rgb_video = rgb_video_1 yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \ gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") # logger.info(f"===output===: {output}") data_0 = { "ip": get_ip(request), "model": model_name_0, "type": "online", "gen_params": {}, "state": state_0.dict(), "start": round(start_time, 4), "time": round(finish_time - start_time, 4), "generate_time": round(generate_time, 4), "render_time": round(render_time, 4), } data_1 = { "ip": get_ip(request), "model": model_name_1, "type": "online", "gen_params": {}, "state": state_1.dict(), "start": round(start_time, 4), "time": round(finish_time - start_time, 4), "generate_time": round(generate_time, 4), "render_time": round(render_time, 4), } with open(get_conv_log_filename(), "a") as fout: fout.write(json.dumps(data_0) + "\n") fout.write(json.dumps(data_1) + "\n") append_json_item_on_log_server(data_0, get_conv_log_filename()) append_json_item_on_log_server(data_1, get_conv_log_filename()) # for i, state in enumerate([state_0, state_1]): # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' # os.makedirs(os.path.dirname(src_img_file), exist_ok=True) # with open(src_img_file, 'w') as f: # state.source_image.save(f, 'PNG') # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' # with open(output_file, 'w') as f: # state.output.save(f, 'PNG') # save_image_file_on_log_server(src_img_file) # save_image_file_on_log_server(output_file)