d0tpy commited on
Commit
29b0bbc
·
verified ·
1 Parent(s): 8419f53
Files changed (2) hide show
  1. app.py +46 -2
  2. image_enhancer.py +123 -0
app.py CHANGED
@@ -1,7 +1,51 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
5
  @app.get("/")
6
  def greet_json():
7
- return {"Initializing GlamApp Enhancer"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
+ from media_enhancer.image_enhancer import EnhancementMethod, Enhancer
4
+ from pydantic import BaseModel
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import base64
8
+ import numpy as np
9
+
10
+ class EnhancementRequest(BaseModel):
11
+ method: EnhancementMethod = EnhancementMethod.gfpgan
12
+ background_enhancement: bool = True
13
+ upscale: int = 2
14
+
15
+ class _EnhanceBase(BaseModel):
16
+ encoded_base_img: List[str]
17
+
18
+
19
 
20
  app = FastAPI()
21
 
22
  @app.get("/")
23
  def greet_json():
24
+ return {"Initializing GlamApp Enhancer"}
25
+
26
+ @app.post("/enhance")
27
+ async def enhance_image(
28
+ file: UploadFile = File(...),
29
+ request: EnhancementRequest = EnhancementRequest()
30
+ ):
31
+ try:
32
+ if not file.content_type.startswith('image/'):
33
+ raise HTTPException(status_code=400, detail="Invalid file type")
34
+
35
+ contents = await file.read()
36
+ base64_encoded_image = base64.b64encode(contents).decode('utf-8')
37
+ data = _EnhanceBase(encoded_base_img=[base64_encoded_image])
38
+
39
+ enhancer = Enhancer(request.method, request.background_enhancement, request.upscale)
40
+
41
+ enhanced_img, original_resolution, enhanced_resolution = await enhancer.enhance(data)
42
+
43
+ enhanced_image = Image.fromarray(enhanced_img)
44
+ img_byte_arr = BytesIO()
45
+ enhanced_image.save(img_byte_arr, format='PNG')
46
+ img_byte_arr.seek(0)
47
+
48
+ return StreamingResponse(img_byte_arr, media_type="image/png")
49
+
50
+ except Exception as e:
51
+ raise HTTPException(status_code=500, detail=str(e))
image_enhancer.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from gfpgan import GFPGANer
4
+ from tqdm import tqdm
5
+ import cv2
6
+ from realesrgan import RealESRGANer
7
+ from basicsr.archs.rrdbnet_arch import RRDBNet
8
+ import warnings
9
+ from enum import Enum
10
+
11
+ class EnhancementMethod(str, Enum):
12
+ gfpgan = "gfpgan"
13
+ RestoreFormer = "RestoreFormer"
14
+ codeformer = "codeformer"
15
+ realesrgan = "realesrgan"
16
+
17
+
18
+ class Enhancer:
19
+ def __init__(self, method: EnhancementMethod, background_enhancement=True, upscale=2):
20
+ self.method = method
21
+ self.background_enhancement = background_enhancement
22
+ self.upscale = upscale
23
+ self.bg_upsampler = None
24
+ self.realesrgan_enhancer = None
25
+
26
+ if self.method != EnhancementMethod.realesrgan:
27
+ self.setup_face_enhancer()
28
+ if self.background_enhancement:
29
+ self.setup_background_enhancer()
30
+ else:
31
+ self.setup_realesrgan_enhancer()
32
+
33
+ def setup_background_enhancer(self):
34
+ if not torch.cuda.is_available():
35
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it.')
36
+ return
37
+
38
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
39
+ model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x{self.upscale}plus.pth'
40
+ self.bg_upsampler = RealESRGANer(
41
+ scale=self.upscale,
42
+ model_path=model_path,
43
+ model=model,
44
+ tile=400,
45
+ tile_pad=10,
46
+ pre_pad=0,
47
+ half=True)
48
+
49
+ def setup_realesrgan_enhancer(self):
50
+ if not torch.cuda.is_available():
51
+ raise ValueError('CUDA is not available for RealESRGAN')
52
+
53
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
54
+ model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x{self.upscale}plus.pth'
55
+ self.realesrgan_enhancer = RealESRGANer(
56
+ scale=self.upscale,
57
+ model_path=model_path,
58
+ model=model,
59
+ tile=400,
60
+ tile_pad=10,
61
+ pre_pad=0,
62
+ half=True)
63
+
64
+ def setup_face_enhancer(self):
65
+ model_configs = {
66
+ EnhancementMethod.gfpgan: {
67
+ 'arch': 'clean',
68
+ 'channel_multiplier': 2,
69
+ 'model_name': 'GFPGANv1.4',
70
+ 'url': 'https://huggingface.co/gmk123/GFPGAN/resolve/main/GFPGANv1.4.pth'
71
+ },
72
+ EnhancementMethod.RestoreFormer: {
73
+ 'arch': 'RestoreFormer',
74
+ 'channel_multiplier': 2,
75
+ 'model_name': 'RestoreFormer',
76
+ 'url': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
77
+ },
78
+ EnhancementMethod.codeformer: {
79
+ 'arch': 'CodeFormer',
80
+ 'channel_multiplier': 2,
81
+ 'model_name': 'CodeFormer',
82
+ 'url': 'https://huggingface.co/sinadi/aar/resolve/main/codeformer.pth'
83
+ }
84
+ }
85
+
86
+ config = model_configs.get(self.method)
87
+ if not config:
88
+ raise ValueError(f'Wrong model version {self.method}')
89
+
90
+ model_path = os.path.join('gfpgan/weights', config['model_name'] + '.pth')
91
+ if not os.path.isfile(model_path):
92
+ model_path = os.path.join('checkpoints', config['model_name'] + '.pth')
93
+ if not os.path.isfile(model_path):
94
+ model_path = config['url']
95
+
96
+ self.face_enhancer = GFPGANer(
97
+ model_path=model_path,
98
+ upscale=self.upscale,
99
+ arch=config['arch'],
100
+ channel_multiplier=config['channel_multiplier'],
101
+ bg_upsampler=self.bg_upsampler)
102
+
103
+ def check_image_resolution(self, image):
104
+ height, width, _ = image.shape
105
+ return width, height
106
+
107
+ async def enhance(self, image):
108
+ img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
109
+ width, height = self.check_image_resolution(img)
110
+
111
+ if self.method == EnhancementMethod.realesrgan:
112
+ enhanced_img, _ = await asyncio.to_thread(self.realesrgan_enhancer.enhance, img, outscale=self.upscale)
113
+ else:
114
+ _, _, enhanced_img = await asyncio.to_thread(self.face_enhancer.enhance,
115
+ img,
116
+ has_aligned=False,
117
+ only_center_face=False,
118
+ paste_back=True)
119
+
120
+ enhanced_img = cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB)
121
+ enhanced_width, enhanced_height = self.check_image_resolution(enhanced_img)
122
+
123
+ return enhanced_img, (width, height), (enhanced_width, enhanced_height)