import gradio import subprocess from PIL import Image import torch, torch.backends.cudnn, torch.backends.cuda from min_dalle import MinDalle from emoji import demojize import string def filename_from_text(text: str) -> str: text = demojize(text, delimiters=['', '']) text = text.lower().encode('ascii', errors='ignore').decode() allowed_chars = string.ascii_lowercase + ' ' text = ''.join(i for i in text.lower() if i in allowed_chars) text = text[:64] text = '-'.join(text.strip().split()) if len(text) == 0: text = 'blank' return text def log_gpu_memory(): print(subprocess.check_output('nvidia-smi').decode('utf-8')) # log_gpu_memory() model = MinDalle( is_mega=True, is_reusable=True, device='cpu', # dtype=torch.float32 ) # log_gpu_memory() def run_model( text: str, grid_size: int, is_seamless: bool, save_as_png: bool, temperature: float, supercondition: str, top_k: str ) -> str: torch.set_grad_enabled(False) torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True print('text:', text) print('grid_size:', grid_size) print('is_seamless:', is_seamless) print('temperature:', temperature) print('supercondition:', supercondition) print('top_k:', top_k) try: temperature = float(temperature) assert(temperature > 1e-6) except: raise Exception('Temperature must be a positive nonzero number') try: grid_size = int(grid_size) assert(grid_size <= 5) assert(grid_size >= 1) except: raise Exception('Grid size must be between 1 and 5') try: top_k = int(top_k) assert(top_k <= 16384) assert(top_k >= 1) except: raise Exception('Top k must be between 1 and 16384') with torch.no_grad(): image = model.generate_image( text = text, seed = -1, grid_size = grid_size, is_seamless = bool(is_seamless), temperature = temperature, supercondition_factor = float(supercondition), top_k = top_k, is_verbose = True ) log_gpu_memory() ext = 'png' if bool(save_as_png) else 'jpg' filename = filename_from_text(text) image_path = '{}.{}'.format(filename, ext) image.save(image_path) return image_path demo = gradio.Blocks(analytics_enabled=True) with demo: with gradio.Row(): with gradio.Column(): input_text = gradio.Textbox( label='Input Text', value='Rusty Iron Man suit found abandoned in the woods being reclaimed by nature', lines=3 ) run_button = gradio.Button(value='Generate Image').style(full_width=True) output_image = gradio.Image( value='examples/dog.jpg', label='Output Image', type='file', interactive=False ) with gradio.Column(): gradio.Markdown('## Settings') with gradio.Row(): grid_size = gradio.Slider( label='Grid Size', value=3, minimum=1, maximum=5, step=1 ) save_as_png = gradio.Checkbox( label='Output PNG', value=False ) is_seamless = gradio.Checkbox( label='Seamless', value=False ) gradio.Markdown('#### Advanced') with gradio.Row(): temperature = gradio.Number( label='Temperature', value=1 ) top_k = gradio.Dropdown( label='Top-k', choices=[str(2 ** i) for i in range(15)], value='128' ) supercondition = gradio.Dropdown( label='Super Condition', choices=[str(2 ** i) for i in range(2, 7)], value='16' ) gradio.Markdown( """ #### Parameter - **Input Text**: For long prompts, only the first 64 text tokens will be used to generate the image. - **Grid Size**: Size of the image grid. 3x3 takes about 15 seconds. - **Seamless**: Tile images in image token space instead of pixel space. - **Temperature**: High temperature increases the probability of sampling low scoring image tokens. - **Top-k**: Each image token is sampled from the top-k scoring tokens. - **Super Condition**: Higher values can result in better agreement with the text. #### """ ) gradio.Examples( examples=[ ['A white cat with golden sunglasses on, pink background, studio lighting, 4k, award winning photography', 2, 'examples/cat.png'], ['an astronaut dancing on the moon’s surface, close-up photo', 2, 'examples/astronaut.png'], ['A photo of a Samoyed dog with its tongue out hugging a white Siamese cat', 5, 'examples/dog.png'], ['Dragons of Earth, Wind, Fire, powering up a huge sphere of compressed energy, digital art', 2, 'examples/dragon.png'], ['A snowboarder jumping in the air while coming down a ski mountain, concept art, artstation, unreal engine, 3d render, HD, Bokeh', 3, 'examples/snow.png'], ['Portrait of a basset hound, 8k, photograph', 3, 'examples/8kdog.png'], ], inputs=[ input_text, grid_size, output_image ], examples_per_page=20 ) run_button.click( fn=run_model, inputs=[ input_text, grid_size, is_seamless, save_as_png, temperature, supercondition, top_k ], outputs=[ output_image ] ) demo.launch()