Spaces:
Configuration error
Configuration error
folder gradio_app
Browse files
app.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
base_path = os.path.expanduser('~')
|
7 |
+
|
8 |
+
sys.path.append(os.path.join(base_path, 'Er0mangaSeg/'))
|
9 |
+
sys.path.append(os.path.join(base_path, 'Er0mangaSeg/demo'))
|
10 |
+
from image_demo_tta import init_seg_model, inference_tta
|
11 |
+
|
12 |
+
sys.path.append(os.path.join(base_path, 'Er0mangaInpaint/'))
|
13 |
+
sys.path.append(os.path.join(base_path, 'Er0mangaInpaint/bin'))
|
14 |
+
from uncen import init_inpaint_model, inpaint
|
15 |
+
|
16 |
+
|
17 |
+
import time
|
18 |
+
import numpy as np
|
19 |
+
import cv2
|
20 |
+
import shutil
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
print('GPU found!')
|
26 |
+
device = 'cuda:0'
|
27 |
+
else:
|
28 |
+
print('GPU not found! Using CPU')
|
29 |
+
device = 'cpu'
|
30 |
+
|
31 |
+
|
32 |
+
config = os.path.join(base_path, 'Er0mangaSeg/configs/convnext/convnext_h.py')
|
33 |
+
checkpoint = os.path.join(base_path, 'Er0mangaSeg/pretrained/convnext_1024_iter_400.pth')
|
34 |
+
model_seg = init_seg_model(config, checkpoint, device=device)
|
35 |
+
print('Segmentation initialized')
|
36 |
+
|
37 |
+
|
38 |
+
inp_model_path = os.path.join(base_path, 'Er0mangaInpaint/pretrained/00-30-09')
|
39 |
+
model_inp = init_inpaint_model(inp_model_path)
|
40 |
+
print('Inpainting initialized')
|
41 |
+
|
42 |
+
|
43 |
+
def proc(input_img):
|
44 |
+
|
45 |
+
try:
|
46 |
+
|
47 |
+
s = time.time()
|
48 |
+
|
49 |
+
out_mask, raw_mask = inference_tta(model_seg, input_img)
|
50 |
+
out_mask = np.dstack([out_mask, out_mask, out_mask])
|
51 |
+
raw_mask = np.dstack([raw_mask, raw_mask, raw_mask])
|
52 |
+
|
53 |
+
output_img, out_dbg = inpaint(model_inp, input_img, out_mask)
|
54 |
+
|
55 |
+
e = time.time()
|
56 |
+
print(f"proc_time: {e-s:.2f}")
|
57 |
+
|
58 |
+
return output_img#, raw_mask
|
59 |
+
|
60 |
+
except Exception as e:
|
61 |
+
raise gr.Error(e)
|
62 |
+
|
63 |
+
|
64 |
+
def proc_batch(batch):
|
65 |
+
|
66 |
+
res = []
|
67 |
+
try:
|
68 |
+
|
69 |
+
s = time.time()
|
70 |
+
|
71 |
+
out_p = os.path.dirname(batch[0][0])
|
72 |
+
salt = str(np.random.randint(1e10))
|
73 |
+
out_p_d = os.path.join(out_p, '__salt_img__'+salt)
|
74 |
+
out_p_m = os.path.join(out_p, '__salt_mask__'+salt)
|
75 |
+
os.mkdir(out_p_d)
|
76 |
+
os.mkdir(out_p_m)
|
77 |
+
|
78 |
+
for i in range(len(batch)):
|
79 |
+
input_path = batch[i][0]
|
80 |
+
inp_name = os.path.basename(input_path)
|
81 |
+
input_img = cv2.cvtColor(cv2.imread(input_path), cv2.COLOR_BGR2RGB)
|
82 |
+
|
83 |
+
out_mask, raw_mask = inference_tta(model_seg, input_img)
|
84 |
+
out_mask = np.dstack([out_mask, out_mask, out_mask])
|
85 |
+
raw_mask = np.dstack([raw_mask, raw_mask, raw_mask])
|
86 |
+
|
87 |
+
output_img, out_dbg = inpaint(model_inp, input_img, out_mask)
|
88 |
+
out_path_img = os.path.join(out_p_d, inp_name)
|
89 |
+
out_path_mask = os.path.join(out_p_m, inp_name+'.png')
|
90 |
+
cv2.imwrite(out_path_img, cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB))
|
91 |
+
cv2.imwrite(out_path_mask, raw_mask)
|
92 |
+
res.append(out_path_img)
|
93 |
+
|
94 |
+
ar_path = os.path.join(out_p, 'output')
|
95 |
+
shutil.make_archive(ar_path, 'zip', out_p_d)
|
96 |
+
|
97 |
+
ar_path_m = os.path.join(out_p, 'output_mask')
|
98 |
+
shutil.make_archive(ar_path_m, 'zip', out_p_m)
|
99 |
+
|
100 |
+
e = time.time()
|
101 |
+
print(f"batch proc_time: {e-s:.2f}")
|
102 |
+
|
103 |
+
return res, ar_path + '.zip', ar_path_m + '.zip'
|
104 |
+
|
105 |
+
except Exception as e:
|
106 |
+
raise gr.Error(e)
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
demo1 = gr.Interface(proc, gr.Image(), gr.Image(format='png'), delete_cache=(7200, 7200), allow_flagging='never')
|
111 |
+
demo2 = gr.Interface(proc_batch, gr.Gallery(), [gr.Gallery(value='str', format='png'), gr.File(), gr.File()], delete_cache=(7200, 7200), allow_flagging='never')
|
112 |
+
demo = gr.TabbedInterface([demo1, demo2], ["Single image processing", "Batch processing (experimental)"])
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
demo.launch(server_name='0.0.0.0', server_port=7860)
|
116 |
+
|
117 |
+
|