radames commited on
Commit
8a9145c
·
1 Parent(s): a96a8c6

add extra options to img2img

Browse files
Files changed (2) hide show
  1. app-img2img.py +9 -12
  2. img2img/index.html +103 -22
app-img2img.py CHANGED
@@ -76,9 +76,9 @@ pipe.unet.to(memory_format=torch.channels_last)
76
  if psutil.virtual_memory().total < 64 * 1024**3:
77
  pipe.enable_attention_slicing()
78
 
79
- if not mps_available and not xpu_available:
80
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
81
- pipe(prompt="warmup", image=[Image.new("RGB", (512, 512))])
82
 
83
  compel_proc = Compel(
84
  tokenizer=pipe.tokenizer,
@@ -89,30 +89,27 @@ user_queue_map = {}
89
 
90
 
91
  class InputParams(BaseModel):
92
- prompt: str
93
  seed: int = 2159232
 
94
  guidance_scale: float = 8.0
95
  strength: float = 0.5
 
 
96
  width: int = WIDTH
97
  height: int = HEIGHT
98
 
99
-
100
- def predict(
101
- input_image: Image.Image, params: InputParams, prompt_embeds: torch.Tensor = None
102
- ):
103
  generator = torch.manual_seed(params.seed)
104
- # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
105
- num_inference_steps = 4
106
  results = pipe(
107
  prompt_embeds=prompt_embeds,
108
  generator=generator,
109
  image=input_image,
110
  strength=params.strength,
111
- num_inference_steps=num_inference_steps,
112
  guidance_scale=params.guidance_scale,
113
  width=params.width,
114
  height=params.height,
115
- original_inference_steps=50,
116
  output_type="pil",
117
  )
118
  nsfw_content_detected = (
 
76
  if psutil.virtual_memory().total < 64 * 1024**3:
77
  pipe.enable_attention_slicing()
78
 
79
+ # if not mps_available and not xpu_available:
80
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
81
+ # pipe(prompt="warmup", image=[Image.new("RGB", (512, 512))])
82
 
83
  compel_proc = Compel(
84
  tokenizer=pipe.tokenizer,
 
89
 
90
 
91
  class InputParams(BaseModel):
 
92
  seed: int = 2159232
93
+ prompt: str
94
  guidance_scale: float = 8.0
95
  strength: float = 0.5
96
+ steps: int = 4
97
+ lcm_steps: int = 50
98
  width: int = WIDTH
99
  height: int = HEIGHT
100
 
101
+ def predict(input_image: Image.Image, params: InputParams, prompt_embeds: torch.Tensor = None):
 
 
 
102
  generator = torch.manual_seed(params.seed)
 
 
103
  results = pipe(
104
  prompt_embeds=prompt_embeds,
105
  generator=generator,
106
  image=input_image,
107
  strength=params.strength,
108
+ num_inference_steps=params.steps,
109
  guidance_scale=params.guidance_scale,
110
  width=params.width,
111
  height=params.height,
112
+ original_inference_steps=params.lcm_steps,
113
  output_type="pil",
114
  )
115
  nsfw_content_detected = (
img2img/index.html CHANGED
@@ -15,13 +15,13 @@
15
  }
16
  </style>
17
  <script type="module">
18
- // you can change the size of the input image to 768x768 if you have a powerful GPU
19
- const WIDTH = 512;
20
- const HEIGHT = 512;
21
- const seedEl = document.querySelector("#seed");
22
- const promptEl = document.querySelector("#prompt");
23
- const guidanceEl = document.querySelector("#guidance-scale");
24
- const strengthEl = document.querySelector("#strength");
25
  const startBtn = document.querySelector("#start");
26
  const stopBtn = document.querySelector("#stop");
27
  const videoEl = document.querySelector("#webcam");
@@ -29,8 +29,9 @@
29
  const queueSizeEl = document.querySelector("#queue_size");
30
  const errorEl = document.querySelector("#error");
31
  const snapBtn = document.querySelector("#snap");
 
32
 
33
- function LCMLive(webcamVideo, liveImage, seedEl, promptEl, guidanceEl, strengthEl) {
34
  let websocket;
35
 
36
  async function start() {
@@ -72,30 +73,73 @@
72
  websocket = socket;
73
  })
74
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  async function videoTimeUpdateHandler() {
 
 
 
77
  const canvas = new OffscreenCanvas(WIDTH, HEIGHT);
78
  const videoW = webcamVideo.videoWidth;
79
  const videoH = webcamVideo.videoHeight;
 
80
 
81
  const ctx = canvas.getContext("2d");
82
- // grap square from center
83
- ctx.drawImage(webcamVideo, videoW / 2 - WIDTH / 2, videoH / 2 - HEIGHT / 2, WIDTH, HEIGHT, 0, 0, canvas.width, canvas.height);
84
  const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
85
  websocket.send(blob);
86
  websocket.send(JSON.stringify({
87
- "seed": seedEl.value,
88
- "prompt": promptEl.value,
89
- "guidance_scale": guidanceEl.value,
90
- "strength": strengthEl.value
 
 
 
 
91
  }));
92
  }
93
-
94
- function initVideoStream(userId) {
95
  liveImage.src = `/stream/${userId}`;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  const constraints = {
97
  audio: false,
98
- video: { width: WIDTH, height: HEIGHT },
99
  };
100
  navigator.mediaDevices
101
  .getUserMedia(constraints)
@@ -117,6 +161,7 @@
117
  mediaStream.getTracks().forEach((track) => track.stop());
118
  });
119
  webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
 
120
  webcamVideo.srcObject = null;
121
  }
122
  return {
@@ -147,7 +192,7 @@
147
  const exif = {};
148
  const gps = {};
149
  zeroth[piexif.ImageIFD.Make] = "LCM Image-to-Image";
150
- zeroth[piexif.ImageIFD.ImageDescription] = `prompt: ${promptEl.value} | seed: ${seedEl.value} | guidance_scale: ${guidanceEl.value} | strength: ${strengthEl.value}`;
151
  zeroth[piexif.ImageIFD.Software] = "https://github.com/radames/Real-Time-Latent-Consistency-Model";
152
 
153
  exif[piexif.ExifIFD.DateTimeOriginal] = new Date().toISOString();
@@ -173,7 +218,7 @@
173
  }
174
 
175
 
176
- const lcmLive = LCMLive(videoEl, imageEl, seedEl, promptEl, guidanceEl, strengthEl);
177
  startBtn.addEventListener("click", async () => {
178
  try {
179
  startBtn.disabled = true;
@@ -249,18 +294,38 @@
249
  <div class="">
250
  <details>
251
  <summary class="font-medium cursor-pointer">Advanced Options</summary>
252
- <div class="grid grid-cols-3 max-w-md items-center gap-3 py-3">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  <label class="text-sm font-medium" for="guidance-scale">Guidance Scale
254
  </label>
255
- <input type="range" id="guidance-scale" name="guidance-scale" min="1" max="30" step="0.001"
256
  value="8.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
257
  <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
258
  8.0</output>
 
259
  <label class="text-sm font-medium" for="strength">Strength</label>
260
- <input type="range" id="strength" name="strength" min="0.02" max="1" step="0.001" value="0.50"
261
  oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
262
  <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
263
  0.5</output>
 
264
  <label class="text-sm font-medium" for="seed">Seed</label>
265
  <input type="number" id="seed" name="seed" value="299792458"
266
  class="font-light border border-gray-700 text-right rounded-md p-2 dark:text-black">
@@ -269,6 +334,22 @@
269
  class="button">
270
  Rand
271
  </button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  </div>
273
  </details>
274
  </div>
 
15
  }
16
  </style>
17
  <script type="module">
18
+ const getValue = (id) => {
19
+ const el = document.querySelector(`${id}`)
20
+ if (el.type === "checkbox")
21
+ return el.checked;
22
+ return el.value;
23
+ }
24
+
25
  const startBtn = document.querySelector("#start");
26
  const stopBtn = document.querySelector("#stop");
27
  const videoEl = document.querySelector("#webcam");
 
29
  const queueSizeEl = document.querySelector("#queue_size");
30
  const errorEl = document.querySelector("#error");
31
  const snapBtn = document.querySelector("#snap");
32
+ const webcamsEl = document.querySelector("#webcams");
33
 
34
+ function LCMLive(webcamVideo, liveImage) {
35
  let websocket;
36
 
37
  async function start() {
 
73
  websocket = socket;
74
  })
75
  }
76
+ function switchCamera() {
77
+ const constraints = {
78
+ audio: false,
79
+ video: { width: 1024, height: 1024, deviceId: mediaDevices[webcamsEl.value].deviceId }
80
+ };
81
+ navigator.mediaDevices
82
+ .getUserMedia(constraints)
83
+ .then((mediaStream) => {
84
+ webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
85
+ webcamVideo.srcObject = mediaStream;
86
+ webcamVideo.onloadedmetadata = () => {
87
+ webcamVideo.play();
88
+ webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
89
+ };
90
+ })
91
+ .catch((err) => {
92
+ console.error(`${err.name}: ${err.message}`);
93
+ });
94
+ }
95
 
96
  async function videoTimeUpdateHandler() {
97
+ const dimension = getValue("input[name=dimension]:checked");
98
+ const [WIDTH, HEIGHT] = JSON.parse(dimension);
99
+
100
  const canvas = new OffscreenCanvas(WIDTH, HEIGHT);
101
  const videoW = webcamVideo.videoWidth;
102
  const videoH = webcamVideo.videoHeight;
103
+ const aspectRatio = WIDTH / HEIGHT;
104
 
105
  const ctx = canvas.getContext("2d");
106
+ ctx.drawImage(webcamVideo, videoW / 2 - videoH * aspectRatio / 2, 0, videoH * aspectRatio, videoH, 0, 0, WIDTH, HEIGHT)
 
107
  const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
108
  websocket.send(blob);
109
  websocket.send(JSON.stringify({
110
+ "seed": getValue("#seed"),
111
+ "prompt": getValue("#prompt"),
112
+ "guidance_scale": getValue("#guidance-scale"),
113
+ "strength": getValue("#strength"),
114
+ "steps": getValue("#steps"),
115
+ "lcm_steps": getValue("#lcm_steps"),
116
+ "width": WIDTH,
117
+ "height": HEIGHT,
118
  }));
119
  }
120
+ let mediaDevices = [];
121
+ async function initVideoStream(userId) {
122
  liveImage.src = `/stream/${userId}`;
123
+ await navigator.mediaDevices.enumerateDevices()
124
+ .then(devices => {
125
+ const cameras = devices.filter(device => device.kind === 'videoinput');
126
+ mediaDevices = cameras;
127
+ webcamsEl.innerHTML = "";
128
+ cameras.forEach((camera, index) => {
129
+ const option = document.createElement("option");
130
+ option.value = index;
131
+ option.innerText = camera.label;
132
+ webcamsEl.appendChild(option);
133
+ option.selected = index === 0;
134
+ });
135
+ webcamsEl.addEventListener("change", switchCamera);
136
+ })
137
+ .catch(err => {
138
+ console.error(err);
139
+ });
140
  const constraints = {
141
  audio: false,
142
+ video: { width: 1024, height: 1024, deviceId: mediaDevices[0].deviceId }
143
  };
144
  navigator.mediaDevices
145
  .getUserMedia(constraints)
 
161
  mediaStream.getTracks().forEach((track) => track.stop());
162
  });
163
  webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
164
+ webcamsEl.removeEventListener("change", switchCamera);
165
  webcamVideo.srcObject = null;
166
  }
