Spaces:
Runtime error
Runtime error
import gradio | |
import subprocess | |
from PIL import Image | |
import torch, torch.backends.cudnn, torch.backends.cuda | |
from min_dalle import MinDalle | |
from emoji import demojize | |
import string | |
def filename_from_text(text: str) -> str: | |
text = demojize(text, delimiters=['', '']) | |
text = text.lower().encode('ascii', errors='ignore').decode() | |
allowed_chars = string.ascii_lowercase + ' ' | |
text = ''.join(i for i in text.lower() if i in allowed_chars) | |
text = text[:64] | |
text = '-'.join(text.strip().split()) | |
if len(text) == 0: text = 'blank' | |
return text | |
def log_gpu_memory(): | |
print(subprocess.check_output('nvidia-smi').decode('utf-8')) | |
# log_gpu_memory() | |
model = MinDalle( | |
is_mega=True, | |
is_reusable=True, | |
device='cpu', | |
# dtype=torch.float32 | |
) | |
# log_gpu_memory() | |
def run_model( | |
text: str, | |
grid_size: int, | |
is_seamless: bool, | |
save_as_png: bool, | |
temperature: float, | |
supercondition: str, | |
top_k: str | |
) -> str: | |
torch.set_grad_enabled(False) | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.deterministic = False | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True | |
print('text:', text) | |
print('grid_size:', grid_size) | |
print('is_seamless:', is_seamless) | |
print('temperature:', temperature) | |
print('supercondition:', supercondition) | |
print('top_k:', top_k) | |
try: | |
temperature = float(temperature) | |
assert(temperature > 1e-6) | |
except: | |
raise Exception('Temperature must be a positive nonzero number') | |
try: | |
grid_size = int(grid_size) | |
assert(grid_size <= 5) | |
assert(grid_size >= 1) | |
except: | |
raise Exception('Grid size must be between 1 and 5') | |
try: | |
top_k = int(top_k) | |
assert(top_k <= 16384) | |
assert(top_k >= 1) | |
except: | |
raise Exception('Top k must be between 1 and 16384') | |
with torch.no_grad(): | |
image = model.generate_image( | |
text = text, | |
seed = -1, | |
grid_size = grid_size, | |
is_seamless = bool(is_seamless), | |
temperature = temperature, | |
supercondition_factor = float(supercondition), | |
top_k = top_k, | |
is_verbose = True | |
) | |
log_gpu_memory() | |
ext = 'png' if bool(save_as_png) else 'jpg' | |
filename = filename_from_text(text) | |
image_path = '{}.{}'.format(filename, ext) | |
image.save(image_path) | |
return image_path | |
demo = gradio.Blocks(analytics_enabled=True) | |
with demo: | |
with gradio.Row(): | |
with gradio.Column(): | |
input_text = gradio.Textbox( | |
label='Input Text', | |
value='Rusty Iron Man suit found abandoned in the woods being reclaimed by nature', | |
lines=3 | |
) | |
run_button = gradio.Button(value='Generate Image').style(full_width=True) | |
output_image = gradio.Image( | |
value='examples/dog.jpg', | |
label='Output Image', | |
type='file', | |
interactive=False | |
) | |
with gradio.Column(): | |
gradio.Markdown('## Settings') | |
with gradio.Row(): | |
grid_size = gradio.Slider( | |
label='Grid Size', | |
value=3, | |
minimum=1, | |
maximum=5, | |
step=1 | |
) | |
save_as_png = gradio.Checkbox( | |
label='Output PNG', | |
value=False | |
) | |
is_seamless = gradio.Checkbox( | |
label='Seamless', | |
value=False | |
) | |
gradio.Markdown('#### Advanced') | |
with gradio.Row(): | |
temperature = gradio.Number( | |
label='Temperature', | |
value=1 | |
) | |
top_k = gradio.Dropdown( | |
label='Top-k', | |
choices=[str(2 ** i) for i in range(15)], | |
value='128' | |
) | |
supercondition = gradio.Dropdown( | |
label='Super Condition', | |
choices=[str(2 ** i) for i in range(2, 7)], | |
value='16' | |
) | |
gradio.Markdown( | |
""" | |
#### Parameter | |
- **Input Text**: For long prompts, only the first 64 text tokens will be used to generate the image. | |
- **Grid Size**: Size of the image grid. 3x3 takes about 15 seconds. | |
- **Seamless**: Tile images in image token space instead of pixel space. | |
- **Temperature**: High temperature increases the probability of sampling low scoring image tokens. | |
- **Top-k**: Each image token is sampled from the top-k scoring tokens. | |
- **Super Condition**: Higher values can result in better agreement with the text. | |
#### | |
""" | |
) | |
gradio.Examples( | |
examples=[ | |
['A white cat with golden sunglasses on, pink background, studio lighting, 4k, award winning photography', 2, 'examples/cat.png'], | |
['an astronaut dancing on the moon’s surface, close-up photo', 2, 'examples/astronaut.png'], | |
['A photo of a Samoyed dog with its tongue out hugging a white Siamese cat', 5, 'examples/dog.png'], | |
['Dragons of Earth, Wind, Fire, powering up a huge sphere of compressed energy, digital art', 2, 'examples/dragon.png'], | |
['A snowboarder jumping in the air while coming down a ski mountain, concept art, artstation, unreal engine, 3d render, HD, Bokeh', 3, 'examples/snow.png'], | |
['Portrait of a basset hound, 8k, photograph', 3, 'examples/8kdog.png'], | |
], | |
inputs=[ | |
input_text, | |
grid_size, | |
output_image | |
], | |
examples_per_page=20 | |
) | |
run_button.click( | |
fn=run_model, | |
inputs=[ | |
input_text, | |
grid_size, | |
is_seamless, | |
save_as_png, | |
temperature, | |
supercondition, | |
top_k | |
], | |
outputs=[ | |
output_image | |
] | |
) | |
demo.launch() |