befozg commited on
Commit
0891b79
·
1 Parent(s): 0eff2bf

fixed live demo app, converted network for onnx convertion, fixed code

Browse files
Files changed (14) hide show
  1. .gitignore +18 -0
  2. app.py +4 -49
  3. config/test.yaml +3 -2
  4. converter.py +35 -0
  5. live_demo.py +71 -0
  6. live_mp.py +112 -0
  7. output.mp4 +0 -0
  8. requirements.txt +3 -1
  9. tools/__init__.py +1 -1
  10. tools/engine.py +12 -0
  11. tools/inference.py +48 -9
  12. tools/model.py +14 -5
  13. tools/stylematte.py +1 -21
  14. tools/util.py +49 -6
.gitignore CHANGED
@@ -173,3 +173,21 @@ __pycache__/*
173
  flagged/
174
  # assets/
175
  .DS_store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  flagged/
174
  # assets/
175
  .DS_store
176
+
177
+
178
+ config/*
179
+ trainer/__pycache__/
180
+ trainer/__pycache__/*
181
+ __pycache__/*
182
+ checkpoints/*.pth
183
+ */*.pth
184
+ */checkpoints/best_pure.pth
185
+ checkpoints/best_pure.pth
186
+ *.ipynb
187
+ .ipynb_checkpoints/*
188
+ flagged/
189
+ assets/*
190
+ *.html
191
+ checkpoints/*.onnx
192
+ *.avi
193
+ *.onnx
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from tools import Inference, Matting, log
3
  from omegaconf import OmegaConf
4
  import os
5
  import sys
@@ -9,56 +9,11 @@ from PIL import Image
9
 
10
  args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
11
 
12
- global_comp = None
13
- global_mask = None
14
-
15
  log("Model loading")
16
  phnet = Inference(**args)
17
  stylematte = Matting(**args)
18
  log("Model loaded")
19
 
20
-
21
- def harmonize(comp, mask):
22
- log("Inference started")
23
- if comp is None or mask is None:
24
- log("Empty source")
25
- return np.zeros((16, 16, 3))
26
-
27
- comp = comp.convert('RGB')
28
- mask = mask.convert('1')
29
- in_shape = comp.size[::-1]
30
-
31
- comp = tf.resize(comp, [args.image_size, args.image_size])
32
- mask = tf.resize(mask, [args.image_size, args.image_size])
33
-
34
- compt = tf.to_tensor(comp)
35
- maskt = tf.to_tensor(mask)
36
- res = phnet.harmonize(compt, maskt)
37
- res = tf.resize(res, in_shape)
38
-
39
- log("Inference finished")
40
-
41
- return np.uint8((res*255)[0].permute(1, 2, 0).numpy())
42
-
43
-
44
- def extract_matte(img, back):
45
- mask, fg = stylematte.extract(img)
46
- fg_pil = Image.fromarray(np.uint8(fg))
47
-
48
- composite = fg + (1 - mask[:, :, None]) * \
49
- np.array(back.resize(mask.shape[::-1]))
50
- composite_pil = Image.fromarray(np.uint8(composite))
51
-
52
- global_comp = composite_pil
53
- global_mask = mask
54
-
55
- return [composite_pil, mask, fg_pil]
56
-
57
-
58
- def css(height=3, scale=2):
59
- return f".output_image {{height: {height}rem !important; width: {scale}rem !important;}}"
60
-
61
-
62
  with gr.Blocks() as demo:
63
  gr.Markdown(
64
  """
@@ -97,11 +52,11 @@ with gr.Blocks() as demo:
97
  harmonized_ui = gr.Image(
98
  type="pil", label='Harmonized composite', css=css(3, 3))
99
 
100
- btn_compose.click(extract_matte, inputs=[input_ui, back_ui], outputs=[
101
  composite_ui, matte_ui, fg_ui])
102
- btn_harmonize.click(harmonize, inputs=[
103
  composite_ui, matte_ui], outputs=[harmonized_ui])
104
 
105
 
106
  log("Interface created")
107
- demo.launch(share=True)
 
1
  import gradio as gr
2
+ from tools import Inference, Matting, log, extract_matte, harmonize, css
3
  from omegaconf import OmegaConf
4
  import os
5
  import sys
 
9
 
10
  args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
11
 
 
 
 
12
  log("Model loading")
13
  phnet = Inference(**args)
14
  stylematte = Matting(**args)
15
  log("Model loaded")
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  with gr.Blocks() as demo:
18
  gr.Markdown(
19
  """
 
52
  harmonized_ui = gr.Image(
53
  type="pil", label='Harmonized composite', css=css(3, 3))
54
 
55
+ btn_compose.click(lambda x, y: extract_matte(x, y, stylematte), inputs=[input_ui, back_ui], outputs=[
56
  composite_ui, matte_ui, fg_ui])
57
+ btn_harmonize.click(lambda x, y: harmonize(x, y, phnet), inputs=[
58
  composite_ui, matte_ui], outputs=[harmonized_ui])
59
 
60
 
61
  log("Interface created")
62
+ demo.launch(share=False)
config/test.yaml CHANGED
@@ -5,7 +5,8 @@ init_value: 0.8
5
  skips: True
6
  device: 'cpu'
7
  checkpoint:
8
- harmonizer: "checkpoints/best_pure.pth"
 
9
  matting: "checkpoints/stylematte.pth"
10
-
11
  image_size: 1024
 
5
  skips: True
6
  device: 'cpu'
7
  checkpoint:
8
+ matting_onnx: "checkpoints/stylematte_720.onnx"
9
+ harmonizer: "checkpoints/ffhqh1024.pth"
10
  matting: "checkpoints/stylematte.pth"
11
+ onnx: False
12
  image_size: 1024
converter.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import io
3
+ import numpy as np
4
+ import torch.onnx
5
+ import onnx
6
+ from tools import Inference, Matting, log, extract_matte, harmonize, css, execute_onnx_model
7
+ from omegaconf import OmegaConf
8
+ import os
9
+ import sys
10
+ import torch
11
+ import numpy as np
12
+ import torchvision.transforms.functional as tf
13
+ from PIL import Image
14
+ import cv2 as cv
15
+ from onnxruntime import InferenceSession
16
+
17
+ args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
18
+
19
+ log("Model loading")
20
+ phnet = Inference(**args)
21
+ stylematte = Matting(**args)
22
+ log("Model loaded")
23
+ model = stylematte.model
24
+
25
+ x = torch.randn((1, 3, 720, 1280))
26
+ mask = torch.ones((1, 1, 512, 512))
27
+ path = 'checkpoints/stylematte-test.onnx'
28
+
29
+ # Export
30
+ torch.onnx.export(model, x, path, opset_version=16)
31
+
32
+ # Validation
33
+ onnx_model = onnx.load(path)
34
+ onnx.checker.check_model(onnx_model)
35
+ # execute_onnx_model(x, onnx_model)
live_demo.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from tools import Inference, Matting, log, extract_matte, harmonize, css, live_matting_step
3
+ from omegaconf import OmegaConf
4
+ import os
5
+ import sys
6
+ import numpy as np
7
+ import torchvision.transforms.functional as tf
8
+ from PIL import Image
9
+ import cv2 as cv
10
+ import time
11
+ import asyncio
12
+
13
+ args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
14
+
15
+ log("Model loading")
16
+ phnet = Inference(**args)
17
+ stylematte = Matting(**args)
18
+ log("Model loaded")
19
+
20
+
21
+ async def show(queue):
22
+ while True:
23
+ log("SHOW FRAME")
24
+ frame = queue.get()
25
+ cv.imshow('Video', frame)
26
+ await asyncio.sleep(0.01)
27
+
28
+
29
+ async def main(queue):
30
+ video = cv.VideoCapture(0)
31
+ fps = 10
32
+ counter = 0
33
+ frame_count = 0
34
+ if not video.isOpened():
35
+ raise Exception('Video is not opened!')
36
+ begin = time.time()
37
+ for i in range(300):
38
+ counter += 1
39
+ frame_count += 1
40
+ ret, frame = video.read() # Capture frame-by-frame
41
+ inp = np.array(frame)
42
+ back = np.zeros_like(frame)
43
+ queue.put(inp)
44
+ # res = asyncio.ensure_future(
45
+ # live_matting_step(inp, back, stylematte))
46
+ # res = await live_matting_step(inp, back, stylematte)
47
+ log(f"{i} await")
48
+ # Display the resulting frame
49
+
50
+ # blurred_frame = cv.blur(frame, (10, 10))
51
+ end = time.time()
52
+ log(f"frames: {frame_count}, time: {end - begin}, fps: {frame_count/(end - begin) }")
53
+
54
+ if cv.waitKey(1) & 0xFF == ord('q'):
55
+ break
56
+
57
+ end = time.time()
58
+ log(f"OVERALL TIME CONSUMED: {end - begin}, frames: {frame_count}, fps: {frame_count/(end - begin) }")
59
+ # release the capture
60
+ video.release()
61
+ cv.destroyAllWindows()
62
+
63
+
64
+ if __name__ == "__main__":
65
+ queue = asyncio.Queue()
66
+ loop = asyncio.get_event_loop()
67
+ # asyncio.ensure_future(show(frame)) # Display the resulting frame
68
+
69
+ loop.run_until_complete(main(queue))
70
+ loop.run_until_complete(show(queue))
71
+ loop.run_forever()
live_mp.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import Process, Queue
2
+ import gradio as gr
3
+ from tools import Inference, Matting, log, extract_matte, harmonize, css, live_matting_step
4
+ from omegaconf import OmegaConf
5
+ import os
6
+ import sys
7
+ import numpy as np
8
+ import torchvision.transforms.functional as tf
9
+ from PIL import Image
10
+ import cv2 as cv
11
+ import time
12
+ import asyncio
13
+
14
+
15
+ def show(queue, stack):
16
+ print(f"PROCESS {3}")
17
+ # while not queue.empty():
18
+ if stack.empty():
19
+ frame = queue.get()
20
+ else:
21
+ frame = stack.get(block=False)
22
+ cv.imshow('Video', np.uint8(frame))
23
+ log("PID: 3, SHOW FRAME")
24
+ print(frame.shape)
25
+ time.sleep(0.1)
26
+
27
+
28
+ def extract(queue, stack, model):
29
+ '''
30
+ img: np.array,
31
+ back: np.array,
32
+ model: Matting instance
33
+ '''
34
+ print(f"PROCESS {2}")
35
+ img = queue.get()
36
+ back = np.zeros_like(img)
37
+ mask, fg = model.extract(img)
38
+ composite = fg + (1 - mask[:, :, None]) * \
39
+ back # .resize(mask.shape[::-1])
40
+ stack.put(np.uint8(composite))
41
+ # time.sleep(0.1)
42
+ print("PID: 2, LIVE STEP")
43
+ # for i in range(10):
44
+ # print(f"In live {i}")
45
+ # cv.imshow('Video', np.uint8(composite))
46
+ # return composite
47
+
48
+
49
+ def main(queue):
50
+
51
+ log(f"PROCESS {1}")
52
+ video = cv.VideoCapture(0)
53
+ fps = 10
54
+ counter = 0
55
+ frame_count = 0
56
+ if not video.isOpened():
57
+ raise Exception('Video is not opened!')
58
+ begin = time.time()
59
+ # stack = Queue()
60
+ for i in range(10):
61
+ counter += 1
62
+ frame_count += 1
63
+ ret, frame = video.read() # Capture frame-by-frame
64
+ inp = np.array(frame)
65
+ back = np.zeros_like(frame)
66
+ # res = asyncio.ensure_future(
67
+ # live_matting_step(inp, back, stylematte))
68
+ # res = live_matting_step(inp, back, stylematte)
69
+ queue.put(inp)
70
+ mp.sleep(0.1)
71
+ # Display the resulting frame
72
+
73
+ # blurred_frame = cv.blur(frame, (10, 10))
74
+ counter = 0
75
+ end = time.time()
76
+ log(f"PID: 1, frames: {frame_count}, time: {end - begin}, fps: {frame_count/(end - begin) }")
77
+ # else:
78
+ # show(queue) # Display the resulting frame
79
+
80
+ if cv.waitKey(1) & 0xFF == ord('q'):
81
+ break
82
+ end = time.time()
83
+ log(f"OVERALL TIME CONSUMED: {end - begin}, frames: {frame_count}, fps: {frame_count/(end - begin) }")
84
+ # release the capture
85
+ video.release()
86
+ cv.destroyAllWindows()
87
+
88
+
89
+ if __name__ == "__main__":
90
+ queue = Queue() # Создаем канал
91
+ stack = Queue() # Создаем канал
92
+ # stack = Queue() # Создаем канал
93
+ args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
94
+
95
+ log("Model loading")
96
+ phnet = Inference(**args)
97
+ stylematte = Matting(**args)
98
+ log("Model loaded")
99
+
100
+ p1 = Process(target=main, args=(queue,)) # Вводим параметры
101
+
102
+ p2 = Process(target=extract, args=(
103
+ queue, stack, stylematte)) # Вводим параметры
104
+ p3 = Process(target=show, args=(queue, stack)) # Вводим параметры
105
+ # p2 = Process(target=test_2, args=("Пончик", queue,)) # Вводим параметры
106
+
107
+ p1.start()
108
+ p2.start()
109
+ p3.start()
110
+ p3.join()
111
+ p2.join()
112
+ p1.join()
output.mp4 ADDED
Binary file (258 Bytes). View file
 
requirements.txt CHANGED
@@ -35,4 +35,6 @@ torchaudio==0.11.0
35
  torchvision==0.12.0
36
  tornado==6.2
37
  tqdm==4.64.1
38
- transformers==4.28.1
 
 
 
35
  torchvision==0.12.0
36
  tornado==6.2
37
  tqdm==4.64.1
38
+ transformers==4.28.1
39
+ onnx==1.14.1
40
+ onnxruntime==1.16.0
tools/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .inference import Inference
2
  from .inference import Matting
3
- from .util import log
 
1
  from .inference import Inference
2
  from .inference import Matting
3
+ from .util import *
tools/engine.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnx
2
+ from onnxruntime import InferenceSession
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def execute_onnx_model(x, onnx_model) -> None:
8
+ sess = InferenceSession(onnx_model.SerializeToString(), providers=[
9
+ 'AzureExecutionProvider', 'CPUExecutionProvider'])
10
+ out = sess.run(None, {'input.1': x.numpy().astype(np.float32)})[0]
11
+
12
+ return out
tools/inference.py CHANGED
@@ -4,6 +4,11 @@ import torchvision.transforms.functional as tf
4
  from .util import inference_img, log
5
  from .stylematte import StyleMatte
6
  import numpy as np
 
 
 
 
 
7
 
8
 
9
  class Inference:
@@ -15,7 +20,6 @@ class Inference:
15
  grid_count=self.grid_counts,
16
  init_weights=self.init_weights,
17
  init_value=self.init_value)
18
- log(f"checkpoint: {self.checkpoint.harmonizer}")
19
  state = torch.load(self.checkpoint.harmonizer,
20
  map_location=self.device)
21
 
@@ -29,12 +33,13 @@ class Inference:
29
  mask = mask.unsqueeze(0)
30
  composite = tf.resize(composite, [self.image_size, self.image_size])
31
  mask = tf.resize(mask, [self.image_size, self.image_size])
 
32
  log(composite.shape, mask.shape)
33
  with torch.no_grad():
34
- harmonized = self.model(composite, mask)['harmonized']
35
 
36
  result = harmonized * mask + composite * (1-mask)
37
- print(result.shape)
38
  return result
39
 
40
 
@@ -42,15 +47,49 @@ class Matting:
42
  def __init__(self, **kwargs):
43
  self.rank = 0
44
  self.__dict__.update(kwargs)
45
- self.model = StyleMatte().to(self.device)
46
- log(f"checkpoint: {self.checkpoint.matting}")
47
- state = torch.load(self.checkpoint.matting, map_location=self.device)
48
- self.model.load_state_dict(state, strict=True)
49
- self.model.eval()
 
 
 
50
 
51
  def extract(self, inp):
52
- mask = inference_img(self.model, inp, self.device)
53
  inp_np = np.array(inp)
54
  fg = mask[:, :, None]*inp_np
55
 
56
  return [mask, fg]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from .util import inference_img, log
5
  from .stylematte import StyleMatte
6
  import numpy as np
7
+ import onnx
8
+ from .engine import execute_onnx_model
9
+ import cv2
10
+ from torchvision import transforms
11
+ import time
12
 
13
 
14
  class Inference:
 
20
  grid_count=self.grid_counts,
21
  init_weights=self.init_weights,
22
  init_value=self.init_value)
 
23
  state = torch.load(self.checkpoint.harmonizer,
24
  map_location=self.device)
25
 
 
33
  mask = mask.unsqueeze(0)
34
  composite = tf.resize(composite, [self.image_size, self.image_size])
35
  mask = tf.resize(mask, [self.image_size, self.image_size])
36
+
37
  log(composite.shape, mask.shape)
38
  with torch.no_grad():
39
+ harmonized = self.model(composite, mask) # ['harmonized']
40
 
41
  result = harmonized * mask + composite * (1-mask)
42
+
43
  return result
44
 
45
 
 
47
  def __init__(self, **kwargs):
48
  self.rank = 0
49
  self.__dict__.update(kwargs)
50
+ if self.onnx:
51
+ self.model = onnx.load(self.checkpoint.matting_onnx)
52
+ else:
53
+ self.model = StyleMatte().to(self.device)
54
+ state = torch.load(self.checkpoint.matting,
55
+ map_location=self.device)
56
+ self.model.load_state_dict(state, strict=True)
57
+ self.model.eval()
58
 
59
  def extract(self, inp):
60
+ mask = inference_img(self.model, inp, self.device, self.onnx)
61
  inp_np = np.array(inp)
62
  fg = mask[:, :, None]*inp_np
63
 
64
  return [mask, fg]
65
+
66
+
67
+ def inference_img(model, img, device='cpu', onnx=True):
68
+ beg = time.time()
69
+ h, w, _ = img.shape
70
+ # print(img.shape)
71
+ if h % 8 != 0 or w % 8 != 0:
72
+ img = cv2.copyMakeBorder(img, 8-h % 8, 0, 8-w %
73
+ 8, 0, cv2.BORDER_REFLECT)
74
+ # print(img.shape)
75
+
76
+ tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device)
77
+ input_t = tensor_img/255.0
78
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
79
+ std=[0.229, 0.224, 0.225])
80
+ input_t = normalize(input_t)
81
+ input_t = input_t.unsqueeze(0).float()
82
+ end_p = time.time()
83
+
84
+ if onnx:
85
+ out = execute_onnx_model(input_t, model)
86
+ else:
87
+ with torch.no_grad():
88
+ out = model(input_t).cpu().numpy()
89
+ end = time.time()
90
+ log(f"Inference time: {end-beg}, processing time: {end_p-beg}")
91
+ # print("out",out.shape)
92
+ result = out[0][:, -h:, -w:]
93
+ # print(result.shape)
94
+
95
+ return result[0]
tools/model.py CHANGED
@@ -133,6 +133,16 @@ class Inference_Data(Dataset):
133
  return self.data_len
134
 
135
 
 
 
 
 
 
 
 
 
 
 
136
  class SEBlock(nn.Module):
137
  def __init__(self, channel, reducation=8):
138
  super(SEBlock, self).__init__()
@@ -152,7 +162,8 @@ class SEBlock(nn.Module):
152
  y = self.fc(y1).view(b, c, 1, 1)
153
  r = x*y
154
  if aux_inp is not None:
155
- aux_weitghts = nn.AdaptiveAvgPool2d(aux_inp.shape[-1]//8)(aux_inp)
 
156
  aux_weitghts = nn.Sigmoid()(aux_weitghts.mean(1, keepdim=True))
157
  tmp = x*aux_weitghts
158
  tmp_img = (tmp - tmp.min()) / (tmp.max() - tmp.min())
@@ -283,11 +294,9 @@ class PHNet(nn.Module):
283
  x = self.skip[i](x)
284
  x = up_layer(x)
285
 
286
- relighted = F.sigmoid(x)
287
 
288
- return {
289
- "harmonized": relighted, # target prediction
290
- }
291
 
292
  def set_requires_grad(self, modules=["encoder", "sh_head", "resquare", "decoder"], value=False):
293
  for module in modules:
 
133
  return self.data_len
134
 
135
 
136
+ class MyAdaptiveMaxPool2d(nn.Module):
137
+ def __init__(self, sz=None):
138
+ super().__init__()
139
+
140
+ def forward(self, x):
141
+ inp_size = x.size()
142
+ return nn.functional.max_pool2d(input=x,
143
+ kernel_size=(inp_size[2], inp_size[3]))
144
+
145
+
146
  class SEBlock(nn.Module):
147
  def __init__(self, channel, reducation=8):
148
  super(SEBlock, self).__init__()
 
162
  y = self.fc(y1).view(b, c, 1, 1)
163
  r = x*y
164
  if aux_inp is not None:
165
+ aux_weitghts = MyAdaptiveMaxPool2d(
166
+ aux_inp.shape[-1]//8)(aux_inp)
167
  aux_weitghts = nn.Sigmoid()(aux_weitghts.mean(1, keepdim=True))
168
  tmp = x*aux_weitghts
169
  tmp_img = (tmp - tmp.min()) / (tmp.max() - tmp.min())
 
294
  x = self.skip[i](x)
295
  x = up_layer(x)
296
 
297
+ harmonized = F.sigmoid(x)
298
 
299
+ return harmonized
 
 
300
 
301
  def set_requires_grad(self, modules=["encoder", "sh_head", "resquare", "decoder"], value=False):
302
  for module in modules:
tools/stylematte.py CHANGED
@@ -284,10 +284,6 @@ class CenterBlock(nn.Sequential):
284
  class SegForm(nn.Module):
285
  def __init__(self):
286
  super(SegForm, self).__init__()
287
- # configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
288
- # configuration.num_labels = 1 ## set output as 1
289
- # self.model = SegformerForSemanticSegmentation(config=configuration)
290
-
291
  self.model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", num_labels=1, ignore_mismatched_sizes=True
292
  )
293
 
@@ -303,22 +299,13 @@ class SegForm(nn.Module):
303
  class StyleMatte(nn.Module):
304
  def __init__(self):
305
  super(StyleMatte, self).__init__()
306
- # configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
307
- # configuration.num_labels = 1 ## set output as 1
308
  self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256], fpn_out=256)
309
  self.pixel_decoder = Mask2FormerForUniversalSegmentation.from_pretrained(
310
  "facebook/mask2former-swin-tiny-coco-instance").base_model.pixel_level_module
311
  self.fgf = FastGuidedFilter()
312
  self.conv = nn.Conv2d(256, 1, kernel_size=3, padding=1)
313
- # self.mean = torch.Tensor([0.43216, 0.394666, 0.37645]).float().view(-1, 1, 1)
314
- # self.register_buffer('image_net_mean', self.mean)
315
- # self.std = torch.Tensor([0.22803, 0.22145, 0.216989]).float().view(-1, 1, 1)
316
- # self.register_buffer('image_net_std', self.std)
317
 
318
  def forward(self, image, normalize=False):
319
- # if normalize:
320
- # image.sub_(self.get_buffer("image_net_mean")).div_(self.get_buffer("image_net_std"))
321
-
322
  decoder_out = self.pixel_decoder(image)
323
  decoder_states = list(decoder_out.decoder_hidden_states)
324
  decoder_states.append(decoder_out.decoder_last_hidden_state)
@@ -331,18 +318,11 @@ class StyleMatte(nn.Module):
331
  )
332
  out = self.conv(out_pure)
333
  out = self.fgf(image_lr, out, image.mean(
334
- 1, keepdim=True)) # .clip(0,1)
335
- # out = nn.Sigmoid()(out)
336
- # out = nn.functional.interpolate(out,
337
- # scale_factor=4,
338
- # mode='bicubic',
339
- # align_corners=True
340
- # )
341
 
342
  return torch.sigmoid(out)
343
 
344
  def get_training_params(self):
345
- # +list(self.fgf.parameters())
346
  return list(self.fpn.parameters())+list(self.conv.parameters())
347
 
348
 
 
284
  class SegForm(nn.Module):
285
  def __init__(self):
286
  super(SegForm, self).__init__()
 
 
 
 
287
  self.model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", num_labels=1, ignore_mismatched_sizes=True
288
  )
289
 
 
299
  class StyleMatte(nn.Module):
300
  def __init__(self):
301
  super(StyleMatte, self).__init__()
 
 
302
  self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256], fpn_out=256)
303
  self.pixel_decoder = Mask2FormerForUniversalSegmentation.from_pretrained(
304
  "facebook/mask2former-swin-tiny-coco-instance").base_model.pixel_level_module
305
  self.fgf = FastGuidedFilter()
306
  self.conv = nn.Conv2d(256, 1, kernel_size=3, padding=1)
 
 
 
 
307
 
308
  def forward(self, image, normalize=False):
 
 
 
309
  decoder_out = self.pixel_decoder(image)
310
  decoder_states = list(decoder_out.decoder_hidden_states)
311
  decoder_states.append(decoder_out.decoder_last_hidden_state)
 
318
  )
319
  out = self.conv(out_pure)
320
  out = self.fgf(image_lr, out, image.mean(
321
+ 1, keepdim=True))
 
 
 
 
 
 
322
 
323
  return torch.sigmoid(out)
324
 
325
  def get_training_params(self):
 
326
  return list(self.fpn.parameters())+list(self.conv.parameters())
327
 
328
 
tools/util.py CHANGED
@@ -6,13 +6,10 @@ import torch.nn as nn
6
  from torchvision.utils import make_grid
7
  import cv2
8
  from torchvision import transforms, models
 
 
9
 
10
-
11
- def log(msg, lvl='info'):
12
- if lvl == 'info':
13
- print(f"***********{msg}****************")
14
- if lvl == 'error':
15
- print(f"!!! Exception: {msg} !!!")
16
 
17
 
18
  def lab_shift(x, invert=False):
@@ -321,6 +318,7 @@ def linear_rgb_to_rgb(image: torch.Tensor) -> torch.Tensor:
321
  return rgb
322
 
323
 
 
324
  def inference_img(model, img, device='cpu'):
325
  h, w, _ = img.shape
326
  # print(img.shape)
@@ -343,3 +341,48 @@ def inference_img(model, img, device='cpu'):
343
  # print(result.shape)
344
 
345
  return result[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from torchvision.utils import make_grid
7
  import cv2
8
  from torchvision import transforms, models
9
+ from PIL import Image
10
+ import torchvision.transforms.functional as tf
11
 
12
+ # --------------------------------------------Metric tools-------------------------------------------- #
 
 
 
 
 
13
 
14
 
15
  def lab_shift(x, invert=False):
 
318
  return rgb
319
 
320
 
321
+ # --------------------------------------------Inference tools-------------------------------------------- #
322
  def inference_img(model, img, device='cpu'):
323
  h, w, _ = img.shape
324
  # print(img.shape)
 
341
  # print(result.shape)
342
 
343
  return result[0]
344
+
345
+
346
+ def log(msg, lvl='info'):
347
+ if lvl == 'info':
348
+ print(f"***********{msg}****************")
349
+ if lvl == 'error':
350
+ print(f"!!! Exception: {msg} !!!")
351
+
352
+
353
+ def harmonize(comp, mask, model):
354
+ log("Inference started")
355
+ if comp is None or mask is None:
356
+ log("Empty source")
357
+ return np.zeros((16, 16, 3))
358
+
359
+ comp = comp.convert('RGB')
360
+ mask = mask.convert('1')
361
+ in_shape = comp.size[::-1]
362
+
363
+ comp = tf.resize(comp, [model.image_size, model.image_size])
364
+ mask = tf.resize(mask, [model.image_size, model.image_size])
365
+
366
+ compt = tf.to_tensor(comp)
367
+ maskt = tf.to_tensor(mask)
368
+ res = model.harmonize(compt, maskt)
369
+ res = tf.resize(res, in_shape)
370
+
371
+ log("Inference finished")
372
+
373
+ return np.uint8((res*255)[0].permute(1, 2, 0).numpy())
374
+
375
+
376
+ def extract_matte(img, back, model):
377
+ mask, fg = model.extract(img)
378
+ fg_pil = Image.fromarray(np.uint8(fg))
379
+
380
+ composite = fg + (1 - mask[:, :, None]) * \
381
+ np.array(back.resize(mask.shape[::-1]))
382
+ composite_pil = Image.fromarray(np.uint8(composite))
383
+
384
+ return [composite_pil, mask, fg_pil]
385
+
386
+
387
+ def css(height=3, scale=2):
388
+ return f".output_image {{height: {height}rem !important; width: {scale}rem !important;}}"