radames commited on
Commit
7ad11e7
·
1 Parent(s): ddaada9

QRCode pipeline

Browse files
qr-code.png ADDED
server/pipelines/controlnetLoraSD15QRCode.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ StableDiffusionControlNetImg2ImgPipeline,
3
+ ControlNetModel,
4
+ LCMScheduler,
5
+ AutoencoderTiny,
6
+ )
7
+ from compel import Compel
8
+ import torch
9
+
10
+ try:
11
+ import intel_extension_for_pytorch as ipex # type: ignore
12
+ except:
13
+ pass
14
+
15
+ import psutil
16
+ from config import Args
17
+ from pydantic import BaseModel, Field
18
+ from PIL import Image
19
+ import math
20
+
21
+ taesd_model = "madebyollin/taesd"
22
+ controlnet_model = "monster-labs/control_v1p_sd15_qrcode_monster"
23
+ base_model = "nitrosocke/mo-di-diffusion"
24
+ lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
25
+ default_prompt = "abstract art of a men with curly hair by Pablo Picasso"
26
+ page_content = """
27
+ <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model SDv1.5</h1>
28
+ <h3 class="text-xl font-bold">LCM + LoRA + Controlnet + QRCode</h3>
29
+ <p class="text-sm">
30
+ This demo showcases
31
+ <a
32
+ href="https://huggingface.co/blog/lcm_lora"
33
+ target="_blank"
34
+ class="text-blue-500 underline hover:no-underline">LCM LoRA</a>
35
+ + ControlNet + Image to Imasge pipeline using
36
+ <a
37
+ href="https://huggingface.co/docs/diffusers/main/en/using-diffusers/lcm#performing-inference-with-lcm"
38
+ target="_blank"
39
+ class="text-blue-500 underline hover:no-underline">Diffusers</a
40
+ > with a MJPEG stream server.
41
+ </p>
42
+ <p class="text-sm text-gray-500">
43
+ Change the prompt to generate different images, accepts <a
44
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
45
+ target="_blank"
46
+ class="text-blue-500 underline hover:no-underline">Compel</a
47
+ > syntax.
48
+ </p>
49
+ """
50
+
51
+
52
+ class Pipeline:
53
+ class Info(BaseModel):
54
+ name: str = "controlnet+loras+sd15"
55
+ title: str = "LCM + LoRA + Controlnet"
56
+ description: str = "Generates an image from a text prompt"
57
+ input_mode: str = "image"
58
+ page_content: str = page_content
59
+
60
+ class InputParams(BaseModel):
61
+ prompt: str = Field(
62
+ default_prompt,
63
+ title="Prompt",
64
+ field="textarea",
65
+ id="prompt",
66
+ )
67
+ seed: int = Field(
68
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
69
+ )
70
+ steps: int = Field(
71
+ 5, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
72
+ )
73
+ width: int = Field(
74
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
75
+ )
76
+ height: int = Field(
77
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
78
+ )
79
+ guidance_scale: float = Field(
80
+ 1.0,
81
+ min=0,
82
+ max=2,
83
+ step=0.001,
84
+ title="Guidance Scale",
85
+ field="range",
86
+ hide=True,
87
+ id="guidance_scale",
88
+ )
89
+ strength: float = Field(
90
+ 0.6,
91
+ min=0.25,
92
+ max=1.0,
93
+ step=0.001,
94
+ title="Strength",
95
+ field="range",
96
+ hide=True,
97
+ id="strength",
98
+ )
99
+ controlnet_scale: float = Field(
100
+ 1.0,
101
+ min=0,
102
+ max=1.0,
103
+ step=0.001,
104
+ title="Controlnet Scale",
105
+ field="range",
106
+ hide=True,
107
+ id="controlnet_scale",
108
+ )
109
+ controlnet_start: float = Field(
110
+ 0.0,
111
+ min=0,
112
+ max=1.0,
113
+ step=0.001,
114
+ title="Controlnet Start",
115
+ field="range",
116
+ hide=True,
117
+ id="controlnet_start",
118
+ )
119
+ controlnet_end: float = Field(
120
+ 1.0,
121
+ min=0,
122
+ max=1.0,
123
+ step=0.001,
124
+ title="Controlnet End",
125
+ field="range",
126
+ hide=True,
127
+ id="controlnet_end",
128
+ )
129
+ blend: float = Field(
130
+ 0.1,
131
+ min=0.0,
132
+ max=1.0,
133
+ step=0.001,
134
+ title="Blend",
135
+ field="range",
136
+ hide=True,
137
+ id="blend",
138
+ )
139
+
140
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
141
+ controlnet_qrcode = ControlNetModel.from_pretrained(
142
+ controlnet_model, torch_dtype=torch_dtype, subfolder="v2"
143
+ ).to(device)
144
+
145
+ if args.safety_checker:
146
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
147
+ base_model,
148
+ controlnet=controlnet_qrcode,
149
+ )
150
+ else:
151
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
152
+ base_model,
153
+ safety_checker=None,
154
+ controlnet=controlnet_qrcode,
155
+ )
156
+
157
+ self.control_image = Image.open(
158
+ "qr-code.png").convert("RGB").resize((512, 512))
159
+
160
+ self.pipe.scheduler = LCMScheduler.from_config(
161
+ self.pipe.scheduler.config)
162
+ self.pipe.set_progress_bar_config(disable=True)
163
+ if device.type != "mps":
164
+ self.pipe.unet.to(memory_format=torch.channels_last)
165
+
166
+ if args.taesd:
167
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
168
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
169
+ ).to(device)
170
+
171
+ # Load LCM LoRA
172
+ self.pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
173
+ self.pipe.to(device=device, dtype=torch_dtype).to(device)
174
+ if args.compel:
175
+ self.compel_proc = Compel(
176
+ tokenizer=self.pipe.tokenizer,
177
+ text_encoder=self.pipe.text_encoder,
178
+ truncate_long_prompts=False,
179
+ )
180
+ if args.torch_compile:
181
+ self.pipe.unet = torch.compile(
182
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
183
+ )
184
+ self.pipe.vae = torch.compile(
185
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
186
+ )
187
+ self.pipe(
188
+ prompt="warmup",
189
+ image=[Image.new("RGB", (512, 512))],
190
+ control_image=[Image.new("RGB", (512, 512))],
191
+ )
192
+
193
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
194
+ generator = torch.manual_seed(params.seed)
195
+
196
+ prompt = f"modern disney style {params.prompt}"
197
+ prompt_embeds = None
198
+ prompt = params.prompt
199
+ if hasattr(self, "compel_proc"):
200
+ prompt_embeds = self.compel_proc(prompt)
201
+ prompt = None
202
+
203
+ steps = params.steps
204
+ strength = params.strength
205
+ if int(steps * strength) < 1:
206
+ steps = math.ceil(1 / max(0.10, strength))
207
+
208
+ blend_qr_image = Image.blend(
209
+ params.image,
210
+ self.control_image,
211
+ alpha=params.blend
212
+ )
213
+ results = self.pipe(
214
+ image=blend_qr_image,
215
+ control_image=self.control_image,
216
+ prompt=prompt,
217
+ prompt_embeds=prompt_embeds,
218
+ generator=generator,
219
+ strength=strength,
220
+ num_inference_steps=steps,
221
+ guidance_scale=params.guidance_scale,
222
+ width=params.width,
223
+ height=params.height,
224
+ output_type="pil",
225
+ controlnet_conditioning_scale=params.controlnet_scale,
226
+ control_guidance_start=params.controlnet_start,
227
+ control_guidance_end=params.controlnet_end,
228
+ )
229
+
230
+ nsfw_content_detected = (
231
+ results.nsfw_content_detected[0]
232
+ if "nsfw_content_detected" in results
233
+ else False
234
+ )
235
+ if nsfw_content_detected:
236
+ return None
237
+ result_image = results.images[0]
238
+
239
+ return result_image