jhj0517 commited on
Commit
bb7ed78
·
1 Parent(s): 12a48eb

Add RESRGAN inferencer

Browse files
modules/image_restoration/real_esrgan_inferencer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ from typing import Optional
7
+ from RealESRGAN import RealESRGAN
8
+
9
+ from modules.utils.paths import *
10
+ from .model_downloader import *
11
+
12
+
13
+ class RealESRGANInferencer:
14
+ def __init__(self,
15
+ model_dir: str = MODELS_REAL_ESRGAN_DIR,
16
+ output_dir: str = OUTPUTS_DIR):
17
+ self.model_dir = model_dir
18
+ self.output_dir = output_dir
19
+ self.device = self.get_device()
20
+ self.model = None
21
+ self.available_models = list(MODELS_REALESRGAN_URL.keys())
22
+
23
+ def load_model(self,
24
+ model_name: Optional[str] = None,
25
+ scale: int = 1,
26
+ progress: gr.Progress = gr.Progress()):
27
+ if model_name is None:
28
+ model_name = "realesr-general-x4v3"
29
+ if not model_name.endswith(".pth"):
30
+ model_name += ".pth"
31
+ model_path = os.path.join(self.model_dir, model_name)
32
+
33
+ if not os.path.exists(model_path):
34
+ progress(0, f"Downloading RealESRGAN model to : {model_path}")
35
+ name, ext = os.path.splitext(model_name)
36
+ download_resrgan_model(model_path, MODELS_REALESRGAN_URL[name])
37
+
38
+ if self.model is None:
39
+ self.model = RealESRGAN(self.device, scale=scale)
40
+ self.model.load_weights(model_path=model_path, download=False)
41
+
42
+ def restore_image(self,
43
+ img_path: str,
44
+ overwrite: bool = True):
45
+ if self.model is None:
46
+ self.load_model()
47
+
48
+ try:
49
+ img = Image.open(img_path).convert('RGB')
50
+ sr_img = self.model.predict(img)
51
+ if overwrite:
52
+ output_path = img_path
53
+ else:
54
+ output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
55
+ sr_img.save(output_path)
56
+ except Exception as e:
57
+ raise
58
+
59
+ @staticmethod
60
+ def get_device():
61
+ if torch.cuda.is_available():
62
+ return "cuda"
63
+ elif torch.backends.mps.is_available():
64
+ return "mps"
65
+ else:
66
+ return "cpu"