d0tpy commited on
Commit
8182ca4
·
verified ·
1 Parent(s): 93eb6fd

Update image_enhancer.oy

Browse files
Files changed (1) hide show
  1. image_enhancer.oy +89 -92
image_enhancer.oy CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from gfpgan import GFPGANer
4
  from tqdm import tqdm
5
  import cv2
 
6
  from enum import Enum
7
 
8
  class EnhancementMethod(str, Enum):
@@ -13,112 +14,108 @@ class EnhancementMethod(str, Enum):
13
 
14
 
15
  class Enhancer:
16
- def __init__(self, method=EnhancementMethod, background_enhancement=True, upscale=2):
17
- # Set up RealESRGAN for background enhancement
18
- if background_enhancement:
19
- if upscale == 2:
20
- if not torch.cuda.is_available(): # CPU
21
- import warnings
22
- warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
23
- 'If you really want to use it, please modify the corresponding codes.')
24
- self.bg_upsampler = None
25
- else:
26
- from basicsr.archs.rrdbnet_arch import RRDBNet
27
- from realesrgan import RealESRGANer
28
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
29
- self.bg_upsampler = RealESRGANer(
30
- scale=2,
31
- model_path='https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth',
32
- model=model,
33
- tile=400,
34
- tile_pad=10,
35
- pre_pad=0,
36
- half=True) # need to set False in CPU mode
37
- elif upscale == 4:
38
- if not torch.cuda.is_available(): # CPU
39
- import warnings
40
- warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
41
- 'If you really want to use it, please modify the corresponding codes.')
42
- self.bg_upsampler = None
43
- else:
44
- from basicsr.archs.rrdbnet_arch import RRDBNet
45
- from realesrgan import RealESRGANer
46
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
47
- self.bg_upsampler = RealESRGANer(
48
- scale=4,
49
- model_path='https://huggingface.co/lllyasviel/Annotators/resolve/main/RealESRGAN_x4plus.pth',
50
- model=model,
51
- tile=400,
52
- tile_pad=10,
53
- pre_pad=0,
54
- half=True) # need to set False in CPU mode
55
- else:
56
- raise ValueError(f'Wrong upscale constant {upscale}.')
57
- else:
58
- self.bg_upsampler = None
59
 
60
- # Set up GPFGAN for face enhancement
61
- if method == 'gfpgan':
62
- self.arch = 'clean'
63
- self.channel_multiplier = 2
64
- self.model_name = 'GFPGANv1.4'
65
- self.url = 'https://huggingface.co/gmk123/GFPGAN/resolve/main/GFPGANv1.4.pth'
66
- elif method == 'RestoreFormer':
67
- self.arch = 'RestoreFormer'
68
- self.channel_multiplier = 2
69
- self.model_name = 'RestoreFormer'
70
- self.url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
71
- elif method == 'codeformer': # TODO:
72
- self.arch = 'CodeFormer'
73
- self.channel_multiplier = 2
74
- self.model_name = 'CodeFormer'
75
- self.url = 'https://huggingface.co/sinadi/aar/resolve/main/codeformer.pth'
76
  else:
77
- raise ValueError(f'Wrong model version {method}.')
78
-
79
- # Determine the model path and if the model is not available, download it
80
- model_path = os.path.join('gfpgan/weights', self.model_name + '.pth')
81
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if not os.path.isfile(model_path):
83
- model_path = os.path.join('checkpoints', self.model_name + '.pth')
84
-
85
  if not os.path.isfile(model_path):
86
- # Download pre-trained models from url
87
- model_path = self.url
88
 
89
- self.restorer = GFPGANer(
90
  model_path=model_path,
91
- upscale=upscale,
92
- arch=self.arch,
93
- channel_multiplier=self.channel_multiplier,
94
  bg_upsampler=self.bg_upsampler)
95
-
96
 
97
- def check_image_dimensions(self, image):
98
- # Get the dimensions of the image
99
  height, width, _ = image.shape
100
- return True
101
-
102
- # Check if either dimension exceeds 2048 pixels :Todo
103
- # if width > 2048 or height > 2048:
104
- # return True
105
 
106
- # else:
107
- # print("Image dimensions are within the limit.")
108
- # return True
109
-
110
-
111
- def enhance(self, image):
112
  img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
113
- if self.check_image_dimensions(img):
114
- cropped_faces, restored_faces, r_img = self.restorer.enhance(
 
 
 
 
115
  img,
116
  has_aligned=False,
117
  only_center_face=False,
118
  paste_back=True)
119
- else:
120
- r_img = img
121
 
122
- r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
 
