Spaces:
Runtime error
Runtime error
import gradio as gr | |
from diffusers import DPMSolverMultistepScheduler, AutoencoderKL, UNet2DConditionModel | |
from transformers import CLIPTextModel, CLIPTokenizer | |
import torch | |
from tqdm.auto import tqdm | |
from time import time | |
from PIL import Image | |
vae = AutoencoderKL.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="vae", allow_pickle=True) | |
tokenizer = CLIPTokenizer.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="tokenizer", allow_pickle=True) | |
textEncoder = CLIPTextModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="text_encoder", allow_pickle=True) | |
unet = UNet2DConditionModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="unet", allow_pickle=True) | |
scheduler = DPMSolverMultistepScheduler.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="scheduler", allow_pickle=True) | |
torchDevice = "cuda" | |
vae.to(torchDevice) | |
textEncoder.to(torchDevice) | |
unet.to(torchDevice) | |
def generate(prompt: str, negativePrompt: str, steps: int, cfg: float, seed: int, randomized: bool, width: int, height: int): | |
generator = torch.manual_seed(time()) | |
if randomized: | |
seed = torch.randint(10000, 9223372036854776000, (1,))[0] | |
batchSize = len(prompt) | |
textInput = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
textEmbeddings = textEncoder(textInput.input_ids.to(torchDevice), attention_mask=textInput.attention_mask.to(torchDevice))[0] | |
maxLength = textInput.input_ids.shape[-1] | |
unconditionedInput = tokenizer([""] * batchSize, padding="max_length", max_length=maxLength, return_tensors="pt") | |
unconditionedEmbeddings = textEncoder(unconditionedInput.input_ids.to(torchDevice))[0] | |
textEmbeddings = torch.cat([unconditionedEmbeddings, textEmbeddings]) | |
latents = torch.randn((batchSize, unet.config.in_channels, height // 8, width // 8), generator=generator, device=torchDevice) | |
latents = latents * scheduler.init_noise_sigma | |
scheduler.set_timesteps(steps) | |
for t in tqdm(scheduler.timesteps): | |
latentModelInput = torch.cat([latents] * 2) | |
latentModelInput = scheduler.scale_model_input(latentModelInput, timestep=t) | |
with torch.no_grad(): | |
noisePred = unet(latentModelInput, t, encoder_hidden_states=textEmbeddings).sample | |
unconditionedNoisePred, noisePredText = noisePred.chunk(2) | |
noisePred = unconditionedNoisePred + cfg * (noisePredText - unconditionedNoisePred) | |
latents = scheduler.step(noisePred, t, latents).prev_sample | |
latents = 1 / 0.18215 * latents | |
with torch.no_grad(): | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1).squeeze() | |
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy() | |
images = (image * 255).round().astype("uint8") | |
return Image.fromarray(images) | |
interface = gr.Interface(fn=generate, inputs=[ | |
gr.Textbox(lines=3, placeholder="Prompt is here...", label="Prompt"), | |
gr.Textbox(lines=3, placeholder="Negative prompt is here...", label="Negative Prompt"), | |
gr.Slider(0, 1000, step=1, label="Steps", value=20), | |
gr.Slider(0, 50, step=0.1, label="CFG Scale", value=8), | |
gr.Number(label="Seed", value=0), | |
gr.Checkbox(label="Randomize Seed", value=True), | |
gr.Slider(256, 999999, step=64, label="Width", value=512), | |
gr.Slider(256, 999999, step=64, label="Height", value=512), | |
], outputs="image") | |
if __name__ == "__main__": | |
interface.launch() |