muneebable commited on
Commit
1b928b3
·
verified ·
1 Parent(s): af77af7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -49
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
  import torch
5
  from diffusers import DDPMPipeline, DDIMScheduler
6
  import open_clip
@@ -17,16 +16,14 @@ clip_model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pr
17
  clip_model.to(device)
18
 
19
  # Transform to preprocess images
20
- tfms = torchvision.transforms.Compose(
21
- [
22
- torchvision.transforms.Resize((224, 224)),
23
- torchvision.transforms.ToTensor(),
24
- torchvision.transforms.Normalize(
25
- mean=(0.48145466, 0.4578275, 0.40821073),
26
- std=(0.26862954, 0.26130258, 0.27577711),
27
- ),
28
- ]
29
- )
30
 
31
  # CLIP Loss function
32
  def clip_loss(image, text_features):
@@ -37,66 +34,70 @@ def clip_loss(image, text_features):
37
  return loss
38
 
39
  # Load Diffusion model
40
- model_repo_id = "muneebable/ddpm-celebahq-finetuned-anime-art" # Replace with desired model repo
41
  image_pipe = DDPMPipeline.from_pretrained(model_repo_id)
42
  image_pipe.to(device)
43
 
44
  # Load scheduler
45
  scheduler = DDIMScheduler.from_pretrained(model_repo_id)
46
- scheduler.set_timesteps(num_inference_steps=40)
47
 
48
- # Gradio Inference Function
49
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
50
- if randomize_seed:
51
- seed = random.randint(0, np.iinfo(np.int32).max)
52
- generator = torch.manual_seed(seed)
53
 
54
- # Embed prompt with CLIP
55
  text = open_clip.tokenize([prompt]).to(device)
56
- with torch.no_grad():
57
  text_features = clip_model.encode_text(text)
58
-
59
- x = torch.randn(4, 3, 256, 256).to(device)
 
60
 
61
  for i, t in tqdm(enumerate(scheduler.timesteps)):
62
  model_input = scheduler.scale_model_input(x, t)
 
63
  with torch.no_grad():
64
  noise_pred = image_pipe.unet(model_input, t)["sample"]
65
  cond_grad = 0
66
- for cut in range(4):
 
67
  x = x.detach().requires_grad_()
 
68
  x0 = scheduler.step(noise_pred, t, x).pred_original_sample
 
69
  loss = clip_loss(x0, text_features) * guidance_scale
70
- cond_grad -= torch.autograd.grad(loss, x)[0] / 4
 
 
 
71
  alpha_bar = scheduler.alphas_cumprod[i]
72
  x = x.detach() + cond_grad * alpha_bar.sqrt()
 
73
  x = scheduler.step(noise_pred, t, x).prev_sample
74
-
75
- # Convert output to an image
76
- grid = torchvision.utils.make_grid(x.detach(), nrow=4)
77
- im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
78
- result_image = Image.fromarray((im.numpy() * 255).astype(np.uint8))
79
 
80
- return result_image, seed
81
-
82
- # Gradio App
83
- with gr.Blocks() as demo:
84
- prompt = gr.Textbox(placeholder="Enter your prompt", label="Prompt")
85
- run_button = gr.Button("Generate")
86
-
87
- result = gr.Image(label="Generated Image")
88
 
89
- with gr.Accordion("Advanced Settings"):
90
- negative_prompt = gr.Textbox(label="Negative Prompt")
91
- seed = gr.Slider(0, np.iinfo(np.int32).max, value=0, label="Seed")
92
- randomize_seed = gr.Checkbox(True, label="Randomize Seed")
93
- width = gr.Slider(256, 1024, value=512, label="Width")
94
- height = gr.Slider(256, 1024, value=512, label="Height")
95
- guidance_scale = gr.Slider(0.0, 10.0, value=7.5, label="Guidance Scale")
96
- num_inference_steps = gr.Slider(1, 50, value=50, label="Steps")
97
 
98
- run_button.click(infer,
99
- inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
100
- outputs=[result, seed])
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- demo.queue().launch()
 
1
  import gradio as gr
2
  import numpy as np
 