123
 
124
- return r_img
 
3
  from gfpgan import GFPGANer
4
  from tqdm import tqdm
5
  import cv2
6
+ import warnings
7
  from enum import Enum
8
 
9
  class EnhancementMethod(str, Enum):
 
14
 
15
 
16
  class Enhancer:
17
+ def __init__(self, method: EnhancementMethod, background_enhancement=True, upscale=2):
18
+ self.method = method
19
+ self.background_enhancement = background_enhancement
20
+ self.upscale = upscale
21
+ self.bg_upsampler = None
22
+ self.realesrgan_enhancer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ if self.method != EnhancementMethod.realesrgan:
25
+ self.setup_face_enhancer()
26
+ if self.background_enhancement:
27
+ self.setup_background_enhancer()
 
 
 
 
 
 
 
 
 
 
 
 
28
  else:
29
+ self.setup_realesrgan_enhancer()
30
+
31
+ def setup_background_enhancer(self):
32
+ if not torch.cuda.is_available():
33
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it.')
34
+ return
35
+
36
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
37
+ model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x{self.upscale}plus.pth'
38
+ self.bg_upsampler = RealESRGANer(
39
+ scale=self.upscale,
40
+ model_path=model_path,
41
+ model=model,
42
+ tile=400,
43
+ tile_pad=10,
44
+ pre_pad=0,
45
+ half=True)
46
+
47
+ def setup_realesrgan_enhancer(self):
48
+ if not torch.cuda.is_available():
49
+ raise ValueError('CUDA is not available for RealESRGAN')
50
+
51
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
52
+ model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x{self.upscale}plus.pth'
53
+ self.realesrgan_enhancer = RealESRGANer(
54
+ scale=self.upscale,
55
+ model_path=model_path,
56
+ model=model,
57
+ tile=400,
58
+ tile_pad=10,
59
+ pre_pad=0,
60
+ half=True)
61
+
62
+ def setup_face_enhancer(self):
63
+ model_configs = {
64
+ EnhancementMethod.gfpgan: {
65
+ 'arch': 'clean',
66
+ 'channel_multiplier': 2,
67
+ 'model_name': 'GFPGANv1.4',
68
+ 'url': 'https://huggingface.co/gmk123/GFPGAN/resolve/main/GFPGANv1.4.pth'
69
+ },
70
+ EnhancementMethod.RestoreFormer: {
71
+ 'arch': 'RestoreFormer',
72
+ 'channel_multiplier': 2,
73
+ 'model_name': 'RestoreFormer',
74
+ 'url': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
75
+ },
76
+ EnhancementMethod.codeformer: {
77
+ 'arch': 'CodeFormer',
78
+ 'channel_multiplier': 2,
79
+ 'model_name': 'CodeFormer',
80
+ 'url': 'https://huggingface.co/sinadi/aar/resolve/main/codeformer.pth'
81
+ }
82
+ }
83
+
84
+ config = model_configs.get(self.method)
85
+ if not config:
86
+ raise ValueError(f'Wrong model version {self.method}')
87
+
88
+ model_path = os.path.join('gfpgan/weights', config['model_name'] + '.pth')
89
  if not os.path.isfile(model_path):
90
+ model_path = os.path.join('checkpoints', config['model_name'] + '.pth')
 
91
  if not os.path.isfile(model_path):
92
+ model_path = config['url']
 
93
 
94
+ self.face_enhancer = GFPGANer(
95
  model_path=model_path,
96
+ upscale=self.upscale,
97
+ arch=config['arch'],
98
+ channel_multiplier=config['channel_multiplier'],
99
  bg_upsampler=self.bg_upsampler)
 
100
 
101
+ def check_image_resolution(self, image):
 
102
  height, width, _ = image.shape
103
+ return width, height
 
 
 
 
104
 
105
+ async def enhance(self, image):
 
 
 
 
 
106
  img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
107
+ width, height = self.check_image_resolution(img)
108
+
109
+ if self.method == EnhancementMethod.realesrgan:
110
+ enhanced_img, _ = await asyncio.to_thread(self.realesrgan_enhancer.enhance, img, outscale=self.upscale)
111
+ else:
112
+ _, _, enhanced_img = await asyncio.to_thread(self.face_enhancer.enhance,
113
  img,
114
  has_aligned=False,
115
  only_center_face=False,
116
  paste_back=True)
 
 
117
 
118
+ enhanced_img = cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB)
119
+ enhanced_width, enhanced_height = self.check_image_resolution(enhanced_img)
120
 
121
+ return enhanced_img, (width, height), (enhanced_width, enhanced_height)