Spaces:
Running
Running
fixed live demo app, converted network for onnx convertion, fixed code
Browse files- .gitignore +18 -0
- app.py +4 -49
- config/test.yaml +3 -2
- converter.py +35 -0
- live_demo.py +71 -0
- live_mp.py +112 -0
- output.mp4 +0 -0
- requirements.txt +3 -1
- tools/__init__.py +1 -1
- tools/engine.py +12 -0
- tools/inference.py +48 -9
- tools/model.py +14 -5
- tools/stylematte.py +1 -21
- tools/util.py +49 -6
.gitignore
CHANGED
@@ -173,3 +173,21 @@ __pycache__/*
|
|
173 |
flagged/
|
174 |
# assets/
|
175 |
.DS_store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
flagged/
|
174 |
# assets/
|
175 |
.DS_store
|
176 |
+
|
177 |
+
|
178 |
+
config/*
|
179 |
+
trainer/__pycache__/
|
180 |
+
trainer/__pycache__/*
|
181 |
+
__pycache__/*
|
182 |
+
checkpoints/*.pth
|
183 |
+
*/*.pth
|
184 |
+
*/checkpoints/best_pure.pth
|
185 |
+
checkpoints/best_pure.pth
|
186 |
+
*.ipynb
|
187 |
+
.ipynb_checkpoints/*
|
188 |
+
flagged/
|
189 |
+
assets/*
|
190 |
+
*.html
|
191 |
+
checkpoints/*.onnx
|
192 |
+
*.avi
|
193 |
+
*.onnx
|
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from tools import Inference, Matting, log
|
3 |
from omegaconf import OmegaConf
|
4 |
import os
|
5 |
import sys
|
@@ -9,56 +9,11 @@ from PIL import Image
|
|
9 |
|
10 |
args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
|
11 |
|
12 |
-
global_comp = None
|
13 |
-
global_mask = None
|
14 |
-
|
15 |
log("Model loading")
|
16 |
phnet = Inference(**args)
|
17 |
stylematte = Matting(**args)
|
18 |
log("Model loaded")
|
19 |
|
20 |
-
|
21 |
-
def harmonize(comp, mask):
|
22 |
-
log("Inference started")
|
23 |
-
if comp is None or mask is None:
|
24 |
-
log("Empty source")
|
25 |
-
return np.zeros((16, 16, 3))
|
26 |
-
|
27 |
-
comp = comp.convert('RGB')
|
28 |
-
mask = mask.convert('1')
|
29 |
-
in_shape = comp.size[::-1]
|
30 |
-
|
31 |
-
comp = tf.resize(comp, [args.image_size, args.image_size])
|
32 |
-
mask = tf.resize(mask, [args.image_size, args.image_size])
|
33 |
-
|
34 |
-
compt = tf.to_tensor(comp)
|
35 |
-
maskt = tf.to_tensor(mask)
|
36 |
-
res = phnet.harmonize(compt, maskt)
|
37 |
-
res = tf.resize(res, in_shape)
|
38 |
-
|
39 |
-
log("Inference finished")
|
40 |
-
|
41 |
-
return np.uint8((res*255)[0].permute(1, 2, 0).numpy())
|
42 |
-
|
43 |
-
|
44 |
-
def extract_matte(img, back):
|
45 |
-
mask, fg = stylematte.extract(img)
|
46 |
-
fg_pil = Image.fromarray(np.uint8(fg))
|
47 |
-
|
48 |
-
composite = fg + (1 - mask[:, :, None]) * \
|
49 |
-
np.array(back.resize(mask.shape[::-1]))
|
50 |
-
composite_pil = Image.fromarray(np.uint8(composite))
|
51 |
-
|
52 |
-
global_comp = composite_pil
|
53 |
-
global_mask = mask
|
54 |
-
|
55 |
-
return [composite_pil, mask, fg_pil]
|
56 |
-
|
57 |
-
|
58 |
-
def css(height=3, scale=2):
|
59 |
-
return f".output_image {{height: {height}rem !important; width: {scale}rem !important;}}"
|
60 |
-
|
61 |
-
|
62 |
with gr.Blocks() as demo:
|
63 |
gr.Markdown(
|
64 |
"""
|
@@ -97,11 +52,11 @@ with gr.Blocks() as demo:
|
|
97 |
harmonized_ui = gr.Image(
|
98 |
type="pil", label='Harmonized composite', css=css(3, 3))
|
99 |
|
100 |
-
btn_compose.click(extract_matte, inputs=[input_ui, back_ui], outputs=[
|
101 |
composite_ui, matte_ui, fg_ui])
|
102 |
-
btn_harmonize.click(harmonize, inputs=[
|
103 |
composite_ui, matte_ui], outputs=[harmonized_ui])
|
104 |
|
105 |
|
106 |
log("Interface created")
|
107 |
-
demo.launch(share=
|
|
|
1 |
import gradio as gr
|
2 |
+
from tools import Inference, Matting, log, extract_matte, harmonize, css
|
3 |
from omegaconf import OmegaConf
|
4 |
import os
|
5 |
import sys
|
|
|
9 |
|
10 |
args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
|
11 |
|
|
|
|
|
|
|
12 |
log("Model loading")
|
13 |
phnet = Inference(**args)
|
14 |
stylematte = Matting(**args)
|
15 |
log("Model loaded")
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
with gr.Blocks() as demo:
|
18 |
gr.Markdown(
|
19 |
"""
|
|
|
52 |
harmonized_ui = gr.Image(
|
53 |
type="pil", label='Harmonized composite', css=css(3, 3))
|
54 |
|
55 |
+
btn_compose.click(lambda x, y: extract_matte(x, y, stylematte), inputs=[input_ui, back_ui], outputs=[
|
56 |
composite_ui, matte_ui, fg_ui])
|
57 |
+
btn_harmonize.click(lambda x, y: harmonize(x, y, phnet), inputs=[
|
58 |
composite_ui, matte_ui], outputs=[harmonized_ui])
|
59 |
|
60 |
|
61 |
log("Interface created")
|
62 |
+
demo.launch(share=False)
|
config/test.yaml
CHANGED
@@ -5,7 +5,8 @@ init_value: 0.8
|
|
5 |
skips: True
|
6 |
device: 'cpu'
|
7 |
checkpoint:
|
8 |
-
|
|
|
9 |
matting: "checkpoints/stylematte.pth"
|
10 |
-
|
11 |
image_size: 1024
|
|
|
5 |
skips: True
|
6 |
device: 'cpu'
|
7 |
checkpoint:
|
8 |
+
matting_onnx: "checkpoints/stylematte_720.onnx"
|
9 |
+
harmonizer: "checkpoints/ffhqh1024.pth"
|
10 |
matting: "checkpoints/stylematte.pth"
|
11 |
+
onnx: False
|
12 |
image_size: 1024
|
converter.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision
|
2 |
+
import io
|
3 |
+
import numpy as np
|
4 |
+
import torch.onnx
|
5 |
+
import onnx
|
6 |
+
from tools import Inference, Matting, log, extract_matte, harmonize, css, execute_onnx_model
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import torchvision.transforms.functional as tf
|
13 |
+
from PIL import Image
|
14 |
+
import cv2 as cv
|
15 |
+
from onnxruntime import InferenceSession
|
16 |
+
|
17 |
+
args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
|
18 |
+
|
19 |
+
log("Model loading")
|
20 |
+
phnet = Inference(**args)
|
21 |
+
stylematte = Matting(**args)
|
22 |
+
log("Model loaded")
|
23 |
+
model = stylematte.model
|
24 |
+
|
25 |
+
x = torch.randn((1, 3, 720, 1280))
|
26 |
+
mask = torch.ones((1, 1, 512, 512))
|
27 |
+
path = 'checkpoints/stylematte-test.onnx'
|
28 |
+
|
29 |
+
# Export
|
30 |
+
torch.onnx.export(model, x, path, opset_version=16)
|
31 |
+
|
32 |
+
# Validation
|
33 |
+
onnx_model = onnx.load(path)
|
34 |
+
onnx.checker.check_model(onnx_model)
|
35 |
+
# execute_onnx_model(x, onnx_model)
|
live_demo.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from tools import Inference, Matting, log, extract_matte, harmonize, css, live_matting_step
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import numpy as np
|
7 |
+
import torchvision.transforms.functional as tf
|
8 |
+
from PIL import Image
|
9 |
+
import cv2 as cv
|
10 |
+
import time
|
11 |
+
import asyncio
|
12 |
+
|
13 |
+
args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
|
14 |
+
|
15 |
+
log("Model loading")
|
16 |
+
phnet = Inference(**args)
|
17 |
+
stylematte = Matting(**args)
|
18 |
+
log("Model loaded")
|
19 |
+
|
20 |
+
|
21 |
+
async def show(queue):
|
22 |
+
while True:
|
23 |
+
log("SHOW FRAME")
|
24 |
+
frame = queue.get()
|
25 |
+
cv.imshow('Video', frame)
|
26 |
+
await asyncio.sleep(0.01)
|
27 |
+
|
28 |
+
|
29 |
+
async def main(queue):
|
30 |
+
video = cv.VideoCapture(0)
|
31 |
+
fps = 10
|
32 |
+
counter = 0
|
33 |
+
frame_count = 0
|
34 |
+
if not video.isOpened():
|
35 |
+
raise Exception('Video is not opened!')
|
36 |
+
begin = time.time()
|
37 |
+
for i in range(300):
|
38 |
+
counter += 1
|
39 |
+
frame_count += 1
|
40 |
+
ret, frame = video.read() # Capture frame-by-frame
|
41 |
+
inp = np.array(frame)
|
42 |
+
back = np.zeros_like(frame)
|
43 |
+
queue.put(inp)
|
44 |
+
# res = asyncio.ensure_future(
|
45 |
+
# live_matting_step(inp, back, stylematte))
|
46 |
+
# res = await live_matting_step(inp, back, stylematte)
|
47 |
+
log(f"{i} await")
|
48 |
+
# Display the resulting frame
|
49 |
+
|
50 |
+
# blurred_frame = cv.blur(frame, (10, 10))
|
51 |
+
end = time.time()
|
52 |
+
log(f"frames: {frame_count}, time: {end - begin}, fps: {frame_count/(end - begin) }")
|
53 |
+
|
54 |
+
if cv.waitKey(1) & 0xFF == ord('q'):
|
55 |
+
break
|
56 |
+
|
57 |
+
end = time.time()
|
58 |
+
log(f"OVERALL TIME CONSUMED: {end - begin}, frames: {frame_count}, fps: {frame_count/(end - begin) }")
|
59 |
+
# release the capture
|
60 |
+
video.release()
|
61 |
+
cv.destroyAllWindows()
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
queue = asyncio.Queue()
|
66 |
+
loop = asyncio.get_event_loop()
|
67 |
+
# asyncio.ensure_future(show(frame)) # Display the resulting frame
|
68 |
+
|
69 |
+
loop.run_until_complete(main(queue))
|
70 |
+
loop.run_until_complete(show(queue))
|
71 |
+
loop.run_forever()
|
live_mp.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocessing import Process, Queue
|
2 |
+
import gradio as gr
|
3 |
+
from tools import Inference, Matting, log, extract_matte, harmonize, css, live_matting_step
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import numpy as np
|
8 |
+
import torchvision.transforms.functional as tf
|
9 |
+
from PIL import Image
|
10 |
+
import cv2 as cv
|
11 |
+
import time
|
12 |
+
import asyncio
|
13 |
+
|
14 |
+
|
15 |
+
def show(queue, stack):
|
16 |
+
print(f"PROCESS {3}")
|
17 |
+
# while not queue.empty():
|
18 |
+
if stack.empty():
|
19 |
+
frame = queue.get()
|
20 |
+
else:
|
21 |
+
frame = stack.get(block=False)
|
22 |
+
cv.imshow('Video', np.uint8(frame))
|
23 |
+
log("PID: 3, SHOW FRAME")
|
24 |
+
print(frame.shape)
|
25 |
+
time.sleep(0.1)
|
26 |
+
|
27 |
+
|
28 |
+
def extract(queue, stack, model):
|
29 |
+
'''
|
30 |
+
img: np.array,
|
31 |
+
back: np.array,
|
32 |
+
model: Matting instance
|
33 |
+
'''
|
34 |
+
print(f"PROCESS {2}")
|
35 |
+
img = queue.get()
|
36 |
+
back = np.zeros_like(img)
|
37 |
+
mask, fg = model.extract(img)
|
38 |
+
composite = fg + (1 - mask[:, :, None]) * \
|
39 |
+
back # .resize(mask.shape[::-1])
|
40 |
+
stack.put(np.uint8(composite))
|
41 |
+
# time.sleep(0.1)
|
42 |
+
print("PID: 2, LIVE STEP")
|
43 |
+
# for i in range(10):
|
44 |
+
# print(f"In live {i}")
|
45 |
+
# cv.imshow('Video', np.uint8(composite))
|
46 |
+
# return composite
|
47 |
+
|
48 |
+
|
49 |
+
def main(queue):
|
50 |
+
|
51 |
+
log(f"PROCESS {1}")
|
52 |
+
video = cv.VideoCapture(0)
|
53 |
+
fps = 10
|
54 |
+
counter = 0
|
55 |
+
frame_count = 0
|
56 |
+
if not video.isOpened():
|
57 |
+
raise Exception('Video is not opened!')
|
58 |
+
begin = time.time()
|
59 |
+
# stack = Queue()
|
60 |
+
for i in range(10):
|
61 |
+
counter += 1
|
62 |
+
frame_count += 1
|
63 |
+
ret, frame = video.read() # Capture frame-by-frame
|
64 |
+
inp = np.array(frame)
|
65 |
+
back = np.zeros_like(frame)
|
66 |
+
# res = asyncio.ensure_future(
|
67 |
+
# live_matting_step(inp, back, stylematte))
|
68 |
+
# res = live_matting_step(inp, back, stylematte)
|
69 |
+
queue.put(inp)
|
70 |
+
mp.sleep(0.1)
|
71 |
+
# Display the resulting frame
|
72 |
+
|
73 |
+
# blurred_frame = cv.blur(frame, (10, 10))
|
74 |
+
counter = 0
|
75 |
+
end = time.time()
|
76 |
+
log(f"PID: 1, frames: {frame_count}, time: {end - begin}, fps: {frame_count/(end - begin) }")
|
77 |
+
# else:
|
78 |
+
# show(queue) # Display the resulting frame
|
79 |
+
|
80 |
+
if cv.waitKey(1) & 0xFF == ord('q'):
|
81 |
+
break
|
82 |
+
end = time.time()
|
83 |
+
log(f"OVERALL TIME CONSUMED: {end - begin}, frames: {frame_count}, fps: {frame_count/(end - begin) }")
|
84 |
+
# release the capture
|
85 |
+
video.release()
|
86 |
+
cv.destroyAllWindows()
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
queue = Queue() # Создаем канал
|
91 |
+
stack = Queue() # Создаем канал
|
92 |
+
# stack = Queue() # Создаем канал
|
93 |
+
args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
|
94 |
+
|
95 |
+
log("Model loading")
|
96 |
+
phnet = Inference(**args)
|
97 |
+
stylematte = Matting(**args)
|
98 |
+
log("Model loaded")
|
99 |
+
|
100 |
+
p1 = Process(target=main, args=(queue,)) # Вводим параметры
|
101 |
+
|
102 |
+
p2 = Process(target=extract, args=(
|
103 |
+
queue, stack, stylematte)) # Вводим параметры
|
104 |
+
p3 = Process(target=show, args=(queue, stack)) # Вводим параметры
|
105 |
+
# p2 = Process(target=test_2, args=("Пончик", queue,)) # Вводим параметры
|
106 |
+
|
107 |
+
p1.start()
|
108 |
+
p2.start()
|
109 |
+
p3.start()
|
110 |
+
p3.join()
|
111 |
+
p2.join()
|
112 |
+
p1.join()
|
output.mp4
ADDED
Binary file (258 Bytes). View file
|
|
requirements.txt
CHANGED
@@ -35,4 +35,6 @@ torchaudio==0.11.0
|
|
35 |
torchvision==0.12.0
|
36 |
tornado==6.2
|
37 |
tqdm==4.64.1
|
38 |
-
transformers==4.28.1
|
|
|
|
|
|
35 |
torchvision==0.12.0
|
36 |
tornado==6.2
|
37 |
tqdm==4.64.1
|
38 |
+
transformers==4.28.1
|
39 |
+
onnx==1.14.1
|
40 |
+
onnxruntime==1.16.0
|
tools/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .inference import Inference
|
2 |
from .inference import Matting
|
3 |
-
from .util import
|
|
|
1 |
from .inference import Inference
|
2 |
from .inference import Matting
|
3 |
+
from .util import *
|
tools/engine.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import onnx
|
2 |
+
from onnxruntime import InferenceSession
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def execute_onnx_model(x, onnx_model) -> None:
|
8 |
+
sess = InferenceSession(onnx_model.SerializeToString(), providers=[
|
9 |
+
'AzureExecutionProvider', 'CPUExecutionProvider'])
|
10 |
+
out = sess.run(None, {'input.1': x.numpy().astype(np.float32)})[0]
|
11 |
+
|
12 |
+
return out
|
tools/inference.py
CHANGED
@@ -4,6 +4,11 @@ import torchvision.transforms.functional as tf
|
|
4 |
from .util import inference_img, log
|
5 |
from .stylematte import StyleMatte
|
6 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
class Inference:
|
@@ -15,7 +20,6 @@ class Inference:
|
|
15 |
grid_count=self.grid_counts,
|
16 |
init_weights=self.init_weights,
|
17 |
init_value=self.init_value)
|
18 |
-
log(f"checkpoint: {self.checkpoint.harmonizer}")
|
19 |
state = torch.load(self.checkpoint.harmonizer,
|
20 |
map_location=self.device)
|
21 |
|
@@ -29,12 +33,13 @@ class Inference:
|
|
29 |
mask = mask.unsqueeze(0)
|
30 |
composite = tf.resize(composite, [self.image_size, self.image_size])
|
31 |
mask = tf.resize(mask, [self.image_size, self.image_size])
|
|
|
32 |
log(composite.shape, mask.shape)
|
33 |
with torch.no_grad():
|
34 |
-
harmonized = self.model(composite, mask)['harmonized']
|
35 |
|
36 |
result = harmonized * mask + composite * (1-mask)
|
37 |
-
|
38 |
return result
|
39 |
|
40 |
|
@@ -42,15 +47,49 @@ class Matting:
|
|
42 |
def __init__(self, **kwargs):
|
43 |
self.rank = 0
|
44 |
self.__dict__.update(kwargs)
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
50 |
|
51 |
def extract(self, inp):
|
52 |
-
mask = inference_img(self.model, inp, self.device)
|
53 |
inp_np = np.array(inp)
|
54 |
fg = mask[:, :, None]*inp_np
|
55 |
|
56 |
return [mask, fg]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from .util import inference_img, log
|
5 |
from .stylematte import StyleMatte
|
6 |
import numpy as np
|
7 |
+
import onnx
|
8 |
+
from .engine import execute_onnx_model
|
9 |
+
import cv2
|
10 |
+
from torchvision import transforms
|
11 |
+
import time
|
12 |
|
13 |
|
14 |
class Inference:
|
|
|
20 |
grid_count=self.grid_counts,
|
21 |
init_weights=self.init_weights,
|
22 |
init_value=self.init_value)
|
|
|
23 |
state = torch.load(self.checkpoint.harmonizer,
|
24 |
map_location=self.device)
|
25 |
|
|
|
33 |
mask = mask.unsqueeze(0)
|
34 |
composite = tf.resize(composite, [self.image_size, self.image_size])
|
35 |
mask = tf.resize(mask, [self.image_size, self.image_size])
|
36 |
+
|
37 |
log(composite.shape, mask.shape)
|
38 |
with torch.no_grad():
|
39 |
+
harmonized = self.model(composite, mask) # ['harmonized']
|
40 |
|
41 |
result = harmonized * mask + composite * (1-mask)
|
42 |
+
|
43 |
return result
|
44 |
|
45 |
|
|
|
47 |
def __init__(self, **kwargs):
|
48 |
self.rank = 0
|
49 |
self.__dict__.update(kwargs)
|
50 |
+
if self.onnx:
|
51 |
+
self.model = onnx.load(self.checkpoint.matting_onnx)
|
52 |
+
else:
|
53 |
+
self.model = StyleMatte().to(self.device)
|
54 |
+
state = torch.load(self.checkpoint.matting,
|
55 |
+
map_location=self.device)
|
56 |
+
self.model.load_state_dict(state, strict=True)
|
57 |
+
self.model.eval()
|
58 |
|
59 |
def extract(self, inp):
|
60 |
+
mask = inference_img(self.model, inp, self.device, self.onnx)
|
61 |
inp_np = np.array(inp)
|
62 |
fg = mask[:, :, None]*inp_np
|
63 |
|
64 |
return [mask, fg]
|
65 |
+
|
66 |
+
|
67 |
+
def inference_img(model, img, device='cpu', onnx=True):
|
68 |
+
beg = time.time()
|
69 |
+
h, w, _ = img.shape
|
70 |
+
# print(img.shape)
|
71 |
+
if h % 8 != 0 or w % 8 != 0:
|
72 |
+
img = cv2.copyMakeBorder(img, 8-h % 8, 0, 8-w %
|
73 |
+
8, 0, cv2.BORDER_REFLECT)
|
74 |
+
# print(img.shape)
|
75 |
+
|
76 |
+
tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device)
|
77 |
+
input_t = tensor_img/255.0
|
78 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
79 |
+
std=[0.229, 0.224, 0.225])
|
80 |
+
input_t = normalize(input_t)
|
81 |
+
input_t = input_t.unsqueeze(0).float()
|
82 |
+
end_p = time.time()
|
83 |
+
|
84 |
+
if onnx:
|
85 |
+
out = execute_onnx_model(input_t, model)
|
86 |
+
else:
|
87 |
+
with torch.no_grad():
|
88 |
+
out = model(input_t).cpu().numpy()
|
89 |
+
end = time.time()
|
90 |
+
log(f"Inference time: {end-beg}, processing time: {end_p-beg}")
|
91 |
+
# print("out",out.shape)
|
92 |
+
result = out[0][:, -h:, -w:]
|
93 |
+
# print(result.shape)
|
94 |
+
|
95 |
+
return result[0]
|
tools/model.py
CHANGED
@@ -133,6 +133,16 @@ class Inference_Data(Dataset):
|
|
133 |
return self.data_len
|
134 |
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
class SEBlock(nn.Module):
|
137 |
def __init__(self, channel, reducation=8):
|
138 |
super(SEBlock, self).__init__()
|
@@ -152,7 +162,8 @@ class SEBlock(nn.Module):
|
|
152 |
y = self.fc(y1).view(b, c, 1, 1)
|
153 |
r = x*y
|
154 |
if aux_inp is not None:
|
155 |
-
aux_weitghts =
|
|
|
156 |
aux_weitghts = nn.Sigmoid()(aux_weitghts.mean(1, keepdim=True))
|
157 |
tmp = x*aux_weitghts
|
158 |
tmp_img = (tmp - tmp.min()) / (tmp.max() - tmp.min())
|
@@ -283,11 +294,9 @@ class PHNet(nn.Module):
|
|
283 |
x = self.skip[i](x)
|
284 |
x = up_layer(x)
|
285 |
|
286 |
-
|
287 |
|
288 |
-
return
|
289 |
-
"harmonized": relighted, # target prediction
|
290 |
-
}
|
291 |
|
292 |
def set_requires_grad(self, modules=["encoder", "sh_head", "resquare", "decoder"], value=False):
|
293 |
for module in modules:
|
|
|
133 |
return self.data_len
|
134 |
|
135 |
|
136 |
+
class MyAdaptiveMaxPool2d(nn.Module):
|
137 |
+
def __init__(self, sz=None):
|
138 |
+
super().__init__()
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
inp_size = x.size()
|
142 |
+
return nn.functional.max_pool2d(input=x,
|
143 |
+
kernel_size=(inp_size[2], inp_size[3]))
|
144 |
+
|
145 |
+
|
146 |
class SEBlock(nn.Module):
|
147 |
def __init__(self, channel, reducation=8):
|
148 |
super(SEBlock, self).__init__()
|
|
|
162 |
y = self.fc(y1).view(b, c, 1, 1)
|
163 |
r = x*y
|
164 |
if aux_inp is not None:
|
165 |
+
aux_weitghts = MyAdaptiveMaxPool2d(
|
166 |
+
aux_inp.shape[-1]//8)(aux_inp)
|
167 |
aux_weitghts = nn.Sigmoid()(aux_weitghts.mean(1, keepdim=True))
|
168 |
tmp = x*aux_weitghts
|
169 |
tmp_img = (tmp - tmp.min()) / (tmp.max() - tmp.min())
|
|
|
294 |
x = self.skip[i](x)
|
295 |
x = up_layer(x)
|
296 |
|
297 |
+
harmonized = F.sigmoid(x)
|
298 |
|
299 |
+
return harmonized
|
|
|
|
|
300 |
|
301 |
def set_requires_grad(self, modules=["encoder", "sh_head", "resquare", "decoder"], value=False):
|
302 |
for module in modules:
|
tools/stylematte.py
CHANGED
@@ -284,10 +284,6 @@ class CenterBlock(nn.Sequential):
|
|
284 |
class SegForm(nn.Module):
|
285 |
def __init__(self):
|
286 |
super(SegForm, self).__init__()
|
287 |
-
# configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
288 |
-
# configuration.num_labels = 1 ## set output as 1
|
289 |
-
# self.model = SegformerForSemanticSegmentation(config=configuration)
|
290 |
-
|
291 |
self.model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", num_labels=1, ignore_mismatched_sizes=True
|
292 |
)
|
293 |
|
@@ -303,22 +299,13 @@ class SegForm(nn.Module):
|
|
303 |
class StyleMatte(nn.Module):
|
304 |
def __init__(self):
|
305 |
super(StyleMatte, self).__init__()
|
306 |
-
# configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
307 |
-
# configuration.num_labels = 1 ## set output as 1
|
308 |
self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256], fpn_out=256)
|
309 |
self.pixel_decoder = Mask2FormerForUniversalSegmentation.from_pretrained(
|
310 |
"facebook/mask2former-swin-tiny-coco-instance").base_model.pixel_level_module
|
311 |
self.fgf = FastGuidedFilter()
|
312 |
self.conv = nn.Conv2d(256, 1, kernel_size=3, padding=1)
|
313 |
-
# self.mean = torch.Tensor([0.43216, 0.394666, 0.37645]).float().view(-1, 1, 1)
|
314 |
-
# self.register_buffer('image_net_mean', self.mean)
|
315 |
-
# self.std = torch.Tensor([0.22803, 0.22145, 0.216989]).float().view(-1, 1, 1)
|
316 |
-
# self.register_buffer('image_net_std', self.std)
|
317 |
|
318 |
def forward(self, image, normalize=False):
|
319 |
-
# if normalize:
|
320 |
-
# image.sub_(self.get_buffer("image_net_mean")).div_(self.get_buffer("image_net_std"))
|
321 |
-
|
322 |
decoder_out = self.pixel_decoder(image)
|
323 |
decoder_states = list(decoder_out.decoder_hidden_states)
|
324 |
decoder_states.append(decoder_out.decoder_last_hidden_state)
|
@@ -331,18 +318,11 @@ class StyleMatte(nn.Module):
|
|
331 |
)
|
332 |
out = self.conv(out_pure)
|
333 |
out = self.fgf(image_lr, out, image.mean(
|
334 |
-
1, keepdim=True))
|
335 |
-
# out = nn.Sigmoid()(out)
|
336 |
-
# out = nn.functional.interpolate(out,
|
337 |
-
# scale_factor=4,
|
338 |
-
# mode='bicubic',
|
339 |
-
# align_corners=True
|
340 |
-
# )
|
341 |
|
342 |
return torch.sigmoid(out)
|
343 |
|
344 |
def get_training_params(self):
|
345 |
-
# +list(self.fgf.parameters())
|
346 |
return list(self.fpn.parameters())+list(self.conv.parameters())
|
347 |
|
348 |
|
|
|
284 |
class SegForm(nn.Module):
|
285 |
def __init__(self):
|
286 |
super(SegForm, self).__init__()
|
|
|
|
|
|
|
|
|
287 |
self.model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", num_labels=1, ignore_mismatched_sizes=True
|
288 |
)
|
289 |
|
|
|
299 |
class StyleMatte(nn.Module):
|
300 |
def __init__(self):
|
301 |
super(StyleMatte, self).__init__()
|
|
|
|
|
302 |
self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256], fpn_out=256)
|
303 |
self.pixel_decoder = Mask2FormerForUniversalSegmentation.from_pretrained(
|
304 |
"facebook/mask2former-swin-tiny-coco-instance").base_model.pixel_level_module
|
305 |
self.fgf = FastGuidedFilter()
|
306 |
self.conv = nn.Conv2d(256, 1, kernel_size=3, padding=1)
|
|
|
|
|
|
|
|
|
307 |
|
308 |
def forward(self, image, normalize=False):
|
|
|
|
|
|
|
309 |
decoder_out = self.pixel_decoder(image)
|
310 |
decoder_states = list(decoder_out.decoder_hidden_states)
|
311 |
decoder_states.append(decoder_out.decoder_last_hidden_state)
|
|
|
318 |
)
|
319 |
out = self.conv(out_pure)
|
320 |
out = self.fgf(image_lr, out, image.mean(
|
321 |
+
1, keepdim=True))
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
|
323 |
return torch.sigmoid(out)
|
324 |
|
325 |
def get_training_params(self):
|
|
|
326 |
return list(self.fpn.parameters())+list(self.conv.parameters())
|
327 |
|
328 |
|
tools/util.py
CHANGED
@@ -6,13 +6,10 @@ import torch.nn as nn
|
|
6 |
from torchvision.utils import make_grid
|
7 |
import cv2
|
8 |
from torchvision import transforms, models
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
def log(msg, lvl='info'):
|
12 |
-
if lvl == 'info':
|
13 |
-
print(f"***********{msg}****************")
|
14 |
-
if lvl == 'error':
|
15 |
-
print(f"!!! Exception: {msg} !!!")
|
16 |
|
17 |
|
18 |
def lab_shift(x, invert=False):
|
@@ -321,6 +318,7 @@ def linear_rgb_to_rgb(image: torch.Tensor) -> torch.Tensor:
|
|
321 |
return rgb
|
322 |
|
323 |
|
|
|
324 |
def inference_img(model, img, device='cpu'):
|
325 |
h, w, _ = img.shape
|
326 |
# print(img.shape)
|
@@ -343,3 +341,48 @@ def inference_img(model, img, device='cpu'):
|
|
343 |
# print(result.shape)
|
344 |
|
345 |
return result[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from torchvision.utils import make_grid
|
7 |
import cv2
|
8 |
from torchvision import transforms, models
|
9 |
+
from PIL import Image
|
10 |
+
import torchvision.transforms.functional as tf
|
11 |
|
12 |
+
# --------------------------------------------Metric tools-------------------------------------------- #
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
def lab_shift(x, invert=False):
|
|
|
318 |
return rgb
|
319 |
|
320 |
|
321 |
+
# --------------------------------------------Inference tools-------------------------------------------- #
|
322 |
def inference_img(model, img, device='cpu'):
|
323 |
h, w, _ = img.shape
|
324 |
# print(img.shape)
|
|
|
341 |
# print(result.shape)
|
342 |
|
343 |
return result[0]
|
344 |
+
|
345 |
+
|
346 |
+
def log(msg, lvl='info'):
|
347 |
+
if lvl == 'info':
|
348 |
+
print(f"***********{msg}****************")
|
349 |
+
if lvl == 'error':
|
350 |
+
print(f"!!! Exception: {msg} !!!")
|
351 |
+
|
352 |
+
|
353 |
+
def harmonize(comp, mask, model):
|
354 |
+
log("Inference started")
|
355 |
+
if comp is None or mask is None:
|
356 |
+
log("Empty source")
|
357 |
+
return np.zeros((16, 16, 3))
|
358 |
+
|
359 |
+
comp = comp.convert('RGB')
|
360 |
+
mask = mask.convert('1')
|
361 |
+
in_shape = comp.size[::-1]
|
362 |
+
|
363 |
+
comp = tf.resize(comp, [model.image_size, model.image_size])
|
364 |
+
mask = tf.resize(mask, [model.image_size, model.image_size])
|
365 |
+
|
366 |
+
compt = tf.to_tensor(comp)
|
367 |
+
maskt = tf.to_tensor(mask)
|
368 |
+
res = model.harmonize(compt, maskt)
|
369 |
+
res = tf.resize(res, in_shape)
|
370 |
+
|
371 |
+
log("Inference finished")
|
372 |
+
|
373 |
+
return np.uint8((res*255)[0].permute(1, 2, 0).numpy())
|
374 |
+
|
375 |
+
|
376 |
+
def extract_matte(img, back, model):
|
377 |
+
mask, fg = model.extract(img)
|
378 |
+
fg_pil = Image.fromarray(np.uint8(fg))
|
379 |
+
|
380 |
+
composite = fg + (1 - mask[:, :, None]) * \
|
381 |
+
np.array(back.resize(mask.shape[::-1]))
|
382 |
+
composite_pil = Image.fromarray(np.uint8(composite))
|
383 |
+
|
384 |
+
return [composite_pil, mask, fg_pil]
|
385 |
+
|
386 |
+
|
387 |
+
def css(height=3, scale=2):
|
388 |
+
return f".output_image {{height: {height}rem !important; width: {scale}rem !important;}}"
|