3
  import torch
4
  from diffusers import DDPMPipeline, DDIMScheduler
5
  import open_clip
 
16
  clip_model.to(device)
17
 
18
  # Transform to preprocess images
19
+ tfms = torchvision.transforms.Compose([
20
+ torchvision.transforms.Resize((224, 224)),
21
+ torchvision.transforms.ToTensor(),
22
+ torchvision.transforms.Normalize(
23
+ mean=(0.48145466, 0.4578275, 0.40821073),
24
+ std=(0.26862954, 0.26130258, 0.27577711),
25
+ ),
26
+ ])
 
 
27
 
28
  # CLIP Loss function
29
  def clip_loss(image, text_features):
 
34
  return loss
35
 
36
  # Load Diffusion model
37
+ model_repo_id = "muneebable/ddpm-celebahq-finetuned-anime-art"
38
  image_pipe = DDPMPipeline.from_pretrained(model_repo_id)
39
  image_pipe.to(device)
40
 
41
  # Load scheduler
42
  scheduler = DDIMScheduler.from_pretrained(model_repo_id)
 
43
 
44
+ def generate_image(prompt, guidance_scale, num_steps):
45
+ scheduler.set_timesteps(num_inference_steps=num_steps)
 
 
 
46
 
47
+ # We embed a prompt with CLIP as our target
48
  text = open_clip.tokenize([prompt]).to(device)
49
+ with torch.no_grad(), torch.cuda.amp.autocast():
50
  text_features = clip_model.encode_text(text)
51
+
52
+ x = torch.randn(1, 3, 256, 256).to(device)
53
+ n_cuts = 4
54
 
55
  for i, t in tqdm(enumerate(scheduler.timesteps)):
56
  model_input = scheduler.scale_model_input(x, t)
57
+ # predict the noise residual
58
  with torch.no_grad():
59
  noise_pred = image_pipe.unet(model_input, t)["sample"]
60
  cond_grad = 0
61
+ for cut in range(n_cuts):
62
+ # Set requires grad on x
63
  x = x.detach().requires_grad_()
64
+ # Get the predicted x0:
65
  x0 = scheduler.step(noise_pred, t, x).pred_original_sample
66
+ # Calculate loss
67
  loss = clip_loss(x0, text_features) * guidance_scale
68
+ # Get gradient (scale by n_cuts since we want the average)
69
+ cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts
70
+
71
+ # Modify x based on this gradient
72
  alpha_bar = scheduler.alphas_cumprod[i]
73
  x = x.detach() + cond_grad * alpha_bar.sqrt()
74
+ # Now step with scheduler
75
  x = scheduler.step(noise_pred, t, x).prev_sample
 
 
 
 
 
76
 
77
+ grid = torchvision.utils.make_grid(x.detach(), nrow=1)
78
+ im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
79
+ im = (im * 255).byte().numpy()
80
+ return Image.fromarray(im)
 
 
 
 
81
 
82
+ # Gradio interface
83
+ def gradio_interface(prompt, guidance_scale, num_steps):
84
+ return generate_image(prompt, guidance_scale, num_steps)
 
 
 
 
 
85
 
86
+ iface = gr.Interface(
87
+ fn=gradio_interface,
88
+ inputs=[
89
+ gr.Textbox(label="Prompt", value="Red Rose (still life), red flower painting"),
90
+ gr.Slider(minimum=1, maximum=20, step=1, label="Guidance Scale", value=8),
91
+ gr.Slider(minimum=10, maximum=100, step=10, label="Number of Steps", value=50)
92
+ ],
93
+ outputs=gr.Image(type="pil", label="Generated Image"),
94
+ title="CLIP-Guided Diffusion Image Generation",
95
+ description="Generate images using CLIP-guided diffusion. Enter a prompt, adjust the guidance scale, and set the number of steps.",
96
+ examples=[
97
+ ["A serene landscape with mountains and a lake", 10, 50],
98
+ ["A futuristic cityscape at night", 15, 70],
99
+ ["A cute cartoon character", 5, 30]
100
+ ]
101
+ )
102
 
103
+ iface.launch()