sagar007 commited on
Commit
5f38454
·
verified ·
1 Parent(s): 8ed4956

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+ import torch.nn.functional as F
7
+ from torchvision import transforms as tfms
8
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
9
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline
10
+
11
+ torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
12
+ if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
13
+
14
+ # Load the pipeline
15
+ model_path = "CompVis/stable-diffusion-v1-4"
16
+ sd_pipeline = DiffusionPipeline.from_pretrained(
17
+ model_path,
18
+ low_cpu_mem_usage=True,
19
+ torch_dtype=torch.float32
20
+ ).to(torch_device)
21
+
22
+ # Load textual inversions
23
+ sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
24
+ sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
25
+ sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
26
+ sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
27
+ sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
28
+ sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
29
+ sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")
30
+
31
+ # Update style token dictionary
32
+ style_token_dict = {
33
+ "Illustration Style": '<illustration-style>',
34
+ "Line Art":'<line-art>',
35
+ "Hitokomoru Style":'<hitokomoru-style-nao>',
36
+ "Marc Allante": '<Marc_Allante>',
37
+ "Midjourney":'<midjourney-style>',
38
+ "Hanfu Anime": '<hanfu-anime-style>',
39
+ "Birb Style": '<birb-style>'
40
+ }
41
+
42
+ def apply_guidance(latents, guidance_method, loss_scale):
43
+ if guidance_method == 'Grayscale':
44
+ rgb = latents_to_pil(latents)[0]
45
+ gray = rgb.convert('L')
46
+ gray_latents = pil_to_latent(gray.convert('RGB'))
47
+ return latents + (gray_latents - latents) * loss_scale
48
+ elif guidance_method == 'Bright':
49
+ bright_latents = F.relu(latents) # Simple brightness increase
50
+ return latents + (bright_latents - latents) * loss_scale
51
+ elif guidance_method == 'Contrast':
52
+ mean = latents.mean()
53
+ contrast_latents = (latents - mean) * 2 + mean
54
+ return latents + (contrast_latents - latents) * loss_scale
55
+ elif guidance_method == 'Symmetry':
56
+ flipped_latents = torch.flip(latents, [3]) # Flip horizontally
57
+ return latents + (flipped_latents - latents) * loss_scale
58
+ elif guidance_method == 'Saturation':
59
+ rgb = latents_to_pil(latents)[0]
60
+ saturated = tfms.functional.adjust_saturation(tfms.ToTensor()(rgb), 2)
61
+ saturated_latents = pil_to_latent(tfms.ToPILImage()(saturated))
62
+ return latents + (saturated_latents - latents) * loss_scale
63
+ else:
64
+ return latents
65
+
66
+ def generate_with_guidance(prompt, num_inference_steps, guidance_scale, seed, guidance_method, loss_scale):
67
+ generator = torch.Generator(device=torch_device).manual_seed(seed)
68
+
69
+ # Get the text embeddings
70
+ text_input = sd_pipeline.tokenizer(prompt, padding="max_length", max_length=sd_pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt")
71
+ with torch.no_grad():
72
+ text_embeddings = sd_pipeline.text_encoder(text_input.input_ids.to(torch_device))[0]
73
+
74
+ # Set the timesteps
75
+ sd_pipeline.scheduler.set_timesteps(num_inference_steps)
76
+
77
+ # Prepare latents
78
+ latents = torch.randn(
79
+ (1, sd_pipeline.unet.in_channels, 64, 64),
80
+ generator=generator,
81
+ device=torch_device
82
+ )
83
+ latents = latents * sd_pipeline.scheduler.init_noise_sigma
84
+
85
+ # Denoising loop
86
+ for t in tqdm(sd_pipeline.scheduler.timesteps):
87
+ # Expand the latents for classifier-free guidance
88
+ latent_model_input = torch.cat([latents] * 2)
89
+ latent_model_input = sd_pipeline.scheduler.scale_model_input(latent_model_input, timestep=t)
90
+
91
+ # Predict the noise residual
92
+ with torch.no_grad():
93
+ noise_pred = sd_pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
94
+
95
+ # Perform guidance
96
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
97
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
98
+
99
+ # Apply custom guidance
100
+ latents = apply_guidance(latents, guidance_method, loss_scale / 10000) # Normalize loss_scale
101
+
102
+ # Compute the previous noisy sample x_t -> x_t-1
103
+ latents = sd_pipeline.scheduler.step(noise_pred, t, latents).prev_sample
104
+
105
+ # Scale and decode the image latents with vae
106
+ latents = 1 / 0.18215 * latents
107
+ with torch.no_grad():
108
+ image = sd_pipeline.vae.decode(latents).sample
109
+
110
+ # Convert to PIL Image
111
+ image = (image / 2 + 0.5).clamp(0, 1)
112
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
113
+ image = (image * 255).round().astype("uint8")[0]
114
+ image = Image.fromarray(image)
115
+
116
+ return image
117
+
118
+ def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
119
+ prompt = text + " " + style_token_dict[style]
120
+
121
+ # Generate image with pipeline
122
+ image_pipeline = sd_pipeline(
123
+ prompt,
124
+ num_inference_steps=inference_step,
125
+ guidance_scale=guidance_scale,
126
+ generator=torch.Generator(device=torch_device).manual_seed(seed)
127
+ ).images[0]
128
+
129
+ # Generate image with guidance
130
+ image_guide = generate_with_guidance(prompt, inference_step, guidance_scale, seed, guidance_method, loss_scale)
131
+
132
+ return image_pipeline, image_guide
133
+
134
+ title = "Generative with Textual Inversion and Guidance"
135
+ description = "A Gradio interface to infer Stable Diffusion and generate images with different art styles and guidance methods"
136
+ examples = [
137
+ ["A majestic castle on a floating island", 'Illustration Style', 20, 7.5, 42, 'Grayscale', 200],
138
+ ["A cyberpunk cityscape at night", 'Midjourney', 25, 8.0, 123, 'Contrast', 300]
139
+ ]
140
+
141
+ demo = gr.Interface(inference,
142
+ inputs = [gr.Textbox(label="Prompt", type="text"),
143
+ gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style"),
144
+ gr.Slider(1, 50, 10, step = 1, label="Inference steps"),
145
+ gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
146
+ gr.Slider(0, 10000, 42, step = 1, label="Seed"),
147
+ gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast',
148
+ 'Symmetry', 'Saturation'], value="Grayscale"),
149
+ gr.Slider(100, 10000, 200, step = 100, label="Loss scale")],
150
+ outputs= [gr.Image(width=512, height=512, label="Generated art"),
151
+ gr.Image(width=512, height=512, label="Generated art with guidance")],
152
+ title=title,
153
+ description=description,
154
+ examples=examples)
155
+
156
+ demo.launch()