167
  return {
 
192
  const exif = {};
193
  const gps = {};
194
  zeroth[piexif.ImageIFD.Make] = "LCM Image-to-Image";
195
+ zeroth[piexif.ImageIFD.ImageDescription] = `prompt: ${getValue("#prompt")} | seed: ${getValue("#seed")} | guidance_scale: ${getValue("#guidance-scale")} | strength: ${getValue("#strength")} | lcm_steps: ${getValue("#lcm_steps")} | steps: ${getValue("#steps")}`;
196
  zeroth[piexif.ImageIFD.Software] = "https://github.com/radames/Real-Time-Latent-Consistency-Model";
197
 
198
  exif[piexif.ExifIFD.DateTimeOriginal] = new Date().toISOString();
 
218
  }
219
 
220
 
221
+ const lcmLive = LCMLive(videoEl, imageEl);
222
  startBtn.addEventListener("click", async () => {
223
  try {
224
  startBtn.disabled = true;
 
294
  <div class="">
295
  <details>
296
  <summary class="font-medium cursor-pointer">Advanced Options</summary>
297
+ <div class="grid grid-cols-3 sm:grid-cols-6 items-center gap-3 py-3">
298
+ <label for="webcams" class="text-sm font-medium">Camera Options: </label>
299
+ <select id="webcams" class="text-sm border-2 border-gray-500 rounded-md font-light dark:text-black">
300
+ </select>
301
+ <div></div>
302
+ <label class="text-sm font-medium " for="steps">Inference Steps
303
+ </label>
304
+ <input type="range" id="steps" name="steps" min="1" max="20" value="4"
305
+ oninput="this.nextElementSibling.value = Number(this.value)">
306
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
307
+ 4</output>
308
+ <!-- -->
309
+ <label class="text-sm font-medium" for="lcm_steps">LCM Inference Steps
310
+ </label>
311
+ <input type="range" id="lcm_steps" name="lcm_steps" min="2" max="60" value="50"
312
+ oninput="this.nextElementSibling.value = Number(this.value)">
313
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
314
+ 50</output>
315
+ <!-- -->
316
  <label class="text-sm font-medium" for="guidance-scale">Guidance Scale
317
  </label>
318
+ <input type="range" id="guidance-scale" name="guidance-scale" min="0" max="30" step="0.001"
319
  value="8.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
320
  <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
321
  8.0</output>
322
+ <!-- -->
323
  <label class="text-sm font-medium" for="strength">Strength</label>
324
+ <input type="range" id="strength" name="strength" min="0.1" max="1" step="0.001" value="0.50"
325
  oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
326
  <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
327
  0.5</output>
328
+ <!-- -->
329
  <label class="text-sm font-medium" for="seed">Seed</label>
330
  <input type="number" id="seed" name="seed" value="299792458"
331
  class="font-light border border-gray-700 text-right rounded-md p-2 dark:text-black">
 
334
  class="button">
335
  Rand
336
  </button>
337
+ <!-- -->
338
+ <!-- -->
339
+ <label class="text-sm font-medium" for="dimension">Image Dimensions</label>
340
+ <div class="col-span-2 flex gap-2">
341
+ <div class="flex gap-1">
342
+ <input type="radio" id="dimension512" name="dimension" value="[512,512]" checked
343
+ class="cursor-pointer">
344
+ <label for="dimension512" class="text-sm cursor-pointer">512x512</label>
345
+ </div>
346
+ <div class="flex gap-1">
347
+ <input type="radio" id="dimension768" name="dimension" value="[768,768]"
348
+ lass="cursor-pointer">
349
+ <label for="dimension768" class="text-sm cursor-pointer">768x768</label>
350
+ </div>
351
+ </div>
352
+ <!-- -->
353
  </div>
354
  </details>
355
  </div>