Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +0 -1
- carvekit/__init__.py +1 -0
- carvekit/__main__.py +149 -0
- carvekit/api/__init__.py +0 -0
- carvekit/api/high.py +100 -0
- carvekit/api/interface.py +77 -0
- carvekit/ml/__init__.py +4 -0
- carvekit/ml/arch/__init__.py +0 -0
- carvekit/ml/arch/basnet/__init__.py +0 -0
- carvekit/ml/arch/basnet/basnet.py +478 -0
- carvekit/ml/arch/fba_matting/__init__.py +0 -0
- carvekit/ml/arch/fba_matting/layers_WS.py +57 -0
- carvekit/ml/arch/fba_matting/models.py +341 -0
- carvekit/ml/arch/fba_matting/resnet_GN_WS.py +151 -0
- carvekit/ml/arch/fba_matting/resnet_bn.py +169 -0
- carvekit/ml/arch/fba_matting/transforms.py +45 -0
- carvekit/ml/arch/tracerb7/__init__.py +0 -0
- carvekit/ml/arch/tracerb7/att_modules.py +290 -0
- carvekit/ml/arch/tracerb7/conv_modules.py +88 -0
- carvekit/ml/arch/tracerb7/effi_utils.py +579 -0
- carvekit/ml/arch/tracerb7/efficientnet.py +325 -0
- carvekit/ml/arch/tracerb7/tracer.py +97 -0
- carvekit/ml/arch/u2net/__init__.py +0 -0
- carvekit/ml/arch/u2net/u2net.py +172 -0
- carvekit/ml/files/__init__.py +7 -0
- carvekit/ml/files/models_loc.py +70 -0
- carvekit/ml/wrap/__init__.py +0 -0
- carvekit/ml/wrap/basnet.py +141 -0
- carvekit/ml/wrap/deeplab_v3.py +150 -0
- carvekit/ml/wrap/fba_matting.py +224 -0
- carvekit/ml/wrap/tracer_b7.py +178 -0
- carvekit/ml/wrap/u2net.py +140 -0
- carvekit/pipelines/__init__.py +0 -0
- carvekit/pipelines/postprocessing.py +76 -0
- carvekit/pipelines/preprocessing.py +28 -0
- carvekit/trimap/__init__.py +0 -0
- carvekit/trimap/add_ops.py +91 -0
- carvekit/trimap/cv_gen.py +64 -0
- carvekit/trimap/generator.py +47 -0
- carvekit/utils/__init__.py +0 -0
- carvekit/utils/download_models.py +214 -0
- carvekit/utils/fs_utils.py +38 -0
- carvekit/utils/image_utils.py +150 -0
- carvekit/utils/mask_utils.py +85 -0
- carvekit/utils/models_utils.py +126 -0
- carvekit/utils/pool_utils.py +40 -0
- carvekit/web/__init__.py +0 -0
- carvekit/web/app.py +30 -0
- carvekit/web/deps.py +6 -0
- carvekit/web/handlers/__init__.py +0 -0
.gitattributes
CHANGED
@@ -25,7 +25,6 @@
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
28 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
carvekit/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
version = "4.1.0"
|
carvekit/__main__.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import click
|
4 |
+
import tqdm
|
5 |
+
|
6 |
+
from carvekit.utils.image_utils import ALLOWED_SUFFIXES
|
7 |
+
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
|
8 |
+
from carvekit.web.schemas.config import MLConfig
|
9 |
+
from carvekit.web.utils.init_utils import init_interface
|
10 |
+
from carvekit.utils.fs_utils import save_file
|
11 |
+
|
12 |
+
|
13 |
+
@click.command(
|
14 |
+
"removebg",
|
15 |
+
help="Performs background removal on specified photos using console interface.",
|
16 |
+
)
|
17 |
+
@click.option("-i", required=True, type=str, help="Path to input file or dir")
|
18 |
+
@click.option("-o", default="none", type=str, help="Path to output file or dir")
|
19 |
+
@click.option("--pre", default="none", type=str, help="Preprocessing method")
|
20 |
+
@click.option("--post", default="fba", type=str, help="Postprocessing method.")
|
21 |
+
@click.option("--net", default="tracer_b7", type=str, help="Segmentation Network")
|
22 |
+
@click.option(
|
23 |
+
"--recursive",
|
24 |
+
default=False,
|
25 |
+
type=bool,
|
26 |
+
help="Enables recursive search for images in a folder",
|
27 |
+
)
|
28 |
+
@click.option(
|
29 |
+
"--batch_size",
|
30 |
+
default=10,
|
31 |
+
type=int,
|
32 |
+
help="Batch Size for list of images to be loaded to RAM",
|
33 |
+
)
|
34 |
+
@click.option(
|
35 |
+
"--batch_size_seg",
|
36 |
+
default=5,
|
37 |
+
type=int,
|
38 |
+
help="Batch size for list of images to be processed by segmentation " "network",
|
39 |
+
)
|
40 |
+
@click.option(
|
41 |
+
"--batch_size_mat",
|
42 |
+
default=1,
|
43 |
+
type=int,
|
44 |
+
help="Batch size for list of images to be processed by matting " "network",
|
45 |
+
)
|
46 |
+
@click.option(
|
47 |
+
"--seg_mask_size",
|
48 |
+
default=640,
|
49 |
+
type=int,
|
50 |
+
help="The size of the input image for the segmentation neural network.",
|
51 |
+
)
|
52 |
+
@click.option(
|
53 |
+
"--matting_mask_size",
|
54 |
+
default=2048,
|
55 |
+
type=int,
|
56 |
+
help="The size of the input image for the matting neural network.",
|
57 |
+
)
|
58 |
+
@click.option(
|
59 |
+
"--trimap_dilation",
|
60 |
+
default=30,
|
61 |
+
type=int,
|
62 |
+
help="The size of the offset radius from the object mask in "
|
63 |
+
"pixels when forming an unknown area",
|
64 |
+
)
|
65 |
+
@click.option(
|
66 |
+
"--trimap_erosion",
|
67 |
+
default=5,
|
68 |
+
type=int,
|
69 |
+
help="The number of iterations of erosion that the object's "
|
70 |
+
"mask will be subjected to before forming an unknown area",
|
71 |
+
)
|
72 |
+
@click.option(
|
73 |
+
"--trimap_prob_threshold",
|
74 |
+
default=231,
|
75 |
+
type=int,
|
76 |
+
help="Probability threshold at which the prob_filter "
|
77 |
+
"and prob_as_unknown_area operations will be "
|
78 |
+
"applied",
|
79 |
+
)
|
80 |
+
@click.option("--device", default="cpu", type=str, help="Processing Device.")
|
81 |
+
@click.option(
|
82 |
+
"--fp16", default=False, type=bool, help="Enables mixed precision processing."
|
83 |
+
)
|
84 |
+
def removebg(
|
85 |
+
i: str,
|
86 |
+
o: str,
|
87 |
+
pre: str,
|
88 |
+
post: str,
|
89 |
+
net: str,
|
90 |
+
recursive: bool,
|
91 |
+
batch_size: int,
|
92 |
+
batch_size_seg: int,
|
93 |
+
batch_size_mat: int,
|
94 |
+
seg_mask_size: int,
|
95 |
+
matting_mask_size: int,
|
96 |
+
device: str,
|
97 |
+
fp16: bool,
|
98 |
+
trimap_dilation: int,
|
99 |
+
trimap_erosion: int,
|
100 |
+
trimap_prob_threshold: int,
|
101 |
+
):
|
102 |
+
out_path = Path(o)
|
103 |
+
input_path = Path(i)
|
104 |
+
if input_path.is_dir():
|
105 |
+
if recursive:
|
106 |
+
all_images = input_path.rglob("*.*")
|
107 |
+
else:
|
108 |
+
all_images = input_path.glob("*.*")
|
109 |
+
all_images = [
|
110 |
+
i
|
111 |
+
for i in all_images
|
112 |
+
if i.suffix.lower() in ALLOWED_SUFFIXES and "_bg_removed" not in i.name
|
113 |
+
]
|
114 |
+
else:
|
115 |
+
all_images = [input_path]
|
116 |
+
|
117 |
+
interface_config = MLConfig(
|
118 |
+
segmentation_network=net,
|
119 |
+
preprocessing_method=pre,
|
120 |
+
postprocessing_method=post,
|
121 |
+
device=device,
|
122 |
+
batch_size_seg=batch_size_seg,
|
123 |
+
batch_size_matting=batch_size_mat,
|
124 |
+
seg_mask_size=seg_mask_size,
|
125 |
+
matting_mask_size=matting_mask_size,
|
126 |
+
fp16=fp16,
|
127 |
+
trimap_dilation=trimap_dilation,
|
128 |
+
trimap_erosion=trimap_erosion,
|
129 |
+
trimap_prob_threshold=trimap_prob_threshold,
|
130 |
+
)
|
131 |
+
|
132 |
+
interface = init_interface(interface_config)
|
133 |
+
|
134 |
+
for image_batch in tqdm.tqdm(
|
135 |
+
batch_generator(all_images, n=batch_size),
|
136 |
+
total=int(len(all_images) / batch_size),
|
137 |
+
desc="Removing background",
|
138 |
+
unit=" image batch",
|
139 |
+
colour="blue",
|
140 |
+
):
|
141 |
+
images_without_background = interface(image_batch) # Remove background
|
142 |
+
thread_pool_processing(
|
143 |
+
lambda x: save_file(out_path, image_batch[x], images_without_background[x]),
|
144 |
+
range((len(image_batch))),
|
145 |
+
) # Drop images to fs
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
removebg()
|
carvekit/api/__init__.py
ADDED
File without changes
|
carvekit/api/high.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
from carvekit.api.interface import Interface
|
9 |
+
from carvekit.ml.wrap.fba_matting import FBAMatting
|
10 |
+
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
|
11 |
+
from carvekit.ml.wrap.u2net import U2NET
|
12 |
+
from carvekit.pipelines.postprocessing import MattingMethod
|
13 |
+
from carvekit.trimap.generator import TrimapGenerator
|
14 |
+
|
15 |
+
|
16 |
+
class HiInterface(Interface):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
object_type: str = "object",
|
20 |
+
batch_size_seg=2,
|
21 |
+
batch_size_matting=1,
|
22 |
+
device="cpu",
|
23 |
+
seg_mask_size=640,
|
24 |
+
matting_mask_size=2048,
|
25 |
+
trimap_prob_threshold=231,
|
26 |
+
trimap_dilation=30,
|
27 |
+
trimap_erosion_iters=5,
|
28 |
+
fp16=False,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Initializes High Level interface.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
object_type: Interest object type. Can be "object" or "hairs-like".
|
35 |
+
matting_mask_size: The size of the input image for the matting neural network.
|
36 |
+
seg_mask_size: The size of the input image for the segmentation neural network.
|
37 |
+
batch_size_seg: Number of images processed per one segmentation neural network call.
|
38 |
+
batch_size_matting: Number of images processed per one matting neural network call.
|
39 |
+
device: Processing device
|
40 |
+
fp16: Use half precision. Reduce memory usage and increase speed. Experimental support
|
41 |
+
trimap_prob_threshold: Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied
|
42 |
+
trimap_dilation: The size of the offset radius from the object mask in pixels when forming an unknown area
|
43 |
+
trimap_erosion_iters: The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area
|
44 |
+
|
45 |
+
Notes:
|
46 |
+
1. Changing seg_mask_size may cause an out-of-memory error if the value is too large, and it may also
|
47 |
+
result in reduced precision. I do not recommend changing this value. You can change matting_mask_size in
|
48 |
+
range from (1024 to 4096) to improve object edge refining quality, but it will cause extra large RAM and
|
49 |
+
video memory consume. Also, you can change batch size to accelerate background removal, but it also causes
|
50 |
+
extra large video memory consume, if value is too big.
|
51 |
+
|
52 |
+
2. Changing trimap_prob_threshold, trimap_kernel_size, trimap_erosion_iters may improve object edge
|
53 |
+
refining quality,
|
54 |
+
"""
|
55 |
+
if object_type == "object":
|
56 |
+
self.u2net = TracerUniversalB7(
|
57 |
+
device=device,
|
58 |
+
batch_size=batch_size_seg,
|
59 |
+
input_image_size=seg_mask_size,
|
60 |
+
fp16=fp16,
|
61 |
+
)
|
62 |
+
elif object_type == "hairs-like":
|
63 |
+
self.u2net = U2NET(
|
64 |
+
device=device,
|
65 |
+
batch_size=batch_size_seg,
|
66 |
+
input_image_size=seg_mask_size,
|
67 |
+
fp16=fp16,
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
warnings.warn(
|
71 |
+
f"Unknown object type: {object_type}. Using default object type: object"
|
72 |
+
)
|
73 |
+
self.u2net = TracerUniversalB7(
|
74 |
+
device=device,
|
75 |
+
batch_size=batch_size_seg,
|
76 |
+
input_image_size=seg_mask_size,
|
77 |
+
fp16=fp16,
|
78 |
+
)
|
79 |
+
|
80 |
+
self.fba = FBAMatting(
|
81 |
+
batch_size=batch_size_matting,
|
82 |
+
device=device,
|
83 |
+
input_tensor_size=matting_mask_size,
|
84 |
+
fp16=fp16,
|
85 |
+
)
|
86 |
+
self.trimap_generator = TrimapGenerator(
|
87 |
+
prob_threshold=trimap_prob_threshold,
|
88 |
+
kernel_size=trimap_dilation,
|
89 |
+
erosion_iters=trimap_erosion_iters,
|
90 |
+
)
|
91 |
+
super(HiInterface, self).__init__(
|
92 |
+
pre_pipe=None,
|
93 |
+
seg_pipe=self.u2net,
|
94 |
+
post_pipe=MattingMethod(
|
95 |
+
matting_module=self.fba,
|
96 |
+
trimap_generator=self.trimap_generator,
|
97 |
+
device=device,
|
98 |
+
),
|
99 |
+
device=device,
|
100 |
+
)
|
carvekit/api/interface.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Union, List, Optional
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from carvekit.ml.wrap.basnet import BASNET
|
12 |
+
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
|
13 |
+
from carvekit.ml.wrap.u2net import U2NET
|
14 |
+
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
|
15 |
+
from carvekit.pipelines.preprocessing import PreprocessingStub
|
16 |
+
from carvekit.pipelines.postprocessing import MattingMethod
|
17 |
+
from carvekit.utils.image_utils import load_image
|
18 |
+
from carvekit.utils.mask_utils import apply_mask
|
19 |
+
from carvekit.utils.pool_utils import thread_pool_processing
|
20 |
+
|
21 |
+
|
22 |
+
class Interface:
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7],
|
26 |
+
pre_pipe: Optional[Union[PreprocessingStub]] = None,
|
27 |
+
post_pipe: Optional[Union[MattingMethod]] = None,
|
28 |
+
device="cpu",
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Initializes an object for interacting with pipelines and other components of the CarveKit framework.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
pre_pipe: Initialized pre-processing pipeline object
|
35 |
+
seg_pipe: Initialized segmentation network object
|
36 |
+
post_pipe: Initialized postprocessing pipeline object
|
37 |
+
device: The processing device that will be used to apply the masks to the images.
|
38 |
+
"""
|
39 |
+
self.device = device
|
40 |
+
self.preprocessing_pipeline = pre_pipe
|
41 |
+
self.segmentation_pipeline = seg_pipe
|
42 |
+
self.postprocessing_pipeline = post_pipe
|
43 |
+
|
44 |
+
def __call__(
|
45 |
+
self, images: List[Union[str, Path, Image.Image]]
|
46 |
+
) -> List[Image.Image]:
|
47 |
+
"""
|
48 |
+
Removes the background from the specified images.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
images: list of input images
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
List of images without background as PIL.Image.Image instances
|
55 |
+
"""
|
56 |
+
images = thread_pool_processing(load_image, images)
|
57 |
+
if self.preprocessing_pipeline is not None:
|
58 |
+
masks: List[Image.Image] = self.preprocessing_pipeline(
|
59 |
+
interface=self, images=images
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
masks: List[Image.Image] = self.segmentation_pipeline(images=images)
|
63 |
+
|
64 |
+
if self.postprocessing_pipeline is not None:
|
65 |
+
images: List[Image.Image] = self.postprocessing_pipeline(
|
66 |
+
images=images, masks=masks
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
images = list(
|
70 |
+
map(
|
71 |
+
lambda x: apply_mask(
|
72 |
+
image=images[x], mask=masks[x], device=self.device
|
73 |
+
),
|
74 |
+
range(len(images)),
|
75 |
+
)
|
76 |
+
)
|
77 |
+
return images
|
carvekit/ml/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from carvekit.utils.models_utils import fix_seed, suppress_warnings
|
2 |
+
|
3 |
+
fix_seed()
|
4 |
+
suppress_warnings()
|
carvekit/ml/arch/__init__.py
ADDED
File without changes
|
carvekit/ml/arch/basnet/__init__.py
ADDED
File without changes
|
carvekit/ml/arch/basnet/basnet.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/NathanUA/BASNet
|
3 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torchvision import models
|
9 |
+
|
10 |
+
|
11 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
12 |
+
"""3x3 convolution with padding"""
|
13 |
+
return nn.Conv2d(
|
14 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class BasicBlock(nn.Module):
|
19 |
+
expansion = 1
|
20 |
+
|
21 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
22 |
+
super(BasicBlock, self).__init__()
|
23 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
24 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
25 |
+
self.relu = nn.ReLU(inplace=True)
|
26 |
+
self.conv2 = conv3x3(planes, planes)
|
27 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
28 |
+
self.downsample = downsample
|
29 |
+
self.stride = stride
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
residual = x
|
33 |
+
|
34 |
+
out = self.conv1(x)
|
35 |
+
out = self.bn1(out)
|
36 |
+
out = self.relu(out)
|
37 |
+
|
38 |
+
out = self.conv2(out)
|
39 |
+
out = self.bn2(out)
|
40 |
+
|
41 |
+
if self.downsample is not None:
|
42 |
+
residual = self.downsample(x)
|
43 |
+
|
44 |
+
out += residual
|
45 |
+
out = self.relu(out)
|
46 |
+
|
47 |
+
return out
|
48 |
+
|
49 |
+
|
50 |
+
class BasicBlockDe(nn.Module):
|
51 |
+
expansion = 1
|
52 |
+
|
53 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
54 |
+
super(BasicBlockDe, self).__init__()
|
55 |
+
|
56 |
+
self.convRes = conv3x3(inplanes, planes, stride)
|
57 |
+
self.bnRes = nn.BatchNorm2d(planes)
|
58 |
+
self.reluRes = nn.ReLU(inplace=True)
|
59 |
+
|
60 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
61 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
62 |
+
self.relu = nn.ReLU(inplace=True)
|
63 |
+
self.conv2 = conv3x3(planes, planes)
|
64 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
65 |
+
self.downsample = downsample
|
66 |
+
self.stride = stride
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
residual = self.convRes(x)
|
70 |
+
residual = self.bnRes(residual)
|
71 |
+
residual = self.reluRes(residual)
|
72 |
+
|
73 |
+
out = self.conv1(x)
|
74 |
+
out = self.bn1(out)
|
75 |
+
out = self.relu(out)
|
76 |
+
|
77 |
+
out = self.conv2(out)
|
78 |
+
out = self.bn2(out)
|
79 |
+
|
80 |
+
if self.downsample is not None:
|
81 |
+
residual = self.downsample(x)
|
82 |
+
|
83 |
+
out += residual
|
84 |
+
out = self.relu(out)
|
85 |
+
|
86 |
+
return out
|
87 |
+
|
88 |
+
|
89 |
+
class Bottleneck(nn.Module):
|
90 |
+
expansion = 4
|
91 |
+
|
92 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
93 |
+
super(Bottleneck, self).__init__()
|
94 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
95 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
96 |
+
self.conv2 = nn.Conv2d(
|
97 |
+
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
98 |
+
)
|
99 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
100 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
101 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
102 |
+
self.relu = nn.ReLU(inplace=True)
|
103 |
+
self.downsample = downsample
|
104 |
+
self.stride = stride
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
residual = x
|
108 |
+
|
109 |
+
out = self.conv1(x)
|
110 |
+
out = self.bn1(out)
|
111 |
+
out = self.relu(out)
|
112 |
+
|
113 |
+
out = self.conv2(out)
|
114 |
+
out = self.bn2(out)
|
115 |
+
out = self.relu(out)
|
116 |
+
|
117 |
+
out = self.conv3(out)
|
118 |
+
out = self.bn3(out)
|
119 |
+
|
120 |
+
if self.downsample is not None:
|
121 |
+
residual = self.downsample(x)
|
122 |
+
|
123 |
+
out += residual
|
124 |
+
out = self.relu(out)
|
125 |
+
|
126 |
+
return out
|
127 |
+
|
128 |
+
|
129 |
+
class RefUnet(nn.Module):
|
130 |
+
def __init__(self, in_ch, inc_ch):
|
131 |
+
super(RefUnet, self).__init__()
|
132 |
+
|
133 |
+
self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1)
|
134 |
+
|
135 |
+
self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1)
|
136 |
+
self.bn1 = nn.BatchNorm2d(64)
|
137 |
+
self.relu1 = nn.ReLU(inplace=True)
|
138 |
+
|
139 |
+
self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
140 |
+
|
141 |
+
self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
|
142 |
+
self.bn2 = nn.BatchNorm2d(64)
|
143 |
+
self.relu2 = nn.ReLU(inplace=True)
|
144 |
+
|
145 |
+
self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
146 |
+
|
147 |
+
self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
|
148 |
+
self.bn3 = nn.BatchNorm2d(64)
|
149 |
+
self.relu3 = nn.ReLU(inplace=True)
|
150 |
+
|
151 |
+
self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
152 |
+
|
153 |
+
self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
|
154 |
+
self.bn4 = nn.BatchNorm2d(64)
|
155 |
+
self.relu4 = nn.ReLU(inplace=True)
|
156 |
+
|
157 |
+
self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
158 |
+
|
159 |
+
self.conv5 = nn.Conv2d(64, 64, 3, padding=1)
|
160 |
+
self.bn5 = nn.BatchNorm2d(64)
|
161 |
+
self.relu5 = nn.ReLU(inplace=True)
|
162 |
+
|
163 |
+
self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1)
|
164 |
+
self.bn_d4 = nn.BatchNorm2d(64)
|
165 |
+
self.relu_d4 = nn.ReLU(inplace=True)
|
166 |
+
|
167 |
+
self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1)
|
168 |
+
self.bn_d3 = nn.BatchNorm2d(64)
|
169 |
+
self.relu_d3 = nn.ReLU(inplace=True)
|
170 |
+
|
171 |
+
self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1)
|
172 |
+
self.bn_d2 = nn.BatchNorm2d(64)
|
173 |
+
self.relu_d2 = nn.ReLU(inplace=True)
|
174 |
+
|
175 |
+
self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1)
|
176 |
+
self.bn_d1 = nn.BatchNorm2d(64)
|
177 |
+
self.relu_d1 = nn.ReLU(inplace=True)
|
178 |
+
|
179 |
+
self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1)
|
180 |
+
|
181 |
+
self.upscore2 = nn.Upsample(
|
182 |
+
scale_factor=2, mode="bilinear", align_corners=False
|
183 |
+
)
|
184 |
+
|
185 |
+
def forward(self, x):
|
186 |
+
hx = x
|
187 |
+
hx = self.conv0(hx)
|
188 |
+
|
189 |
+
hx1 = self.relu1(self.bn1(self.conv1(hx)))
|
190 |
+
hx = self.pool1(hx1)
|
191 |
+
|
192 |
+
hx2 = self.relu2(self.bn2(self.conv2(hx)))
|
193 |
+
hx = self.pool2(hx2)
|
194 |
+
|
195 |
+
hx3 = self.relu3(self.bn3(self.conv3(hx)))
|
196 |
+
hx = self.pool3(hx3)
|
197 |
+
|
198 |
+
hx4 = self.relu4(self.bn4(self.conv4(hx)))
|
199 |
+
hx = self.pool4(hx4)
|
200 |
+
|
201 |
+
hx5 = self.relu5(self.bn5(self.conv5(hx)))
|
202 |
+
|
203 |
+
hx = self.upscore2(hx5)
|
204 |
+
|
205 |
+
d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx, hx4), 1))))
|
206 |
+
hx = self.upscore2(d4)
|
207 |
+
|
208 |
+
d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx, hx3), 1))))
|
209 |
+
hx = self.upscore2(d3)
|
210 |
+
|
211 |
+
d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx, hx2), 1))))
|
212 |
+
hx = self.upscore2(d2)
|
213 |
+
|
214 |
+
d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx, hx1), 1))))
|
215 |
+
|
216 |
+
residual = self.conv_d0(d1)
|
217 |
+
|
218 |
+
return x + residual
|
219 |
+
|
220 |
+
|
221 |
+
class BASNet(nn.Module):
|
222 |
+
def __init__(self, n_channels, n_classes):
|
223 |
+
super(BASNet, self).__init__()
|
224 |
+
|
225 |
+
resnet = models.resnet34(pretrained=False)
|
226 |
+
|
227 |
+
# -------------Encoder--------------
|
228 |
+
|
229 |
+
self.inconv = nn.Conv2d(n_channels, 64, 3, padding=1)
|
230 |
+
self.inbn = nn.BatchNorm2d(64)
|
231 |
+
self.inrelu = nn.ReLU(inplace=True)
|
232 |
+
|
233 |
+
# stage 1
|
234 |
+
self.encoder1 = resnet.layer1 # 224
|
235 |
+
# stage 2
|
236 |
+
self.encoder2 = resnet.layer2 # 112
|
237 |
+
# stage 3
|
238 |
+
self.encoder3 = resnet.layer3 # 56
|
239 |
+
# stage 4
|
240 |
+
self.encoder4 = resnet.layer4 # 28
|
241 |
+
|
242 |
+
self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
243 |
+
|
244 |
+
# stage 5
|
245 |
+
self.resb5_1 = BasicBlock(512, 512)
|
246 |
+
self.resb5_2 = BasicBlock(512, 512)
|
247 |
+
self.resb5_3 = BasicBlock(512, 512) # 14
|
248 |
+
|
249 |
+
self.pool5 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
250 |
+
|
251 |
+
# stage 6
|
252 |
+
self.resb6_1 = BasicBlock(512, 512)
|
253 |
+
self.resb6_2 = BasicBlock(512, 512)
|
254 |
+
self.resb6_3 = BasicBlock(512, 512) # 7
|
255 |
+
|
256 |
+
# -------------Bridge--------------
|
257 |
+
|
258 |
+
# stage Bridge
|
259 |
+
self.convbg_1 = nn.Conv2d(512, 512, 3, dilation=2, padding=2) # 7
|
260 |
+
self.bnbg_1 = nn.BatchNorm2d(512)
|
261 |
+
self.relubg_1 = nn.ReLU(inplace=True)
|
262 |
+
self.convbg_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
|
263 |
+
self.bnbg_m = nn.BatchNorm2d(512)
|
264 |
+
self.relubg_m = nn.ReLU(inplace=True)
|
265 |
+
self.convbg_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
|
266 |
+
self.bnbg_2 = nn.BatchNorm2d(512)
|
267 |
+
self.relubg_2 = nn.ReLU(inplace=True)
|
268 |
+
|
269 |
+
# -------------Decoder--------------
|
270 |
+
|
271 |
+
# stage 6d
|
272 |
+
self.conv6d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16
|
273 |
+
self.bn6d_1 = nn.BatchNorm2d(512)
|
274 |
+
self.relu6d_1 = nn.ReLU(inplace=True)
|
275 |
+
|
276 |
+
self.conv6d_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
|
277 |
+
self.bn6d_m = nn.BatchNorm2d(512)
|
278 |
+
self.relu6d_m = nn.ReLU(inplace=True)
|
279 |
+
|
280 |
+
self.conv6d_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
|
281 |
+
self.bn6d_2 = nn.BatchNorm2d(512)
|
282 |
+
self.relu6d_2 = nn.ReLU(inplace=True)
|
283 |
+
|
284 |
+
# stage 5d
|
285 |
+
self.conv5d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16
|
286 |
+
self.bn5d_1 = nn.BatchNorm2d(512)
|
287 |
+
self.relu5d_1 = nn.ReLU(inplace=True)
|
288 |
+
|
289 |
+
self.conv5d_m = nn.Conv2d(512, 512, 3, padding=1)
|
290 |
+
self.bn5d_m = nn.BatchNorm2d(512)
|
291 |
+
self.relu5d_m = nn.ReLU(inplace=True)
|
292 |
+
|
293 |
+
self.conv5d_2 = nn.Conv2d(512, 512, 3, padding=1)
|
294 |
+
self.bn5d_2 = nn.BatchNorm2d(512)
|
295 |
+
self.relu5d_2 = nn.ReLU(inplace=True)
|
296 |
+
|
297 |
+
# stage 4d
|
298 |
+
self.conv4d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 32
|
299 |
+
self.bn4d_1 = nn.BatchNorm2d(512)
|
300 |
+
self.relu4d_1 = nn.ReLU(inplace=True)
|
301 |
+
|
302 |
+
self.conv4d_m = nn.Conv2d(512, 512, 3, padding=1)
|
303 |
+
self.bn4d_m = nn.BatchNorm2d(512)
|
304 |
+
self.relu4d_m = nn.ReLU(inplace=True)
|
305 |
+
|
306 |
+
self.conv4d_2 = nn.Conv2d(512, 256, 3, padding=1)
|
307 |
+
self.bn4d_2 = nn.BatchNorm2d(256)
|
308 |
+
self.relu4d_2 = nn.ReLU(inplace=True)
|
309 |
+
|
310 |
+
# stage 3d
|
311 |
+
self.conv3d_1 = nn.Conv2d(512, 256, 3, padding=1) # 64
|
312 |
+
self.bn3d_1 = nn.BatchNorm2d(256)
|
313 |
+
self.relu3d_1 = nn.ReLU(inplace=True)
|
314 |
+
|
315 |
+
self.conv3d_m = nn.Conv2d(256, 256, 3, padding=1)
|
316 |
+
self.bn3d_m = nn.BatchNorm2d(256)
|
317 |
+
self.relu3d_m = nn.ReLU(inplace=True)
|
318 |
+
|
319 |
+
self.conv3d_2 = nn.Conv2d(256, 128, 3, padding=1)
|
320 |
+
self.bn3d_2 = nn.BatchNorm2d(128)
|
321 |
+
self.relu3d_2 = nn.ReLU(inplace=True)
|
322 |
+
|
323 |
+
# stage 2d
|
324 |
+
|
325 |
+
self.conv2d_1 = nn.Conv2d(256, 128, 3, padding=1) # 128
|
326 |
+
self.bn2d_1 = nn.BatchNorm2d(128)
|
327 |
+
self.relu2d_1 = nn.ReLU(inplace=True)
|
328 |
+
|
329 |
+
self.conv2d_m = nn.Conv2d(128, 128, 3, padding=1)
|
330 |
+
self.bn2d_m = nn.BatchNorm2d(128)
|
331 |
+
self.relu2d_m = nn.ReLU(inplace=True)
|
332 |
+
|
333 |
+
self.conv2d_2 = nn.Conv2d(128, 64, 3, padding=1)
|
334 |
+
self.bn2d_2 = nn.BatchNorm2d(64)
|
335 |
+
self.relu2d_2 = nn.ReLU(inplace=True)
|
336 |
+
|
337 |
+
# stage 1d
|
338 |
+
self.conv1d_1 = nn.Conv2d(128, 64, 3, padding=1) # 256
|
339 |
+
self.bn1d_1 = nn.BatchNorm2d(64)
|
340 |
+
self.relu1d_1 = nn.ReLU(inplace=True)
|
341 |
+
|
342 |
+
self.conv1d_m = nn.Conv2d(64, 64, 3, padding=1)
|
343 |
+
self.bn1d_m = nn.BatchNorm2d(64)
|
344 |
+
self.relu1d_m = nn.ReLU(inplace=True)
|
345 |
+
|
346 |
+
self.conv1d_2 = nn.Conv2d(64, 64, 3, padding=1)
|
347 |
+
self.bn1d_2 = nn.BatchNorm2d(64)
|
348 |
+
self.relu1d_2 = nn.ReLU(inplace=True)
|
349 |
+
|
350 |
+
# -------------Bilinear Upsampling--------------
|
351 |
+
self.upscore6 = nn.Upsample(
|
352 |
+
scale_factor=32, mode="bilinear", align_corners=False
|
353 |
+
)
|
354 |
+
self.upscore5 = nn.Upsample(
|
355 |
+
scale_factor=16, mode="bilinear", align_corners=False
|
356 |
+
)
|
357 |
+
self.upscore4 = nn.Upsample(
|
358 |
+
scale_factor=8, mode="bilinear", align_corners=False
|
359 |
+
)
|
360 |
+
self.upscore3 = nn.Upsample(
|
361 |
+
scale_factor=4, mode="bilinear", align_corners=False
|
362 |
+
)
|
363 |
+
self.upscore2 = nn.Upsample(
|
364 |
+
scale_factor=2, mode="bilinear", align_corners=False
|
365 |
+
)
|
366 |
+
|
367 |
+
# -------------Side Output--------------
|
368 |
+
self.outconvb = nn.Conv2d(512, 1, 3, padding=1)
|
369 |
+
self.outconv6 = nn.Conv2d(512, 1, 3, padding=1)
|
370 |
+
self.outconv5 = nn.Conv2d(512, 1, 3, padding=1)
|
371 |
+
self.outconv4 = nn.Conv2d(256, 1, 3, padding=1)
|
372 |
+
self.outconv3 = nn.Conv2d(128, 1, 3, padding=1)
|
373 |
+
self.outconv2 = nn.Conv2d(64, 1, 3, padding=1)
|
374 |
+
self.outconv1 = nn.Conv2d(64, 1, 3, padding=1)
|
375 |
+
|
376 |
+
# -------------Refine Module-------------
|
377 |
+
self.refunet = RefUnet(1, 64)
|
378 |
+
|
379 |
+
def forward(self, x):
|
380 |
+
hx = x
|
381 |
+
|
382 |
+
# -------------Encoder-------------
|
383 |
+
hx = self.inconv(hx)
|
384 |
+
hx = self.inbn(hx)
|
385 |
+
hx = self.inrelu(hx)
|
386 |
+
|
387 |
+
h1 = self.encoder1(hx) # 256
|
388 |
+
h2 = self.encoder2(h1) # 128
|
389 |
+
h3 = self.encoder3(h2) # 64
|
390 |
+
h4 = self.encoder4(h3) # 32
|
391 |
+
|
392 |
+
hx = self.pool4(h4) # 16
|
393 |
+
|
394 |
+
hx = self.resb5_1(hx)
|
395 |
+
hx = self.resb5_2(hx)
|
396 |
+
h5 = self.resb5_3(hx)
|
397 |
+
|
398 |
+
hx = self.pool5(h5) # 8
|
399 |
+
|
400 |
+
hx = self.resb6_1(hx)
|
401 |
+
hx = self.resb6_2(hx)
|
402 |
+
h6 = self.resb6_3(hx)
|
403 |
+
|
404 |
+
# -------------Bridge-------------
|
405 |
+
hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6))) # 8
|
406 |
+
hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx)))
|
407 |
+
hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx)))
|
408 |
+
|
409 |
+
# -------------Decoder-------------
|
410 |
+
|
411 |
+
hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg, h6), 1))))
|
412 |
+
hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx)))
|
413 |
+
hd6 = self.relu6d_2(self.bn6d_2(self.conv6d_2(hx)))
|
414 |
+
|
415 |
+
hx = self.upscore2(hd6) # 8 -> 16
|
416 |
+
|
417 |
+
hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx, h5), 1))))
|
418 |
+
hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx)))
|
419 |
+
hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx)))
|
420 |
+
|
421 |
+
hx = self.upscore2(hd5) # 16 -> 32
|
422 |
+
|
423 |
+
hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx, h4), 1))))
|
424 |
+
hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx)))
|
425 |
+
hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx)))
|
426 |
+
|
427 |
+
hx = self.upscore2(hd4) # 32 -> 64
|
428 |
+
|
429 |
+
hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx, h3), 1))))
|
430 |
+
hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx)))
|
431 |
+
hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx)))
|
432 |
+
|
433 |
+
hx = self.upscore2(hd3) # 64 -> 128
|
434 |
+
|
435 |
+
hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx, h2), 1))))
|
436 |
+
hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx)))
|
437 |
+
hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx)))
|
438 |
+
|
439 |
+
hx = self.upscore2(hd2) # 128 -> 256
|
440 |
+
|
441 |
+
hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx, h1), 1))))
|
442 |
+
hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx)))
|
443 |
+
hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx)))
|
444 |
+
|
445 |
+
# -------------Side Output-------------
|
446 |
+
db = self.outconvb(hbg)
|
447 |
+
db = self.upscore6(db) # 8->256
|
448 |
+
|
449 |
+
d6 = self.outconv6(hd6)
|
450 |
+
d6 = self.upscore6(d6) # 8->256
|
451 |
+
|
452 |
+
d5 = self.outconv5(hd5)
|
453 |
+
d5 = self.upscore5(d5) # 16->256
|
454 |
+
|
455 |
+
d4 = self.outconv4(hd4)
|
456 |
+
d4 = self.upscore4(d4) # 32->256
|
457 |
+
|
458 |
+
d3 = self.outconv3(hd3)
|
459 |
+
d3 = self.upscore3(d3) # 64->256
|
460 |
+
|
461 |
+
d2 = self.outconv2(hd2)
|
462 |
+
d2 = self.upscore2(d2) # 128->256
|
463 |
+
|
464 |
+
d1 = self.outconv1(hd1) # 256
|
465 |
+
|
466 |
+
# -------------Refine Module-------------
|
467 |
+
dout = self.refunet(d1) # 256
|
468 |
+
|
469 |
+
return (
|
470 |
+
torch.sigmoid(dout),
|
471 |
+
torch.sigmoid(d1),
|
472 |
+
torch.sigmoid(d2),
|
473 |
+
torch.sigmoid(d3),
|
474 |
+
torch.sigmoid(d4),
|
475 |
+
torch.sigmoid(d5),
|
476 |
+
torch.sigmoid(d6),
|
477 |
+
torch.sigmoid(db),
|
478 |
+
)
|
carvekit/ml/arch/fba_matting/__init__.py
ADDED
File without changes
|
carvekit/ml/arch/fba_matting/layers_WS.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/MarcoForte/FBA_Matting
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class Conv2d(nn.Conv2d):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
in_channels,
|
15 |
+
out_channels,
|
16 |
+
kernel_size,
|
17 |
+
stride=1,
|
18 |
+
padding=0,
|
19 |
+
dilation=1,
|
20 |
+
groups=1,
|
21 |
+
bias=True,
|
22 |
+
):
|
23 |
+
super(Conv2d, self).__init__(
|
24 |
+
in_channels,
|
25 |
+
out_channels,
|
26 |
+
kernel_size,
|
27 |
+
stride,
|
28 |
+
padding,
|
29 |
+
dilation,
|
30 |
+
groups,
|
31 |
+
bias,
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
# return super(Conv2d, self).forward(x)
|
36 |
+
weight = self.weight
|
37 |
+
weight_mean = (
|
38 |
+
weight.mean(dim=1, keepdim=True)
|
39 |
+
.mean(dim=2, keepdim=True)
|
40 |
+
.mean(dim=3, keepdim=True)
|
41 |
+
)
|
42 |
+
weight = weight - weight_mean
|
43 |
+
# std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
|
44 |
+
std = (
|
45 |
+
torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(
|
46 |
+
-1, 1, 1, 1
|
47 |
+
)
|
48 |
+
+ 1e-5
|
49 |
+
)
|
50 |
+
weight = weight / std.expand_as(weight)
|
51 |
+
return F.conv2d(
|
52 |
+
x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
def BatchNorm2d(num_features):
|
57 |
+
return nn.GroupNorm(num_channels=num_features, num_groups=32)
|
carvekit/ml/arch/fba_matting/models.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/MarcoForte/FBA_Matting
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import carvekit.ml.arch.fba_matting.resnet_GN_WS as resnet_GN_WS
|
9 |
+
import carvekit.ml.arch.fba_matting.layers_WS as L
|
10 |
+
import carvekit.ml.arch.fba_matting.resnet_bn as resnet_bn
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
|
14 |
+
class FBA(nn.Module):
|
15 |
+
def __init__(self, encoder: str):
|
16 |
+
super(FBA, self).__init__()
|
17 |
+
self.encoder = build_encoder(arch=encoder)
|
18 |
+
self.decoder = fba_decoder(batch_norm=True if "BN" in encoder else False)
|
19 |
+
|
20 |
+
def forward(self, image, two_chan_trimap, image_n, trimap_transformed):
|
21 |
+
resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
|
22 |
+
conv_out, indices = self.encoder(resnet_input, return_feature_maps=True)
|
23 |
+
return self.decoder(conv_out, image, indices, two_chan_trimap)
|
24 |
+
|
25 |
+
|
26 |
+
class ResnetDilatedBN(nn.Module):
|
27 |
+
def __init__(self, orig_resnet, dilate_scale=8):
|
28 |
+
super(ResnetDilatedBN, self).__init__()
|
29 |
+
|
30 |
+
if dilate_scale == 8:
|
31 |
+
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
|
32 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
|
33 |
+
elif dilate_scale == 16:
|
34 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
|
35 |
+
|
36 |
+
# take pretrained resnet, except AvgPool and FC
|
37 |
+
self.conv1 = orig_resnet.conv1
|
38 |
+
self.bn1 = orig_resnet.bn1
|
39 |
+
self.relu1 = orig_resnet.relu1
|
40 |
+
self.conv2 = orig_resnet.conv2
|
41 |
+
self.bn2 = orig_resnet.bn2
|
42 |
+
self.relu2 = orig_resnet.relu2
|
43 |
+
self.conv3 = orig_resnet.conv3
|
44 |
+
self.bn3 = orig_resnet.bn3
|
45 |
+
self.relu3 = orig_resnet.relu3
|
46 |
+
self.maxpool = orig_resnet.maxpool
|
47 |
+
self.layer1 = orig_resnet.layer1
|
48 |
+
self.layer2 = orig_resnet.layer2
|
49 |
+
self.layer3 = orig_resnet.layer3
|
50 |
+
self.layer4 = orig_resnet.layer4
|
51 |
+
|
52 |
+
def _nostride_dilate(self, m, dilate):
|
53 |
+
classname = m.__class__.__name__
|
54 |
+
if classname.find("Conv") != -1:
|
55 |
+
# the convolution with stride
|
56 |
+
if m.stride == (2, 2):
|
57 |
+
m.stride = (1, 1)
|
58 |
+
if m.kernel_size == (3, 3):
|
59 |
+
m.dilation = (dilate // 2, dilate // 2)
|
60 |
+
m.padding = (dilate // 2, dilate // 2)
|
61 |
+
# other convoluions
|
62 |
+
else:
|
63 |
+
if m.kernel_size == (3, 3):
|
64 |
+
m.dilation = (dilate, dilate)
|
65 |
+
m.padding = (dilate, dilate)
|
66 |
+
|
67 |
+
def forward(self, x, return_feature_maps=False):
|
68 |
+
conv_out = [x]
|
69 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
70 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
71 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
72 |
+
conv_out.append(x)
|
73 |
+
x, indices = self.maxpool(x)
|
74 |
+
x = self.layer1(x)
|
75 |
+
conv_out.append(x)
|
76 |
+
x = self.layer2(x)
|
77 |
+
conv_out.append(x)
|
78 |
+
x = self.layer3(x)
|
79 |
+
conv_out.append(x)
|
80 |
+
x = self.layer4(x)
|
81 |
+
conv_out.append(x)
|
82 |
+
|
83 |
+
if return_feature_maps:
|
84 |
+
return conv_out, indices
|
85 |
+
return [x]
|
86 |
+
|
87 |
+
|
88 |
+
class Resnet(nn.Module):
|
89 |
+
def __init__(self, orig_resnet):
|
90 |
+
super(Resnet, self).__init__()
|
91 |
+
|
92 |
+
# take pretrained resnet, except AvgPool and FC
|
93 |
+
self.conv1 = orig_resnet.conv1
|
94 |
+
self.bn1 = orig_resnet.bn1
|
95 |
+
self.relu1 = orig_resnet.relu1
|
96 |
+
self.conv2 = orig_resnet.conv2
|
97 |
+
self.bn2 = orig_resnet.bn2
|
98 |
+
self.relu2 = orig_resnet.relu2
|
99 |
+
self.conv3 = orig_resnet.conv3
|
100 |
+
self.bn3 = orig_resnet.bn3
|
101 |
+
self.relu3 = orig_resnet.relu3
|
102 |
+
self.maxpool = orig_resnet.maxpool
|
103 |
+
self.layer1 = orig_resnet.layer1
|
104 |
+
self.layer2 = orig_resnet.layer2
|
105 |
+
self.layer3 = orig_resnet.layer3
|
106 |
+
self.layer4 = orig_resnet.layer4
|
107 |
+
|
108 |
+
def forward(self, x, return_feature_maps=False):
|
109 |
+
conv_out = []
|
110 |
+
|
111 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
112 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
113 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
114 |
+
conv_out.append(x)
|
115 |
+
x, indices = self.maxpool(x)
|
116 |
+
|
117 |
+
x = self.layer1(x)
|
118 |
+
conv_out.append(x)
|
119 |
+
x = self.layer2(x)
|
120 |
+
conv_out.append(x)
|
121 |
+
x = self.layer3(x)
|
122 |
+
conv_out.append(x)
|
123 |
+
x = self.layer4(x)
|
124 |
+
conv_out.append(x)
|
125 |
+
|
126 |
+
if return_feature_maps:
|
127 |
+
return conv_out
|
128 |
+
return [x]
|
129 |
+
|
130 |
+
|
131 |
+
class ResnetDilated(nn.Module):
|
132 |
+
def __init__(self, orig_resnet, dilate_scale=8):
|
133 |
+
super(ResnetDilated, self).__init__()
|
134 |
+
|
135 |
+
if dilate_scale == 8:
|
136 |
+
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
|
137 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
|
138 |
+
elif dilate_scale == 16:
|
139 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
|
140 |
+
|
141 |
+
# take pretrained resnet, except AvgPool and FC
|
142 |
+
self.conv1 = orig_resnet.conv1
|
143 |
+
self.bn1 = orig_resnet.bn1
|
144 |
+
self.relu = orig_resnet.relu
|
145 |
+
self.maxpool = orig_resnet.maxpool
|
146 |
+
self.layer1 = orig_resnet.layer1
|
147 |
+
self.layer2 = orig_resnet.layer2
|
148 |
+
self.layer3 = orig_resnet.layer3
|
149 |
+
self.layer4 = orig_resnet.layer4
|
150 |
+
|
151 |
+
def _nostride_dilate(self, m, dilate):
|
152 |
+
classname = m.__class__.__name__
|
153 |
+
if classname.find("Conv") != -1:
|
154 |
+
# the convolution with stride
|
155 |
+
if m.stride == (2, 2):
|
156 |
+
m.stride = (1, 1)
|
157 |
+
if m.kernel_size == (3, 3):
|
158 |
+
m.dilation = (dilate // 2, dilate // 2)
|
159 |
+
m.padding = (dilate // 2, dilate // 2)
|
160 |
+
# other convoluions
|
161 |
+
else:
|
162 |
+
if m.kernel_size == (3, 3):
|
163 |
+
m.dilation = (dilate, dilate)
|
164 |
+
m.padding = (dilate, dilate)
|
165 |
+
|
166 |
+
def forward(self, x, return_feature_maps=False):
|
167 |
+
conv_out = [x]
|
168 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
169 |
+
conv_out.append(x)
|
170 |
+
x, indices = self.maxpool(x)
|
171 |
+
x = self.layer1(x)
|
172 |
+
conv_out.append(x)
|
173 |
+
x = self.layer2(x)
|
174 |
+
conv_out.append(x)
|
175 |
+
x = self.layer3(x)
|
176 |
+
conv_out.append(x)
|
177 |
+
x = self.layer4(x)
|
178 |
+
conv_out.append(x)
|
179 |
+
|
180 |
+
if return_feature_maps:
|
181 |
+
return conv_out, indices
|
182 |
+
return [x]
|
183 |
+
|
184 |
+
|
185 |
+
def norm(dim, bn=False):
|
186 |
+
if bn is False:
|
187 |
+
return nn.GroupNorm(32, dim)
|
188 |
+
else:
|
189 |
+
return nn.BatchNorm2d(dim)
|
190 |
+
|
191 |
+
|
192 |
+
def fba_fusion(alpha, img, F, B):
|
193 |
+
F = alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B
|
194 |
+
B = (1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F
|
195 |
+
|
196 |
+
F = torch.clamp(F, 0, 1)
|
197 |
+
B = torch.clamp(B, 0, 1)
|
198 |
+
la = 0.1
|
199 |
+
alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (
|
200 |
+
torch.sum((F - B) * (F - B), 1, keepdim=True) + la
|
201 |
+
)
|
202 |
+
alpha = torch.clamp(alpha, 0, 1)
|
203 |
+
return alpha, F, B
|
204 |
+
|
205 |
+
|
206 |
+
class fba_decoder(nn.Module):
|
207 |
+
def __init__(self, batch_norm=False):
|
208 |
+
super(fba_decoder, self).__init__()
|
209 |
+
pool_scales = (1, 2, 3, 6)
|
210 |
+
self.batch_norm = batch_norm
|
211 |
+
|
212 |
+
self.ppm = []
|
213 |
+
|
214 |
+
for scale in pool_scales:
|
215 |
+
self.ppm.append(
|
216 |
+
nn.Sequential(
|
217 |
+
nn.AdaptiveAvgPool2d(scale),
|
218 |
+
L.Conv2d(2048, 256, kernel_size=1, bias=True),
|
219 |
+
norm(256, self.batch_norm),
|
220 |
+
nn.LeakyReLU(),
|
221 |
+
)
|
222 |
+
)
|
223 |
+
self.ppm = nn.ModuleList(self.ppm)
|
224 |
+
|
225 |
+
self.conv_up1 = nn.Sequential(
|
226 |
+
L.Conv2d(
|
227 |
+
2048 + len(pool_scales) * 256, 256, kernel_size=3, padding=1, bias=True
|
228 |
+
),
|
229 |
+
norm(256, self.batch_norm),
|
230 |
+
nn.LeakyReLU(),
|
231 |
+
L.Conv2d(256, 256, kernel_size=3, padding=1),
|
232 |
+
norm(256, self.batch_norm),
|
233 |
+
nn.LeakyReLU(),
|
234 |
+
)
|
235 |
+
|
236 |
+
self.conv_up2 = nn.Sequential(
|
237 |
+
L.Conv2d(256 + 256, 256, kernel_size=3, padding=1, bias=True),
|
238 |
+
norm(256, self.batch_norm),
|
239 |
+
nn.LeakyReLU(),
|
240 |
+
)
|
241 |
+
if self.batch_norm:
|
242 |
+
d_up3 = 128
|
243 |
+
else:
|
244 |
+
d_up3 = 64
|
245 |
+
self.conv_up3 = nn.Sequential(
|
246 |
+
L.Conv2d(256 + d_up3, 64, kernel_size=3, padding=1, bias=True),
|
247 |
+
norm(64, self.batch_norm),
|
248 |
+
nn.LeakyReLU(),
|
249 |
+
)
|
250 |
+
|
251 |
+
self.unpool = nn.MaxUnpool2d(2, stride=2)
|
252 |
+
|
253 |
+
self.conv_up4 = nn.Sequential(
|
254 |
+
nn.Conv2d(64 + 3 + 3 + 2, 32, kernel_size=3, padding=1, bias=True),
|
255 |
+
nn.LeakyReLU(),
|
256 |
+
nn.Conv2d(32, 16, kernel_size=3, padding=1, bias=True),
|
257 |
+
nn.LeakyReLU(),
|
258 |
+
nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True),
|
259 |
+
)
|
260 |
+
|
261 |
+
def forward(self, conv_out, img, indices, two_chan_trimap):
|
262 |
+
conv5 = conv_out[-1]
|
263 |
+
|
264 |
+
input_size = conv5.size()
|
265 |
+
ppm_out = [conv5]
|
266 |
+
for pool_scale in self.ppm:
|
267 |
+
ppm_out.append(
|
268 |
+
nn.functional.interpolate(
|
269 |
+
pool_scale(conv5),
|
270 |
+
(input_size[2], input_size[3]),
|
271 |
+
mode="bilinear",
|
272 |
+
align_corners=False,
|
273 |
+
)
|
274 |
+
)
|
275 |
+
ppm_out = torch.cat(ppm_out, 1)
|
276 |
+
x = self.conv_up1(ppm_out)
|
277 |
+
|
278 |
+
x = torch.nn.functional.interpolate(
|
279 |
+
x, scale_factor=2, mode="bilinear", align_corners=False
|
280 |
+
)
|
281 |
+
|
282 |
+
x = torch.cat((x, conv_out[-4]), 1)
|
283 |
+
|
284 |
+
x = self.conv_up2(x)
|
285 |
+
x = torch.nn.functional.interpolate(
|
286 |
+
x, scale_factor=2, mode="bilinear", align_corners=False
|
287 |
+
)
|
288 |
+
|
289 |
+
x = torch.cat((x, conv_out[-5]), 1)
|
290 |
+
x = self.conv_up3(x)
|
291 |
+
|
292 |
+
x = torch.nn.functional.interpolate(
|
293 |
+
x, scale_factor=2, mode="bilinear", align_corners=False
|
294 |
+
)
|
295 |
+
x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
|
296 |
+
|
297 |
+
output = self.conv_up4(x)
|
298 |
+
|
299 |
+
alpha = torch.clamp(output[:, 0][:, None], 0, 1)
|
300 |
+
F = torch.sigmoid(output[:, 1:4])
|
301 |
+
B = torch.sigmoid(output[:, 4:7])
|
302 |
+
|
303 |
+
# FBA Fusion
|
304 |
+
alpha, F, B = fba_fusion(alpha, img, F, B)
|
305 |
+
|
306 |
+
output = torch.cat((alpha, F, B), 1)
|
307 |
+
|
308 |
+
return output
|
309 |
+
|
310 |
+
|
311 |
+
def build_encoder(arch="resnet50_GN"):
|
312 |
+
if arch == "resnet50_GN_WS":
|
313 |
+
orig_resnet = resnet_GN_WS.__dict__["l_resnet50"]()
|
314 |
+
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
|
315 |
+
elif arch == "resnet50_BN":
|
316 |
+
orig_resnet = resnet_bn.__dict__["l_resnet50"]()
|
317 |
+
net_encoder = ResnetDilatedBN(orig_resnet, dilate_scale=8)
|
318 |
+
|
319 |
+
else:
|
320 |
+
raise ValueError("Architecture undefined!")
|
321 |
+
|
322 |
+
num_channels = 3 + 6 + 2
|
323 |
+
|
324 |
+
if num_channels > 3:
|
325 |
+
net_encoder_sd = net_encoder.state_dict()
|
326 |
+
conv1_weights = net_encoder_sd["conv1.weight"]
|
327 |
+
|
328 |
+
c_out, c_in, h, w = conv1_weights.size()
|
329 |
+
conv1_mod = torch.zeros(c_out, num_channels, h, w)
|
330 |
+
conv1_mod[:, :3, :, :] = conv1_weights
|
331 |
+
|
332 |
+
conv1 = net_encoder.conv1
|
333 |
+
conv1.in_channels = num_channels
|
334 |
+
conv1.weight = torch.nn.Parameter(conv1_mod)
|
335 |
+
|
336 |
+
net_encoder.conv1 = conv1
|
337 |
+
|
338 |
+
net_encoder_sd["conv1.weight"] = conv1_mod
|
339 |
+
|
340 |
+
net_encoder.load_state_dict(net_encoder_sd)
|
341 |
+
return net_encoder
|
carvekit/ml/arch/fba_matting/resnet_GN_WS.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/MarcoForte/FBA_Matting
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import torch.nn as nn
|
7 |
+
import carvekit.ml.arch.fba_matting.layers_WS as L
|
8 |
+
|
9 |
+
__all__ = ["ResNet", "l_resnet50"]
|
10 |
+
|
11 |
+
|
12 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
13 |
+
"""3x3 convolution with padding"""
|
14 |
+
return L.Conv2d(
|
15 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
20 |
+
"""1x1 convolution"""
|
21 |
+
return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
22 |
+
|
23 |
+
|
24 |
+
class BasicBlock(nn.Module):
|
25 |
+
expansion = 1
|
26 |
+
|
27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
28 |
+
super(BasicBlock, self).__init__()
|
29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
30 |
+
self.bn1 = L.BatchNorm2d(planes)
|
31 |
+
self.relu = nn.ReLU(inplace=True)
|
32 |
+
self.conv2 = conv3x3(planes, planes)
|
33 |
+
self.bn2 = L.BatchNorm2d(planes)
|
34 |
+
self.downsample = downsample
|
35 |
+
self.stride = stride
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
identity = x
|
39 |
+
|
40 |
+
out = self.conv1(x)
|
41 |
+
out = self.bn1(out)
|
42 |
+
out = self.relu(out)
|
43 |
+
|
44 |
+
out = self.conv2(out)
|
45 |
+
out = self.bn2(out)
|
46 |
+
|
47 |
+
if self.downsample is not None:
|
48 |
+
identity = self.downsample(x)
|
49 |
+
|
50 |
+
out += identity
|
51 |
+
out = self.relu(out)
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class Bottleneck(nn.Module):
|
57 |
+
expansion = 4
|
58 |
+
|
59 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
60 |
+
super(Bottleneck, self).__init__()
|
61 |
+
self.conv1 = conv1x1(inplanes, planes)
|
62 |
+
self.bn1 = L.BatchNorm2d(planes)
|
63 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
64 |
+
self.bn2 = L.BatchNorm2d(planes)
|
65 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
66 |
+
self.bn3 = L.BatchNorm2d(planes * self.expansion)
|
67 |
+
self.relu = nn.ReLU(inplace=True)
|
68 |
+
self.downsample = downsample
|
69 |
+
self.stride = stride
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
identity = x
|
73 |
+
|
74 |
+
out = self.conv1(x)
|
75 |
+
out = self.bn1(out)
|
76 |
+
out = self.relu(out)
|
77 |
+
|
78 |
+
out = self.conv2(out)
|
79 |
+
out = self.bn2(out)
|
80 |
+
out = self.relu(out)
|
81 |
+
|
82 |
+
out = self.conv3(out)
|
83 |
+
out = self.bn3(out)
|
84 |
+
|
85 |
+
if self.downsample is not None:
|
86 |
+
identity = self.downsample(x)
|
87 |
+
|
88 |
+
out += identity
|
89 |
+
out = self.relu(out)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class ResNet(nn.Module):
|
95 |
+
def __init__(self, block, layers, num_classes=1000):
|
96 |
+
super(ResNet, self).__init__()
|
97 |
+
self.inplanes = 64
|
98 |
+
self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
99 |
+
self.bn1 = L.BatchNorm2d(64)
|
100 |
+
self.relu = nn.ReLU(inplace=True)
|
101 |
+
self.maxpool = nn.MaxPool2d(
|
102 |
+
kernel_size=3, stride=2, padding=1, return_indices=True
|
103 |
+
)
|
104 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
105 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
106 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
107 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
108 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
109 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
110 |
+
|
111 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
112 |
+
downsample = None
|
113 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
114 |
+
downsample = nn.Sequential(
|
115 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
116 |
+
L.BatchNorm2d(planes * block.expansion),
|
117 |
+
)
|
118 |
+
|
119 |
+
layers = []
|
120 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
121 |
+
self.inplanes = planes * block.expansion
|
122 |
+
for _ in range(1, blocks):
|
123 |
+
layers.append(block(self.inplanes, planes))
|
124 |
+
|
125 |
+
return nn.Sequential(*layers)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
x = self.conv1(x)
|
129 |
+
x = self.bn1(x)
|
130 |
+
x = self.relu(x)
|
131 |
+
x = self.maxpool(x)
|
132 |
+
|
133 |
+
x = self.layer1(x)
|
134 |
+
x = self.layer2(x)
|
135 |
+
x = self.layer3(x)
|
136 |
+
x = self.layer4(x)
|
137 |
+
|
138 |
+
x = self.avgpool(x)
|
139 |
+
x = x.view(x.size(0), -1)
|
140 |
+
x = self.fc(x)
|
141 |
+
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
def l_resnet50(pretrained=False, **kwargs):
|
146 |
+
"""Constructs a ResNet-50 model.
|
147 |
+
Args:
|
148 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
149 |
+
"""
|
150 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
151 |
+
return model
|
carvekit/ml/arch/fba_matting/resnet_bn.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/MarcoForte/FBA_Matting
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import torch.nn as nn
|
7 |
+
import math
|
8 |
+
from torch.nn import BatchNorm2d
|
9 |
+
|
10 |
+
__all__ = ["ResNet"]
|
11 |
+
|
12 |
+
|
13 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
14 |
+
"3x3 convolution with padding"
|
15 |
+
return nn.Conv2d(
|
16 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class BasicBlock(nn.Module):
|
21 |
+
expansion = 1
|
22 |
+
|
23 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
24 |
+
super(BasicBlock, self).__init__()
|
25 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
26 |
+
self.bn1 = BatchNorm2d(planes)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.conv2 = conv3x3(planes, planes)
|
29 |
+
self.bn2 = BatchNorm2d(planes)
|
30 |
+
self.downsample = downsample
|
31 |
+
self.stride = stride
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
residual = x
|
35 |
+
|
36 |
+
out = self.conv1(x)
|
37 |
+
out = self.bn1(out)
|
38 |
+
out = self.relu(out)
|
39 |
+
|
40 |
+
out = self.conv2(out)
|
41 |
+
out = self.bn2(out)
|
42 |
+
|
43 |
+
if self.downsample is not None:
|
44 |
+
residual = self.downsample(x)
|
45 |
+
|
46 |
+
out += residual
|
47 |
+
out = self.relu(out)
|
48 |
+
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
class Bottleneck(nn.Module):
|
53 |
+
expansion = 4
|
54 |
+
|
55 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
56 |
+
super(Bottleneck, self).__init__()
|
57 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
58 |
+
self.bn1 = BatchNorm2d(planes)
|
59 |
+
self.conv2 = nn.Conv2d(
|
60 |
+
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
61 |
+
)
|
62 |
+
self.bn2 = BatchNorm2d(planes, momentum=0.01)
|
63 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
64 |
+
self.bn3 = BatchNorm2d(planes * 4)
|
65 |
+
self.relu = nn.ReLU(inplace=True)
|
66 |
+
self.downsample = downsample
|
67 |
+
self.stride = stride
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
residual = x
|
71 |
+
|
72 |
+
out = self.conv1(x)
|
73 |
+
out = self.bn1(out)
|
74 |
+
out = self.relu(out)
|
75 |
+
|
76 |
+
out = self.conv2(out)
|
77 |
+
out = self.bn2(out)
|
78 |
+
out = self.relu(out)
|
79 |
+
|
80 |
+
out = self.conv3(out)
|
81 |
+
out = self.bn3(out)
|
82 |
+
|
83 |
+
if self.downsample is not None:
|
84 |
+
residual = self.downsample(x)
|
85 |
+
|
86 |
+
out += residual
|
87 |
+
out = self.relu(out)
|
88 |
+
|
89 |
+
return out
|
90 |
+
|
91 |
+
|
92 |
+
class ResNet(nn.Module):
|
93 |
+
def __init__(self, block, layers, num_classes=1000):
|
94 |
+
self.inplanes = 128
|
95 |
+
super(ResNet, self).__init__()
|
96 |
+
self.conv1 = conv3x3(3, 64, stride=2)
|
97 |
+
self.bn1 = BatchNorm2d(64)
|
98 |
+
self.relu1 = nn.ReLU(inplace=True)
|
99 |
+
self.conv2 = conv3x3(64, 64)
|
100 |
+
self.bn2 = BatchNorm2d(64)
|
101 |
+
self.relu2 = nn.ReLU(inplace=True)
|
102 |
+
self.conv3 = conv3x3(64, 128)
|
103 |
+
self.bn3 = BatchNorm2d(128)
|
104 |
+
self.relu3 = nn.ReLU(inplace=True)
|
105 |
+
self.maxpool = nn.MaxPool2d(
|
106 |
+
kernel_size=3, stride=2, padding=1, return_indices=True
|
107 |
+
)
|
108 |
+
|
109 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
110 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
111 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
112 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
113 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
114 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
115 |
+
|
116 |
+
for m in self.modules():
|
117 |
+
if isinstance(m, nn.Conv2d):
|
118 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
119 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / n))
|
120 |
+
elif isinstance(m, BatchNorm2d):
|
121 |
+
m.weight.data.fill_(1)
|
122 |
+
m.bias.data.zero_()
|
123 |
+
|
124 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
125 |
+
downsample = None
|
126 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
127 |
+
downsample = nn.Sequential(
|
128 |
+
nn.Conv2d(
|
129 |
+
self.inplanes,
|
130 |
+
planes * block.expansion,
|
131 |
+
kernel_size=1,
|
132 |
+
stride=stride,
|
133 |
+
bias=False,
|
134 |
+
),
|
135 |
+
BatchNorm2d(planes * block.expansion),
|
136 |
+
)
|
137 |
+
|
138 |
+
layers = []
|
139 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
140 |
+
self.inplanes = planes * block.expansion
|
141 |
+
for i in range(1, blocks):
|
142 |
+
layers.append(block(self.inplanes, planes))
|
143 |
+
|
144 |
+
return nn.Sequential(*layers)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
148 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
149 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
150 |
+
x, indices = self.maxpool(x)
|
151 |
+
|
152 |
+
x = self.layer1(x)
|
153 |
+
x = self.layer2(x)
|
154 |
+
x = self.layer3(x)
|
155 |
+
x = self.layer4(x)
|
156 |
+
|
157 |
+
x = self.avgpool(x)
|
158 |
+
x = x.view(x.size(0), -1)
|
159 |
+
x = self.fc(x)
|
160 |
+
return x
|
161 |
+
|
162 |
+
|
163 |
+
def l_resnet50():
|
164 |
+
"""Constructs a ResNet-50 model.
|
165 |
+
Args:
|
166 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
167 |
+
"""
|
168 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3])
|
169 |
+
return model
|
carvekit/ml/arch/fba_matting/transforms.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/MarcoForte/FBA_Matting
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
group_norm_std = [0.229, 0.224, 0.225]
|
10 |
+
group_norm_mean = [0.485, 0.456, 0.406]
|
11 |
+
|
12 |
+
|
13 |
+
def dt(a):
|
14 |
+
return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)
|
15 |
+
|
16 |
+
|
17 |
+
def trimap_transform(trimap):
|
18 |
+
h, w = trimap.shape[0], trimap.shape[1]
|
19 |
+
|
20 |
+
clicks = np.zeros((h, w, 6))
|
21 |
+
for k in range(2):
|
22 |
+
if np.count_nonzero(trimap[:, :, k]) > 0:
|
23 |
+
dt_mask = -dt(1 - trimap[:, :, k]) ** 2
|
24 |
+
L = 320
|
25 |
+
clicks[:, :, 3 * k] = np.exp(dt_mask / (2 * ((0.02 * L) ** 2)))
|
26 |
+
clicks[:, :, 3 * k + 1] = np.exp(dt_mask / (2 * ((0.08 * L) ** 2)))
|
27 |
+
clicks[:, :, 3 * k + 2] = np.exp(dt_mask / (2 * ((0.16 * L) ** 2)))
|
28 |
+
|
29 |
+
return clicks
|
30 |
+
|
31 |
+
|
32 |
+
def groupnorm_normalise_image(img, format="nhwc"):
|
33 |
+
"""
|
34 |
+
Accept rgb in range 0,1
|
35 |
+
"""
|
36 |
+
if format == "nhwc":
|
37 |
+
for i in range(3):
|
38 |
+
img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i]
|
39 |
+
else:
|
40 |
+
for i in range(3):
|
41 |
+
img[..., i, :, :] = (
|
42 |
+
img[..., i, :, :] - group_norm_mean[i]
|
43 |
+
) / group_norm_std[i]
|
44 |
+
|
45 |
+
return img
|
carvekit/ml/arch/tracerb7/__init__.py
ADDED
File without changes
|
carvekit/ml/arch/tracerb7/att_modules.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/Karel911/TRACER
|
3 |
+
Author: Min Seok Lee and Wooseok Shin
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from carvekit.ml.arch.tracerb7.conv_modules import BasicConv2d, DWConv, DWSConv
|
11 |
+
|
12 |
+
|
13 |
+
class RFB_Block(nn.Module):
|
14 |
+
def __init__(self, in_channel, out_channel):
|
15 |
+
super(RFB_Block, self).__init__()
|
16 |
+
self.relu = nn.ReLU(True)
|
17 |
+
self.branch0 = nn.Sequential(
|
18 |
+
BasicConv2d(in_channel, out_channel, 1),
|
19 |
+
)
|
20 |
+
self.branch1 = nn.Sequential(
|
21 |
+
BasicConv2d(in_channel, out_channel, 1),
|
22 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
|
23 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
|
24 |
+
BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3),
|
25 |
+
)
|
26 |
+
self.branch2 = nn.Sequential(
|
27 |
+
BasicConv2d(in_channel, out_channel, 1),
|
28 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
|
29 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
|
30 |
+
BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5),
|
31 |
+
)
|
32 |
+
self.branch3 = nn.Sequential(
|
33 |
+
BasicConv2d(in_channel, out_channel, 1),
|
34 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
|
35 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
|
36 |
+
BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7),
|
37 |
+
)
|
38 |
+
self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1)
|
39 |
+
self.conv_res = BasicConv2d(in_channel, out_channel, 1)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x0 = self.branch0(x)
|
43 |
+
x1 = self.branch1(x)
|
44 |
+
x2 = self.branch2(x)
|
45 |
+
x3 = self.branch3(x)
|
46 |
+
x_cat = torch.cat((x0, x1, x2, x3), 1)
|
47 |
+
x_cat = self.conv_cat(x_cat)
|
48 |
+
|
49 |
+
x = self.relu(x_cat + self.conv_res(x))
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class GlobalAvgPool(nn.Module):
|
54 |
+
def __init__(self, flatten=False):
|
55 |
+
super(GlobalAvgPool, self).__init__()
|
56 |
+
self.flatten = flatten
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
if self.flatten:
|
60 |
+
in_size = x.size()
|
61 |
+
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
62 |
+
else:
|
63 |
+
return (
|
64 |
+
x.view(x.size(0), x.size(1), -1)
|
65 |
+
.mean(-1)
|
66 |
+
.view(x.size(0), x.size(1), 1, 1)
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
class UnionAttentionModule(nn.Module):
|
71 |
+
def __init__(self, n_channels, only_channel_tracing=False):
|
72 |
+
super(UnionAttentionModule, self).__init__()
|
73 |
+
self.GAP = GlobalAvgPool()
|
74 |
+
self.confidence_ratio = 0.1
|
75 |
+
self.bn = nn.BatchNorm2d(n_channels)
|
76 |
+
self.norm = nn.Sequential(
|
77 |
+
nn.BatchNorm2d(n_channels), nn.Dropout3d(self.confidence_ratio)
|
78 |
+
)
|
79 |
+
self.channel_q = nn.Conv2d(
|
80 |
+
in_channels=n_channels,
|
81 |
+
out_channels=n_channels,
|
82 |
+
kernel_size=1,
|
83 |
+
stride=1,
|
84 |
+
padding=0,
|
85 |
+
bias=False,
|
86 |
+
)
|
87 |
+
self.channel_k = nn.Conv2d(
|
88 |
+
in_channels=n_channels,
|
89 |
+
out_channels=n_channels,
|
90 |
+
kernel_size=1,
|
91 |
+
stride=1,
|
92 |
+
padding=0,
|
93 |
+
bias=False,
|
94 |
+
)
|
95 |
+
self.channel_v = nn.Conv2d(
|
96 |
+
in_channels=n_channels,
|
97 |
+
out_channels=n_channels,
|
98 |
+
kernel_size=1,
|
99 |
+
stride=1,
|
100 |
+
padding=0,
|
101 |
+
bias=False,
|
102 |
+
)
|
103 |
+
|
104 |
+
self.fc = nn.Conv2d(
|
105 |
+
in_channels=n_channels,
|
106 |
+
out_channels=n_channels,
|
107 |
+
kernel_size=1,
|
108 |
+
stride=1,
|
109 |
+
padding=0,
|
110 |
+
bias=False,
|
111 |
+
)
|
112 |
+
|
113 |
+
if only_channel_tracing is False:
|
114 |
+
self.spatial_q = nn.Conv2d(
|
115 |
+
in_channels=n_channels,
|
116 |
+
out_channels=1,
|
117 |
+
kernel_size=1,
|
118 |
+
stride=1,
|
119 |
+
padding=0,
|
120 |
+
bias=False,
|
121 |
+
)
|
122 |
+
self.spatial_k = nn.Conv2d(
|
123 |
+
in_channels=n_channels,
|
124 |
+
out_channels=1,
|
125 |
+
kernel_size=1,
|
126 |
+
stride=1,
|
127 |
+
padding=0,
|
128 |
+
bias=False,
|
129 |
+
)
|
130 |
+
self.spatial_v = nn.Conv2d(
|
131 |
+
in_channels=n_channels,
|
132 |
+
out_channels=1,
|
133 |
+
kernel_size=1,
|
134 |
+
stride=1,
|
135 |
+
padding=0,
|
136 |
+
bias=False,
|
137 |
+
)
|
138 |
+
self.sigmoid = nn.Sigmoid()
|
139 |
+
|
140 |
+
def masking(self, x, mask):
|
141 |
+
mask = mask.squeeze(3).squeeze(2)
|
142 |
+
threshold = torch.quantile(
|
143 |
+
mask.float(), self.confidence_ratio, dim=-1, keepdim=True
|
144 |
+
)
|
145 |
+
mask[mask <= threshold] = 0.0
|
146 |
+
mask = mask.unsqueeze(2).unsqueeze(3)
|
147 |
+
mask = mask.expand(-1, x.shape[1], x.shape[2], x.shape[3]).contiguous()
|
148 |
+
masked_x = x * mask
|
149 |
+
|
150 |
+
return masked_x
|
151 |
+
|
152 |
+
def Channel_Tracer(self, x):
|
153 |
+
avg_pool = self.GAP(x)
|
154 |
+
x_norm = self.norm(avg_pool)
|
155 |
+
|
156 |
+
q = self.channel_q(x_norm).squeeze(-1)
|
157 |
+
k = self.channel_k(x_norm).squeeze(-1)
|
158 |
+
v = self.channel_v(x_norm).squeeze(-1)
|
159 |
+
|
160 |
+
# softmax(Q*K^T)
|
161 |
+
QK_T = torch.matmul(q, k.transpose(1, 2))
|
162 |
+
alpha = F.softmax(QK_T, dim=-1)
|
163 |
+
|
164 |
+
# a*v
|
165 |
+
att = torch.matmul(alpha, v).unsqueeze(-1)
|
166 |
+
att = self.fc(att)
|
167 |
+
att = self.sigmoid(att)
|
168 |
+
|
169 |
+
output = (x * att) + x
|
170 |
+
alpha_mask = att.clone()
|
171 |
+
|
172 |
+
return output, alpha_mask
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
X_c, alpha_mask = self.Channel_Tracer(x)
|
176 |
+
X_c = self.bn(X_c)
|
177 |
+
x_drop = self.masking(X_c, alpha_mask)
|
178 |
+
|
179 |
+
q = self.spatial_q(x_drop).squeeze(1)
|
180 |
+
k = self.spatial_k(x_drop).squeeze(1)
|
181 |
+
v = self.spatial_v(x_drop).squeeze(1)
|
182 |
+
|
183 |
+
# softmax(Q*K^T)
|
184 |
+
QK_T = torch.matmul(q, k.transpose(1, 2))
|
185 |
+
alpha = F.softmax(QK_T, dim=-1)
|
186 |
+
|
187 |
+
output = torch.matmul(alpha, v).unsqueeze(1) + v.unsqueeze(1)
|
188 |
+
|
189 |
+
return output
|
190 |
+
|
191 |
+
|
192 |
+
class aggregation(nn.Module):
|
193 |
+
def __init__(self, channel):
|
194 |
+
super(aggregation, self).__init__()
|
195 |
+
self.relu = nn.ReLU(True)
|
196 |
+
|
197 |
+
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
198 |
+
self.conv_upsample1 = BasicConv2d(channel[2], channel[1], 3, padding=1)
|
199 |
+
self.conv_upsample2 = BasicConv2d(channel[2], channel[0], 3, padding=1)
|
200 |
+
self.conv_upsample3 = BasicConv2d(channel[1], channel[0], 3, padding=1)
|
201 |
+
self.conv_upsample4 = BasicConv2d(channel[2], channel[2], 3, padding=1)
|
202 |
+
self.conv_upsample5 = BasicConv2d(
|
203 |
+
channel[2] + channel[1], channel[2] + channel[1], 3, padding=1
|
204 |
+
)
|
205 |
+
|
206 |
+
self.conv_concat2 = BasicConv2d(
|
207 |
+
(channel[2] + channel[1]), (channel[2] + channel[1]), 3, padding=1
|
208 |
+
)
|
209 |
+
self.conv_concat3 = BasicConv2d(
|
210 |
+
(channel[0] + channel[1] + channel[2]),
|
211 |
+
(channel[0] + channel[1] + channel[2]),
|
212 |
+
3,
|
213 |
+
padding=1,
|
214 |
+
)
|
215 |
+
|
216 |
+
self.UAM = UnionAttentionModule(channel[0] + channel[1] + channel[2])
|
217 |
+
|
218 |
+
def forward(self, e4, e3, e2):
|
219 |
+
e4_1 = e4
|
220 |
+
e3_1 = self.conv_upsample1(self.upsample(e4)) * e3
|
221 |
+
e2_1 = (
|
222 |
+
self.conv_upsample2(self.upsample(self.upsample(e4)))
|
223 |
+
* self.conv_upsample3(self.upsample(e3))
|
224 |
+
* e2
|
225 |
+
)
|
226 |
+
|
227 |
+
e3_2 = torch.cat((e3_1, self.conv_upsample4(self.upsample(e4_1))), 1)
|
228 |
+
e3_2 = self.conv_concat2(e3_2)
|
229 |
+
|
230 |
+
e2_2 = torch.cat((e2_1, self.conv_upsample5(self.upsample(e3_2))), 1)
|
231 |
+
x = self.conv_concat3(e2_2)
|
232 |
+
|
233 |
+
output = self.UAM(x)
|
234 |
+
|
235 |
+
return output
|
236 |
+
|
237 |
+
|
238 |
+
class ObjectAttention(nn.Module):
|
239 |
+
def __init__(self, channel, kernel_size):
|
240 |
+
super(ObjectAttention, self).__init__()
|
241 |
+
self.channel = channel
|
242 |
+
self.DWSConv = DWSConv(
|
243 |
+
channel, channel // 2, kernel=kernel_size, padding=1, kernels_per_layer=1
|
244 |
+
)
|
245 |
+
self.DWConv1 = nn.Sequential(
|
246 |
+
DWConv(channel // 2, channel // 2, kernel=1, padding=0, dilation=1),
|
247 |
+
BasicConv2d(channel // 2, channel // 8, 1),
|
248 |
+
)
|
249 |
+
self.DWConv2 = nn.Sequential(
|
250 |
+
DWConv(channel // 2, channel // 2, kernel=3, padding=1, dilation=1),
|
251 |
+
BasicConv2d(channel // 2, channel // 8, 1),
|
252 |
+
)
|
253 |
+
self.DWConv3 = nn.Sequential(
|
254 |
+
DWConv(channel // 2, channel // 2, kernel=3, padding=3, dilation=3),
|
255 |
+
BasicConv2d(channel // 2, channel // 8, 1),
|
256 |
+
)
|
257 |
+
self.DWConv4 = nn.Sequential(
|
258 |
+
DWConv(channel // 2, channel // 2, kernel=3, padding=5, dilation=5),
|
259 |
+
BasicConv2d(channel // 2, channel // 8, 1),
|
260 |
+
)
|
261 |
+
self.conv1 = BasicConv2d(channel // 2, 1, 1)
|
262 |
+
|
263 |
+
def forward(self, decoder_map, encoder_map):
|
264 |
+
"""
|
265 |
+
Args:
|
266 |
+
decoder_map: decoder representation (B, 1, H, W).
|
267 |
+
encoder_map: encoder block output (B, C, H, W).
|
268 |
+
Returns:
|
269 |
+
decoder representation: (B, 1, H, W)
|
270 |
+
"""
|
271 |
+
mask_bg = -1 * torch.sigmoid(decoder_map) + 1 # Sigmoid & Reverse
|
272 |
+
mask_ob = torch.sigmoid(decoder_map) # object attention
|
273 |
+
x = mask_ob.expand(-1, self.channel, -1, -1).mul(encoder_map)
|
274 |
+
|
275 |
+
edge = mask_bg.clone()
|
276 |
+
edge[edge > 0.93] = 0
|
277 |
+
x = x + (edge * encoder_map)
|
278 |
+
|
279 |
+
x = self.DWSConv(x)
|
280 |
+
skip = x.clone()
|
281 |
+
x = (
|
282 |
+
torch.cat(
|
283 |
+
[self.DWConv1(x), self.DWConv2(x), self.DWConv3(x), self.DWConv4(x)],
|
284 |
+
dim=1,
|
285 |
+
)
|
286 |
+
+ skip
|
287 |
+
)
|
288 |
+
x = torch.relu(self.conv1(x))
|
289 |
+
|
290 |
+
return x + decoder_map
|
carvekit/ml/arch/tracerb7/conv_modules.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/Karel911/TRACER
|
3 |
+
Author: Min Seok Lee and Wooseok Shin
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class BasicConv2d(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
in_channel,
|
13 |
+
out_channel,
|
14 |
+
kernel_size,
|
15 |
+
stride=(1, 1),
|
16 |
+
padding=(0, 0),
|
17 |
+
dilation=(1, 1),
|
18 |
+
):
|
19 |
+
super(BasicConv2d, self).__init__()
|
20 |
+
self.conv = nn.Conv2d(
|
21 |
+
in_channel,
|
22 |
+
out_channel,
|
23 |
+
kernel_size=kernel_size,
|
24 |
+
stride=stride,
|
25 |
+
padding=padding,
|
26 |
+
dilation=dilation,
|
27 |
+
bias=False,
|
28 |
+
)
|
29 |
+
self.bn = nn.BatchNorm2d(out_channel)
|
30 |
+
self.selu = nn.SELU()
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.conv(x)
|
34 |
+
x = self.bn(x)
|
35 |
+
x = self.selu(x)
|
36 |
+
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class DWConv(nn.Module):
|
41 |
+
def __init__(self, in_channel, out_channel, kernel, dilation, padding):
|
42 |
+
super(DWConv, self).__init__()
|
43 |
+
self.out_channel = out_channel
|
44 |
+
self.DWConv = nn.Conv2d(
|
45 |
+
in_channel,
|
46 |
+
out_channel,
|
47 |
+
kernel_size=kernel,
|
48 |
+
padding=padding,
|
49 |
+
groups=in_channel,
|
50 |
+
dilation=dilation,
|
51 |
+
bias=False,
|
52 |
+
)
|
53 |
+
self.bn = nn.BatchNorm2d(out_channel)
|
54 |
+
self.selu = nn.SELU()
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
x = self.DWConv(x)
|
58 |
+
out = self.selu(self.bn(x))
|
59 |
+
|
60 |
+
return out
|
61 |
+
|
62 |
+
|
63 |
+
class DWSConv(nn.Module):
|
64 |
+
def __init__(self, in_channel, out_channel, kernel, padding, kernels_per_layer):
|
65 |
+
super(DWSConv, self).__init__()
|
66 |
+
self.out_channel = out_channel
|
67 |
+
self.DWConv = nn.Conv2d(
|
68 |
+
in_channel,
|
69 |
+
in_channel * kernels_per_layer,
|
70 |
+
kernel_size=kernel,
|
71 |
+
padding=padding,
|
72 |
+
groups=in_channel,
|
73 |
+
bias=False,
|
74 |
+
)
|
75 |
+
self.bn = nn.BatchNorm2d(in_channel * kernels_per_layer)
|
76 |
+
self.selu = nn.SELU()
|
77 |
+
self.PWConv = nn.Conv2d(
|
78 |
+
in_channel * kernels_per_layer, out_channel, kernel_size=1, bias=False
|
79 |
+
)
|
80 |
+
self.bn2 = nn.BatchNorm2d(out_channel)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
x = self.DWConv(x)
|
84 |
+
x = self.selu(self.bn(x))
|
85 |
+
out = self.PWConv(x)
|
86 |
+
out = self.selu(self.bn2(out))
|
87 |
+
|
88 |
+
return out
|
carvekit/ml/arch/tracerb7/effi_utils.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Original author: lukemelas (github username)
|
3 |
+
Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
4 |
+
With adjustments and added comments by workingcoder (github username).
|
5 |
+
License: Apache License 2.0
|
6 |
+
Reimplemented: Min Seok Lee and Wooseok Shin
|
7 |
+
"""
|
8 |
+
|
9 |
+
import collections
|
10 |
+
import re
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
import math
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
# Parameters for the entire model (stem, all blocks, and head)
|
19 |
+
GlobalParams = collections.namedtuple(
|
20 |
+
"GlobalParams",
|
21 |
+
[
|
22 |
+
"width_coefficient",
|
23 |
+
"depth_coefficient",
|
24 |
+
"image_size",
|
25 |
+
"dropout_rate",
|
26 |
+
"num_classes",
|
27 |
+
"batch_norm_momentum",
|
28 |
+
"batch_norm_epsilon",
|
29 |
+
"drop_connect_rate",
|
30 |
+
"depth_divisor",
|
31 |
+
"min_depth",
|
32 |
+
"include_top",
|
33 |
+
],
|
34 |
+
)
|
35 |
+
|
36 |
+
# Parameters for an individual model block
|
37 |
+
BlockArgs = collections.namedtuple(
|
38 |
+
"BlockArgs",
|
39 |
+
[
|
40 |
+
"num_repeat",
|
41 |
+
"kernel_size",
|
42 |
+
"stride",
|
43 |
+
"expand_ratio",
|
44 |
+
"input_filters",
|
45 |
+
"output_filters",
|
46 |
+
"se_ratio",
|
47 |
+
"id_skip",
|
48 |
+
],
|
49 |
+
)
|
50 |
+
|
51 |
+
# Set GlobalParams and BlockArgs's defaults
|
52 |
+
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
|
53 |
+
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
|
54 |
+
|
55 |
+
|
56 |
+
# An ordinary implementation of Swish function
|
57 |
+
class Swish(nn.Module):
|
58 |
+
def forward(self, x):
|
59 |
+
return x * torch.sigmoid(x)
|
60 |
+
|
61 |
+
|
62 |
+
# A memory-efficient implementation of Swish function
|
63 |
+
class SwishImplementation(torch.autograd.Function):
|
64 |
+
@staticmethod
|
65 |
+
def forward(ctx, i):
|
66 |
+
result = i * torch.sigmoid(i)
|
67 |
+
ctx.save_for_backward(i)
|
68 |
+
return result
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def backward(ctx, grad_output):
|
72 |
+
i = ctx.saved_tensors[0]
|
73 |
+
sigmoid_i = torch.sigmoid(i)
|
74 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
75 |
+
|
76 |
+
|
77 |
+
class MemoryEfficientSwish(nn.Module):
|
78 |
+
def forward(self, x):
|
79 |
+
return SwishImplementation.apply(x)
|
80 |
+
|
81 |
+
|
82 |
+
def round_filters(filters, global_params):
|
83 |
+
"""Calculate and round number of filters based on width multiplier.
|
84 |
+
Use width_coefficient, depth_divisor and min_depth of global_params.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
filters (int): Filters number to be calculated.
|
88 |
+
global_params (namedtuple): Global params of the model.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
new_filters: New filters number after calculating.
|
92 |
+
"""
|
93 |
+
multiplier = global_params.width_coefficient
|
94 |
+
if not multiplier:
|
95 |
+
return filters
|
96 |
+
divisor = global_params.depth_divisor
|
97 |
+
min_depth = global_params.min_depth
|
98 |
+
filters *= multiplier
|
99 |
+
min_depth = min_depth or divisor # pay attention to this line when using min_depth
|
100 |
+
# follow the formula transferred from official TensorFlow implementation
|
101 |
+
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
|
102 |
+
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
103 |
+
new_filters += divisor
|
104 |
+
return int(new_filters)
|
105 |
+
|
106 |
+
|
107 |
+
def round_repeats(repeats, global_params):
|
108 |
+
"""Calculate module's repeat number of a block based on depth multiplier.
|
109 |
+
Use depth_coefficient of global_params.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
repeats (int): num_repeat to be calculated.
|
113 |
+
global_params (namedtuple): Global params of the model.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
new repeat: New repeat number after calculating.
|
117 |
+
"""
|
118 |
+
multiplier = global_params.depth_coefficient
|
119 |
+
if not multiplier:
|
120 |
+
return repeats
|
121 |
+
# follow the formula transferred from official TensorFlow implementation
|
122 |
+
return int(math.ceil(multiplier * repeats))
|
123 |
+
|
124 |
+
|
125 |
+
def drop_connect(inputs, p, training):
|
126 |
+
"""Drop connect.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
input (tensor: BCWH): Input of this structure.
|
130 |
+
p (float: 0.0~1.0): Probability of drop connection.
|
131 |
+
training (bool): The running mode.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
output: Output after drop connection.
|
135 |
+
"""
|
136 |
+
assert 0 <= p <= 1, "p must be in range of [0,1]"
|
137 |
+
|
138 |
+
if not training:
|
139 |
+
return inputs
|
140 |
+
|
141 |
+
batch_size = inputs.shape[0]
|
142 |
+
keep_prob = 1 - p
|
143 |
+
|
144 |
+
# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
|
145 |
+
random_tensor = keep_prob
|
146 |
+
random_tensor += torch.rand(
|
147 |
+
[batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
|
148 |
+
)
|
149 |
+
binary_tensor = torch.floor(random_tensor)
|
150 |
+
|
151 |
+
output = inputs / keep_prob * binary_tensor
|
152 |
+
return output
|
153 |
+
|
154 |
+
|
155 |
+
def get_width_and_height_from_size(x):
|
156 |
+
"""Obtain height and width from x.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
x (int, tuple or list): Data size.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
size: A tuple or list (H,W).
|
163 |
+
"""
|
164 |
+
if isinstance(x, int):
|
165 |
+
return x, x
|
166 |
+
if isinstance(x, list) or isinstance(x, tuple):
|
167 |
+
return x
|
168 |
+
else:
|
169 |
+
raise TypeError()
|
170 |
+
|
171 |
+
|
172 |
+
def calculate_output_image_size(input_image_size, stride):
|
173 |
+
"""Calculates the output image size when using Conv2dSamePadding with a stride.
|
174 |
+
Necessary for static padding. Thanks to mannatsingh for pointing this out.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
input_image_size (int, tuple or list): Size of input image.
|
178 |
+
stride (int, tuple or list): Conv2d operation's stride.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
output_image_size: A list [H,W].
|
182 |
+
"""
|
183 |
+
if input_image_size is None:
|
184 |
+
return None
|
185 |
+
image_height, image_width = get_width_and_height_from_size(input_image_size)
|
186 |
+
stride = stride if isinstance(stride, int) else stride[0]
|
187 |
+
image_height = int(math.ceil(image_height / stride))
|
188 |
+
image_width = int(math.ceil(image_width / stride))
|
189 |
+
return [image_height, image_width]
|
190 |
+
|
191 |
+
|
192 |
+
# Note:
|
193 |
+
# The following 'SamePadding' functions make output size equal ceil(input size/stride).
|
194 |
+
# Only when stride equals 1, can the output size be the same as input size.
|
195 |
+
# Don't be confused by their function names ! ! !
|
196 |
+
|
197 |
+
|
198 |
+
def get_same_padding_conv2d(image_size=None):
|
199 |
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
200 |
+
Static padding is necessary for ONNX exporting of models.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
image_size (int or tuple): Size of the image.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
|
207 |
+
"""
|
208 |
+
if image_size is None:
|
209 |
+
return Conv2dDynamicSamePadding
|
210 |
+
else:
|
211 |
+
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
212 |
+
|
213 |
+
|
214 |
+
class Conv2dDynamicSamePadding(nn.Conv2d):
|
215 |
+
"""2D Convolutions like TensorFlow, for a dynamic image size.
|
216 |
+
The padding is operated in forward function by calculating dynamically.
|
217 |
+
"""
|
218 |
+
|
219 |
+
# Tips for 'SAME' mode padding.
|
220 |
+
# Given the following:
|
221 |
+
# i: width or height
|
222 |
+
# s: stride
|
223 |
+
# k: kernel size
|
224 |
+
# d: dilation
|
225 |
+
# p: padding
|
226 |
+
# Output after Conv2d:
|
227 |
+
# o = floor((i+p-((k-1)*d+1))/s+1)
|
228 |
+
# If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
|
229 |
+
# => p = (i-1)*s+((k-1)*d+1)-i
|
230 |
+
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
in_channels,
|
234 |
+
out_channels,
|
235 |
+
kernel_size,
|
236 |
+
stride=1,
|
237 |
+
dilation=1,
|
238 |
+
groups=1,
|
239 |
+
bias=True,
|
240 |
+
):
|
241 |
+
super().__init__(
|
242 |
+
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias
|
243 |
+
)
|
244 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
ih, iw = x.size()[-2:]
|
248 |
+
kh, kw = self.weight.size()[-2:]
|
249 |
+
sh, sw = self.stride
|
250 |
+
oh, ow = math.ceil(ih / sh), math.ceil(
|
251 |
+
iw / sw
|
252 |
+
) # change the output size according to stride ! ! !
|
253 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
254 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
255 |
+
if pad_h > 0 or pad_w > 0:
|
256 |
+
x = F.pad(
|
257 |
+
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
|
258 |
+
)
|
259 |
+
return F.conv2d(
|
260 |
+
x,
|
261 |
+
self.weight,
|
262 |
+
self.bias,
|
263 |
+
self.stride,
|
264 |
+
self.padding,
|
265 |
+
self.dilation,
|
266 |
+
self.groups,
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
class Conv2dStaticSamePadding(nn.Conv2d):
|
271 |
+
"""2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
|
272 |
+
The padding mudule is calculated in construction function, then used in forward.
|
273 |
+
"""
|
274 |
+
|
275 |
+
# With the same calculation as Conv2dDynamicSamePadding
|
276 |
+
|
277 |
+
def __init__(
|
278 |
+
self,
|
279 |
+
in_channels,
|
280 |
+
out_channels,
|
281 |
+
kernel_size,
|
282 |
+
stride=1,
|
283 |
+
image_size=None,
|
284 |
+
**kwargs
|
285 |
+
):
|
286 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
|
287 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
288 |
+
|
289 |
+
# Calculate padding based on image size and save it
|
290 |
+
assert image_size is not None
|
291 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
292 |
+
kh, kw = self.weight.size()[-2:]
|
293 |
+
sh, sw = self.stride
|
294 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
295 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
296 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
297 |
+
if pad_h > 0 or pad_w > 0:
|
298 |
+
self.static_padding = nn.ZeroPad2d(
|
299 |
+
(pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
300 |
+
)
|
301 |
+
else:
|
302 |
+
self.static_padding = nn.Identity()
|
303 |
+
|
304 |
+
def forward(self, x):
|
305 |
+
x = self.static_padding(x)
|
306 |
+
x = F.conv2d(
|
307 |
+
x,
|
308 |
+
self.weight,
|
309 |
+
self.bias,
|
310 |
+
self.stride,
|
311 |
+
self.padding,
|
312 |
+
self.dilation,
|
313 |
+
self.groups,
|
314 |
+
)
|
315 |
+
return x
|
316 |
+
|
317 |
+
|
318 |
+
def get_same_padding_maxPool2d(image_size=None):
|
319 |
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
320 |
+
Static padding is necessary for ONNX exporting of models.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
image_size (int or tuple): Size of the image.
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
|
327 |
+
"""
|
328 |
+
if image_size is None:
|
329 |
+
return MaxPool2dDynamicSamePadding
|
330 |
+
else:
|
331 |
+
return partial(MaxPool2dStaticSamePadding, image_size=image_size)
|
332 |
+
|
333 |
+
|
334 |
+
class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
|
335 |
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
|
336 |
+
The padding is operated in forward function by calculating dynamically.
|
337 |
+
"""
|
338 |
+
|
339 |
+
def __init__(
|
340 |
+
self,
|
341 |
+
kernel_size,
|
342 |
+
stride,
|
343 |
+
padding=0,
|
344 |
+
dilation=1,
|
345 |
+
return_indices=False,
|
346 |
+
ceil_mode=False,
|
347 |
+
):
|
348 |
+
super().__init__(
|
349 |
+
kernel_size, stride, padding, dilation, return_indices, ceil_mode
|
350 |
+
)
|
351 |
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
352 |
+
self.kernel_size = (
|
353 |
+
[self.kernel_size] * 2
|
354 |
+
if isinstance(self.kernel_size, int)
|
355 |
+
else self.kernel_size
|
356 |
+
)
|
357 |
+
self.dilation = (
|
358 |
+
[self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
359 |
+
)
|
360 |
+
|
361 |
+
def forward(self, x):
|
362 |
+
ih, iw = x.size()[-2:]
|
363 |
+
kh, kw = self.kernel_size
|
364 |
+
sh, sw = self.stride
|
365 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
366 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
367 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
368 |
+
if pad_h > 0 or pad_w > 0:
|
369 |
+
x = F.pad(
|
370 |
+
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
|
371 |
+
)
|
372 |
+
return F.max_pool2d(
|
373 |
+
x,
|
374 |
+
self.kernel_size,
|
375 |
+
self.stride,
|
376 |
+
self.padding,
|
377 |
+
self.dilation,
|
378 |
+
self.ceil_mode,
|
379 |
+
self.return_indices,
|
380 |
+
)
|
381 |
+
|
382 |
+
|
383 |
+
class MaxPool2dStaticSamePadding(nn.MaxPool2d):
|
384 |
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
|
385 |
+
The padding mudule is calculated in construction function, then used in forward.
|
386 |
+
"""
|
387 |
+
|
388 |
+
def __init__(self, kernel_size, stride, image_size=None, **kwargs):
|
389 |
+
super().__init__(kernel_size, stride, **kwargs)
|
390 |
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
391 |
+
self.kernel_size = (
|
392 |
+
[self.kernel_size] * 2
|
393 |
+
if isinstance(self.kernel_size, int)
|
394 |
+
else self.kernel_size
|
395 |
+
)
|
396 |
+
self.dilation = (
|
397 |
+
[self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
398 |
+
)
|
399 |
+
|
400 |
+
# Calculate padding based on image size and save it
|
401 |
+
assert image_size is not None
|
402 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
403 |
+
kh, kw = self.kernel_size
|
404 |
+
sh, sw = self.stride
|
405 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
406 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
407 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
408 |
+
if pad_h > 0 or pad_w > 0:
|
409 |
+
self.static_padding = nn.ZeroPad2d(
|
410 |
+
(pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
411 |
+
)
|
412 |
+
else:
|
413 |
+
self.static_padding = nn.Identity()
|
414 |
+
|
415 |
+
def forward(self, x):
|
416 |
+
x = self.static_padding(x)
|
417 |
+
x = F.max_pool2d(
|
418 |
+
x,
|
419 |
+
self.kernel_size,
|
420 |
+
self.stride,
|
421 |
+
self.padding,
|
422 |
+
self.dilation,
|
423 |
+
self.ceil_mode,
|
424 |
+
self.return_indices,
|
425 |
+
)
|
426 |
+
return x
|
427 |
+
|
428 |
+
|
429 |
+
class BlockDecoder(object):
|
430 |
+
"""Block Decoder for readability,
|
431 |
+
straight from the official TensorFlow repository.
|
432 |
+
"""
|
433 |
+
|
434 |
+
@staticmethod
|
435 |
+
def _decode_block_string(block_string):
|
436 |
+
"""Get a block through a string notation of arguments.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
block_string (str): A string notation of arguments.
|
440 |
+
Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
BlockArgs: The namedtuple defined at the top of this file.
|
444 |
+
"""
|
445 |
+
assert isinstance(block_string, str)
|
446 |
+
|
447 |
+
ops = block_string.split("_")
|
448 |
+
options = {}
|
449 |
+
for op in ops:
|
450 |
+
splits = re.split(r"(\d.*)", op)
|
451 |
+
if len(splits) >= 2:
|
452 |
+
key, value = splits[:2]
|
453 |
+
options[key] = value
|
454 |
+
|
455 |
+
# Check stride
|
456 |
+
assert ("s" in options and len(options["s"]) == 1) or (
|
457 |
+
len(options["s"]) == 2 and options["s"][0] == options["s"][1]
|
458 |
+
)
|
459 |
+
|
460 |
+
return BlockArgs(
|
461 |
+
num_repeat=int(options["r"]),
|
462 |
+
kernel_size=int(options["k"]),
|
463 |
+
stride=[int(options["s"][0])],
|
464 |
+
expand_ratio=int(options["e"]),
|
465 |
+
input_filters=int(options["i"]),
|
466 |
+
output_filters=int(options["o"]),
|
467 |
+
se_ratio=float(options["se"]) if "se" in options else None,
|
468 |
+
id_skip=("noskip" not in block_string),
|
469 |
+
)
|
470 |
+
|
471 |
+
@staticmethod
|
472 |
+
def _encode_block_string(block):
|
473 |
+
"""Encode a block to a string.
|
474 |
+
|
475 |
+
Args:
|
476 |
+
block (namedtuple): A BlockArgs type argument.
|
477 |
+
|
478 |
+
Returns:
|
479 |
+
block_string: A String form of BlockArgs.
|
480 |
+
"""
|
481 |
+
args = [
|
482 |
+
"r%d" % block.num_repeat,
|
483 |
+
"k%d" % block.kernel_size,
|
484 |
+
"s%d%d" % (block.strides[0], block.strides[1]),
|
485 |
+
"e%s" % block.expand_ratio,
|
486 |
+
"i%d" % block.input_filters,
|
487 |
+
"o%d" % block.output_filters,
|
488 |
+
]
|
489 |
+
if 0 < block.se_ratio <= 1:
|
490 |
+
args.append("se%s" % block.se_ratio)
|
491 |
+
if block.id_skip is False:
|
492 |
+
args.append("noskip")
|
493 |
+
return "_".join(args)
|
494 |
+
|
495 |
+
@staticmethod
|
496 |
+
def decode(string_list):
|
497 |
+
"""Decode a list of string notations to specify blocks inside the network.
|
498 |
+
|
499 |
+
Args:
|
500 |
+
string_list (list[str]): A list of strings, each string is a notation of block.
|
501 |
+
|
502 |
+
Returns:
|
503 |
+
blocks_args: A list of BlockArgs namedtuples of block args.
|
504 |
+
"""
|
505 |
+
assert isinstance(string_list, list)
|
506 |
+
blocks_args = []
|
507 |
+
for block_string in string_list:
|
508 |
+
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
509 |
+
return blocks_args
|
510 |
+
|
511 |
+
@staticmethod
|
512 |
+
def encode(blocks_args):
|
513 |
+
"""Encode a list of BlockArgs to a list of strings.
|
514 |
+
|
515 |
+
Args:
|
516 |
+
blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
|
517 |
+
|
518 |
+
Returns:
|
519 |
+
block_strings: A list of strings, each string is a notation of block.
|
520 |
+
"""
|
521 |
+
block_strings = []
|
522 |
+
for block in blocks_args:
|
523 |
+
block_strings.append(BlockDecoder._encode_block_string(block))
|
524 |
+
return block_strings
|
525 |
+
|
526 |
+
|
527 |
+
def create_block_args(
|
528 |
+
width_coefficient=None,
|
529 |
+
depth_coefficient=None,
|
530 |
+
image_size=None,
|
531 |
+
dropout_rate=0.2,
|
532 |
+
drop_connect_rate=0.2,
|
533 |
+
num_classes=1000,
|
534 |
+
include_top=True,
|
535 |
+
):
|
536 |
+
"""Create BlockArgs and GlobalParams for efficientnet model.
|
537 |
+
|
538 |
+
Args:
|
539 |
+
width_coefficient (float)
|
540 |
+
depth_coefficient (float)
|
541 |
+
image_size (int)
|
542 |
+
dropout_rate (float)
|
543 |
+
drop_connect_rate (float)
|
544 |
+
num_classes (int)
|
545 |
+
|
546 |
+
Meaning as the name suggests.
|
547 |
+
|
548 |
+
Returns:
|
549 |
+
blocks_args, global_params.
|
550 |
+
"""
|
551 |
+
|
552 |
+
# Blocks args for the whole model(efficientnet-b0 by default)
|
553 |
+
# It will be modified in the construction of EfficientNet Class according to model
|
554 |
+
blocks_args = [
|
555 |
+
"r1_k3_s11_e1_i32_o16_se0.25",
|
556 |
+
"r2_k3_s22_e6_i16_o24_se0.25",
|
557 |
+
"r2_k5_s22_e6_i24_o40_se0.25",
|
558 |
+
"r3_k3_s22_e6_i40_o80_se0.25",
|
559 |
+
"r3_k5_s11_e6_i80_o112_se0.25",
|
560 |
+
"r4_k5_s22_e6_i112_o192_se0.25",
|
561 |
+
"r1_k3_s11_e6_i192_o320_se0.25",
|
562 |
+
]
|
563 |
+
blocks_args = BlockDecoder.decode(blocks_args)
|
564 |
+
|
565 |
+
global_params = GlobalParams(
|
566 |
+
width_coefficient=width_coefficient,
|
567 |
+
depth_coefficient=depth_coefficient,
|
568 |
+
image_size=image_size,
|
569 |
+
dropout_rate=dropout_rate,
|
570 |
+
num_classes=num_classes,
|
571 |
+
batch_norm_momentum=0.99,
|
572 |
+
batch_norm_epsilon=1e-3,
|
573 |
+
drop_connect_rate=drop_connect_rate,
|
574 |
+
depth_divisor=8,
|
575 |
+
min_depth=None,
|
576 |
+
include_top=include_top,
|
577 |
+
)
|
578 |
+
|
579 |
+
return blocks_args, global_params
|
carvekit/ml/arch/tracerb7/efficientnet.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/lukemelas/EfficientNet-PyTorch
|
3 |
+
Modified by Min Seok Lee, Wooseok Shin, Nikita Selin
|
4 |
+
License: Apache License 2.0
|
5 |
+
Changes:
|
6 |
+
- Added support for extracting edge features
|
7 |
+
- Added support for extracting object features at different levels
|
8 |
+
- Refactored the code
|
9 |
+
"""
|
10 |
+
from typing import Any, List
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
from carvekit.ml.arch.tracerb7.effi_utils import (
|
17 |
+
get_same_padding_conv2d,
|
18 |
+
calculate_output_image_size,
|
19 |
+
MemoryEfficientSwish,
|
20 |
+
drop_connect,
|
21 |
+
round_filters,
|
22 |
+
round_repeats,
|
23 |
+
Swish,
|
24 |
+
create_block_args,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class MBConvBlock(nn.Module):
|
29 |
+
"""Mobile Inverted Residual Bottleneck Block.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
block_args (namedtuple): BlockArgs, defined in utils.py.
|
33 |
+
global_params (namedtuple): GlobalParam, defined in utils.py.
|
34 |
+
image_size (tuple or list): [image_height, image_width].
|
35 |
+
|
36 |
+
References:
|
37 |
+
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
|
38 |
+
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
|
39 |
+
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, block_args, global_params, image_size=None):
|
43 |
+
super().__init__()
|
44 |
+
self._block_args = block_args
|
45 |
+
self._bn_mom = (
|
46 |
+
1 - global_params.batch_norm_momentum
|
47 |
+
) # pytorch's difference from tensorflow
|
48 |
+
self._bn_eps = global_params.batch_norm_epsilon
|
49 |
+
self.has_se = (self._block_args.se_ratio is not None) and (
|
50 |
+
0 < self._block_args.se_ratio <= 1
|
51 |
+
)
|
52 |
+
self.id_skip = (
|
53 |
+
block_args.id_skip
|
54 |
+
) # whether to use skip connection and drop connect
|
55 |
+
|
56 |
+
# Expansion phase (Inverted Bottleneck)
|
57 |
+
inp = self._block_args.input_filters # number of input channels
|
58 |
+
oup = (
|
59 |
+
self._block_args.input_filters * self._block_args.expand_ratio
|
60 |
+
) # number of output channels
|
61 |
+
if self._block_args.expand_ratio != 1:
|
62 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
63 |
+
self._expand_conv = Conv2d(
|
64 |
+
in_channels=inp, out_channels=oup, kernel_size=1, bias=False
|
65 |
+
)
|
66 |
+
self._bn0 = nn.BatchNorm2d(
|
67 |
+
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
|
68 |
+
)
|
69 |
+
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
|
70 |
+
|
71 |
+
# Depthwise convolution phase
|
72 |
+
k = self._block_args.kernel_size
|
73 |
+
s = self._block_args.stride
|
74 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
75 |
+
self._depthwise_conv = Conv2d(
|
76 |
+
in_channels=oup,
|
77 |
+
out_channels=oup,
|
78 |
+
groups=oup, # groups makes it depthwise
|
79 |
+
kernel_size=k,
|
80 |
+
stride=s,
|
81 |
+
bias=False,
|
82 |
+
)
|
83 |
+
self._bn1 = nn.BatchNorm2d(
|
84 |
+
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
|
85 |
+
)
|
86 |
+
image_size = calculate_output_image_size(image_size, s)
|
87 |
+
|
88 |
+
# Squeeze and Excitation layer, if desired
|
89 |
+
if self.has_se:
|
90 |
+
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
91 |
+
num_squeezed_channels = max(
|
92 |
+
1, int(self._block_args.input_filters * self._block_args.se_ratio)
|
93 |
+
)
|
94 |
+
self._se_reduce = Conv2d(
|
95 |
+
in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1
|
96 |
+
)
|
97 |
+
self._se_expand = Conv2d(
|
98 |
+
in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1
|
99 |
+
)
|
100 |
+
|
101 |
+
# Pointwise convolution phase
|
102 |
+
final_oup = self._block_args.output_filters
|
103 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
104 |
+
self._project_conv = Conv2d(
|
105 |
+
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
|
106 |
+
)
|
107 |
+
self._bn2 = nn.BatchNorm2d(
|
108 |
+
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
|
109 |
+
)
|
110 |
+
self._swish = MemoryEfficientSwish()
|
111 |
+
|
112 |
+
def forward(self, inputs, drop_connect_rate=None):
|
113 |
+
"""MBConvBlock's forward function.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
inputs (tensor): Input tensor.
|
117 |
+
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Output of this block after processing.
|
121 |
+
"""
|
122 |
+
|
123 |
+
# Expansion and Depthwise Convolution
|
124 |
+
x = inputs
|
125 |
+
if self._block_args.expand_ratio != 1:
|
126 |
+
x = self._expand_conv(inputs)
|
127 |
+
x = self._bn0(x)
|
128 |
+
x = self._swish(x)
|
129 |
+
|
130 |
+
x = self._depthwise_conv(x)
|
131 |
+
x = self._bn1(x)
|
132 |
+
x = self._swish(x)
|
133 |
+
|
134 |
+
# Squeeze and Excitation
|
135 |
+
if self.has_se:
|
136 |
+
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
137 |
+
x_squeezed = self._se_reduce(x_squeezed)
|
138 |
+
x_squeezed = self._swish(x_squeezed)
|
139 |
+
x_squeezed = self._se_expand(x_squeezed)
|
140 |
+
x = torch.sigmoid(x_squeezed) * x
|
141 |
+
|
142 |
+
# Pointwise Convolution
|
143 |
+
x = self._project_conv(x)
|
144 |
+
x = self._bn2(x)
|
145 |
+
|
146 |
+
# Skip connection and drop connect
|
147 |
+
input_filters, output_filters = (
|
148 |
+
self._block_args.input_filters,
|
149 |
+
self._block_args.output_filters,
|
150 |
+
)
|
151 |
+
if (
|
152 |
+
self.id_skip
|
153 |
+
and self._block_args.stride == 1
|
154 |
+
and input_filters == output_filters
|
155 |
+
):
|
156 |
+
# The combination of skip connection and drop connect brings about stochastic depth.
|
157 |
+
if drop_connect_rate:
|
158 |
+
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
159 |
+
x = x + inputs # skip connection
|
160 |
+
return x
|
161 |
+
|
162 |
+
def set_swish(self, memory_efficient=True):
|
163 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
164 |
+
|
165 |
+
Args:
|
166 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
167 |
+
"""
|
168 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
169 |
+
|
170 |
+
|
171 |
+
class EfficientNet(nn.Module):
|
172 |
+
def __init__(self, blocks_args=None, global_params=None):
|
173 |
+
super().__init__()
|
174 |
+
assert isinstance(blocks_args, list), "blocks_args should be a list"
|
175 |
+
assert len(blocks_args) > 0, "block args must be greater than 0"
|
176 |
+
self._global_params = global_params
|
177 |
+
self._blocks_args = blocks_args
|
178 |
+
|
179 |
+
# Batch norm parameters
|
180 |
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
181 |
+
bn_eps = self._global_params.batch_norm_epsilon
|
182 |
+
|
183 |
+
# Get stem static or dynamic convolution depending on image size
|
184 |
+
image_size = global_params.image_size
|
185 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
186 |
+
|
187 |
+
# Stem
|
188 |
+
in_channels = 3 # rgb
|
189 |
+
out_channels = round_filters(
|
190 |
+
32, self._global_params
|
191 |
+
) # number of output channels
|
192 |
+
self._conv_stem = Conv2d(
|
193 |
+
in_channels, out_channels, kernel_size=3, stride=2, bias=False
|
194 |
+
)
|
195 |
+
self._bn0 = nn.BatchNorm2d(
|
196 |
+
num_features=out_channels, momentum=bn_mom, eps=bn_eps
|
197 |
+
)
|
198 |
+
image_size = calculate_output_image_size(image_size, 2)
|
199 |
+
|
200 |
+
# Build blocks
|
201 |
+
self._blocks = nn.ModuleList([])
|
202 |
+
for block_args in self._blocks_args:
|
203 |
+
|
204 |
+
# Update block input and output filters based on depth multiplier.
|
205 |
+
block_args = block_args._replace(
|
206 |
+
input_filters=round_filters(
|
207 |
+
block_args.input_filters, self._global_params
|
208 |
+
),
|
209 |
+
output_filters=round_filters(
|
210 |
+
block_args.output_filters, self._global_params
|
211 |
+
),
|
212 |
+
num_repeat=round_repeats(block_args.num_repeat, self._global_params),
|
213 |
+
)
|
214 |
+
|
215 |
+
# The first block needs to take care of stride and filter size increase.
|
216 |
+
self._blocks.append(
|
217 |
+
MBConvBlock(block_args, self._global_params, image_size=image_size)
|
218 |
+
)
|
219 |
+
image_size = calculate_output_image_size(image_size, block_args.stride)
|
220 |
+
if block_args.num_repeat > 1: # modify block_args to keep same output size
|
221 |
+
block_args = block_args._replace(
|
222 |
+
input_filters=block_args.output_filters, stride=1
|
223 |
+
)
|
224 |
+
for _ in range(block_args.num_repeat - 1):
|
225 |
+
self._blocks.append(
|
226 |
+
MBConvBlock(block_args, self._global_params, image_size=image_size)
|
227 |
+
)
|
228 |
+
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
|
229 |
+
|
230 |
+
self._swish = MemoryEfficientSwish()
|
231 |
+
|
232 |
+
def set_swish(self, memory_efficient=True):
|
233 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
234 |
+
|
235 |
+
Args:
|
236 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
237 |
+
|
238 |
+
"""
|
239 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
240 |
+
for block in self._blocks:
|
241 |
+
block.set_swish(memory_efficient)
|
242 |
+
|
243 |
+
def extract_endpoints(self, inputs):
|
244 |
+
endpoints = dict()
|
245 |
+
|
246 |
+
# Stem
|
247 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
248 |
+
prev_x = x
|
249 |
+
|
250 |
+
# Blocks
|
251 |
+
for idx, block in enumerate(self._blocks):
|
252 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
253 |
+
if drop_connect_rate:
|
254 |
+
drop_connect_rate *= float(idx) / len(
|
255 |
+
self._blocks
|
256 |
+
) # scale drop connect_rate
|
257 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
258 |
+
if prev_x.size(2) > x.size(2):
|
259 |
+
endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
|
260 |
+
prev_x = x
|
261 |
+
|
262 |
+
# Head
|
263 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
264 |
+
endpoints["reduction_{}".format(len(endpoints) + 1)] = x
|
265 |
+
|
266 |
+
return endpoints
|
267 |
+
|
268 |
+
def _change_in_channels(self, in_channels):
|
269 |
+
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
in_channels (int): Input data's channel number.
|
273 |
+
"""
|
274 |
+
if in_channels != 3:
|
275 |
+
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
|
276 |
+
out_channels = round_filters(32, self._global_params)
|
277 |
+
self._conv_stem = Conv2d(
|
278 |
+
in_channels, out_channels, kernel_size=3, stride=2, bias=False
|
279 |
+
)
|
280 |
+
|
281 |
+
|
282 |
+
class EfficientEncoderB7(EfficientNet):
|
283 |
+
def __init__(self):
|
284 |
+
super().__init__(
|
285 |
+
*create_block_args(
|
286 |
+
width_coefficient=2.0,
|
287 |
+
depth_coefficient=3.1,
|
288 |
+
dropout_rate=0.5,
|
289 |
+
image_size=600,
|
290 |
+
)
|
291 |
+
)
|
292 |
+
self._change_in_channels(3)
|
293 |
+
self.block_idx = [10, 17, 37, 54]
|
294 |
+
self.channels = [48, 80, 224, 640]
|
295 |
+
|
296 |
+
def initial_conv(self, inputs):
|
297 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
298 |
+
return x
|
299 |
+
|
300 |
+
def get_blocks(self, x, H, W, block_idx):
|
301 |
+
features = []
|
302 |
+
for idx, block in enumerate(self._blocks):
|
303 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
304 |
+
if drop_connect_rate:
|
305 |
+
drop_connect_rate *= float(idx) / len(
|
306 |
+
self._blocks
|
307 |
+
) # scale drop connect_rate
|
308 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
309 |
+
if idx == block_idx[0]:
|
310 |
+
features.append(x.clone())
|
311 |
+
if idx == block_idx[1]:
|
312 |
+
features.append(x.clone())
|
313 |
+
if idx == block_idx[2]:
|
314 |
+
features.append(x.clone())
|
315 |
+
if idx == block_idx[3]:
|
316 |
+
features.append(x.clone())
|
317 |
+
|
318 |
+
return features
|
319 |
+
|
320 |
+
def forward(self, inputs: torch.Tensor) -> List[Any]:
|
321 |
+
B, C, H, W = inputs.size()
|
322 |
+
x = self.initial_conv(inputs) # Prepare input for the backbone
|
323 |
+
return self.get_blocks(
|
324 |
+
x, H, W, block_idx=self.block_idx
|
325 |
+
) # Get backbone features and edge maps
|
carvekit/ml/arch/tracerb7/tracer.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/Karel911/TRACER
|
3 |
+
Author: Min Seok Lee and Wooseok Shin
|
4 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
5 |
+
License: Apache License 2.0
|
6 |
+
Changes:
|
7 |
+
- Refactored code
|
8 |
+
- Removed unused code
|
9 |
+
- Added comments
|
10 |
+
"""
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from typing import List, Optional, Tuple
|
16 |
+
|
17 |
+
from torch import Tensor
|
18 |
+
|
19 |
+
from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7
|
20 |
+
from carvekit.ml.arch.tracerb7.att_modules import (
|
21 |
+
RFB_Block,
|
22 |
+
aggregation,
|
23 |
+
ObjectAttention,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class TracerDecoder(nn.Module):
|
28 |
+
"""Tracer Decoder"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
encoder: EfficientEncoderB7,
|
33 |
+
features_channels: Optional[List[int]] = None,
|
34 |
+
rfb_channel: Optional[List[int]] = None,
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
Initialize the tracer decoder.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
encoder: The encoder to use.
|
41 |
+
features_channels: The channels of the backbone features at different stages. default: [48, 80, 224, 640]
|
42 |
+
rfb_channel: The channels of the RFB features. default: [32, 64, 128]
|
43 |
+
"""
|
44 |
+
super().__init__()
|
45 |
+
if rfb_channel is None:
|
46 |
+
rfb_channel = [32, 64, 128]
|
47 |
+
if features_channels is None:
|
48 |
+
features_channels = [48, 80, 224, 640]
|
49 |
+
self.encoder = encoder
|
50 |
+
self.features_channels = features_channels
|
51 |
+
|
52 |
+
# Receptive Field Blocks
|
53 |
+
features_channels = rfb_channel
|
54 |
+
self.rfb2 = RFB_Block(self.features_channels[1], features_channels[0])
|
55 |
+
self.rfb3 = RFB_Block(self.features_channels[2], features_channels[1])
|
56 |
+
self.rfb4 = RFB_Block(self.features_channels[3], features_channels[2])
|
57 |
+
|
58 |
+
# Multi-level aggregation
|
59 |
+
self.agg = aggregation(features_channels)
|
60 |
+
|
61 |
+
# Object Attention
|
62 |
+
self.ObjectAttention2 = ObjectAttention(
|
63 |
+
channel=self.features_channels[1], kernel_size=3
|
64 |
+
)
|
65 |
+
self.ObjectAttention1 = ObjectAttention(
|
66 |
+
channel=self.features_channels[0], kernel_size=3
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, inputs: torch.Tensor) -> Tensor:
|
70 |
+
"""
|
71 |
+
Forward pass of the tracer decoder.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
inputs: Preprocessed images.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Tensors of segmentation masks and mask of object edges.
|
78 |
+
"""
|
79 |
+
features = self.encoder(inputs)
|
80 |
+
x3_rfb = self.rfb2(features[1])
|
81 |
+
x4_rfb = self.rfb3(features[2])
|
82 |
+
x5_rfb = self.rfb4(features[3])
|
83 |
+
|
84 |
+
D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb)
|
85 |
+
|
86 |
+
ds_map0 = F.interpolate(D_0, scale_factor=8, mode="bilinear")
|
87 |
+
|
88 |
+
D_1 = self.ObjectAttention2(D_0, features[1])
|
89 |
+
ds_map1 = F.interpolate(D_1, scale_factor=8, mode="bilinear")
|
90 |
+
|
91 |
+
ds_map = F.interpolate(D_1, scale_factor=2, mode="bilinear")
|
92 |
+
D_2 = self.ObjectAttention1(ds_map, features[0])
|
93 |
+
ds_map2 = F.interpolate(D_2, scale_factor=4, mode="bilinear")
|
94 |
+
|
95 |
+
final_map = (ds_map2 + ds_map1 + ds_map0) / 3
|
96 |
+
|
97 |
+
return torch.sigmoid(final_map)
|
carvekit/ml/arch/u2net/__init__.py
ADDED
File without changes
|
carvekit/ml/arch/u2net/u2net.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/xuebinqin/U-2-Net
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
import math
|
12 |
+
|
13 |
+
__all__ = ["U2NETArchitecture"]
|
14 |
+
|
15 |
+
|
16 |
+
def _upsample_like(x, size):
|
17 |
+
return nn.Upsample(size=size, mode="bilinear", align_corners=False)(x)
|
18 |
+
|
19 |
+
|
20 |
+
def _size_map(x, height):
|
21 |
+
# {height: size} for Upsample
|
22 |
+
size = list(x.shape[-2:])
|
23 |
+
sizes = {}
|
24 |
+
for h in range(1, height):
|
25 |
+
sizes[h] = size
|
26 |
+
size = [math.ceil(w / 2) for w in size]
|
27 |
+
return sizes
|
28 |
+
|
29 |
+
|
30 |
+
class REBNCONV(nn.Module):
|
31 |
+
def __init__(self, in_ch=3, out_ch=3, dilate=1):
|
32 |
+
super(REBNCONV, self).__init__()
|
33 |
+
|
34 |
+
self.conv_s1 = nn.Conv2d(
|
35 |
+
in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate
|
36 |
+
)
|
37 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
38 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return self.relu_s1(self.bn_s1(self.conv_s1(x)))
|
42 |
+
|
43 |
+
|
44 |
+
class RSU(nn.Module):
|
45 |
+
def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
|
46 |
+
super(RSU, self).__init__()
|
47 |
+
self.name = name
|
48 |
+
self.height = height
|
49 |
+
self.dilated = dilated
|
50 |
+
self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
sizes = _size_map(x, self.height)
|
54 |
+
x = self.rebnconvin(x)
|
55 |
+
|
56 |
+
# U-Net like symmetric encoder-decoder structure
|
57 |
+
def unet(x, height=1):
|
58 |
+
if height < self.height:
|
59 |
+
x1 = getattr(self, f"rebnconv{height}")(x)
|
60 |
+
if not self.dilated and height < self.height - 1:
|
61 |
+
x2 = unet(getattr(self, "downsample")(x1), height + 1)
|
62 |
+
else:
|
63 |
+
x2 = unet(x1, height + 1)
|
64 |
+
|
65 |
+
x = getattr(self, f"rebnconv{height}d")(torch.cat((x2, x1), 1))
|
66 |
+
return (
|
67 |
+
_upsample_like(x, sizes[height - 1])
|
68 |
+
if not self.dilated and height > 1
|
69 |
+
else x
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
return getattr(self, f"rebnconv{height}")(x)
|
73 |
+
|
74 |
+
return x + unet(x)
|
75 |
+
|
76 |
+
def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
|
77 |
+
self.add_module("rebnconvin", REBNCONV(in_ch, out_ch))
|
78 |
+
self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
79 |
+
|
80 |
+
self.add_module("rebnconv1", REBNCONV(out_ch, mid_ch))
|
81 |
+
self.add_module("rebnconv1d", REBNCONV(mid_ch * 2, out_ch))
|
82 |
+
|
83 |
+
for i in range(2, height):
|
84 |
+
dilate = 1 if not dilated else 2 ** (i - 1)
|
85 |
+
self.add_module(f"rebnconv{i}", REBNCONV(mid_ch, mid_ch, dilate=dilate))
|
86 |
+
self.add_module(
|
87 |
+
f"rebnconv{i}d", REBNCONV(mid_ch * 2, mid_ch, dilate=dilate)
|
88 |
+
)
|
89 |
+
|
90 |
+
dilate = 2 if not dilated else 2 ** (height - 1)
|
91 |
+
self.add_module(f"rebnconv{height}", REBNCONV(mid_ch, mid_ch, dilate=dilate))
|
92 |
+
|
93 |
+
|
94 |
+
class U2NETArchitecture(nn.Module):
|
95 |
+
def __init__(self, cfg_type: Union[dict, str] = "full", out_ch: int = 1):
|
96 |
+
super(U2NETArchitecture, self).__init__()
|
97 |
+
if isinstance(cfg_type, str):
|
98 |
+
if cfg_type == "full":
|
99 |
+
layers_cfgs = {
|
100 |
+
# cfgs for building RSUs and sides
|
101 |
+
# {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
|
102 |
+
"stage1": ["En_1", (7, 3, 32, 64), -1],
|
103 |
+
"stage2": ["En_2", (6, 64, 32, 128), -1],
|
104 |
+
"stage3": ["En_3", (5, 128, 64, 256), -1],
|
105 |
+
"stage4": ["En_4", (4, 256, 128, 512), -1],
|
106 |
+
"stage5": ["En_5", (4, 512, 256, 512, True), -1],
|
107 |
+
"stage6": ["En_6", (4, 512, 256, 512, True), 512],
|
108 |
+
"stage5d": ["De_5", (4, 1024, 256, 512, True), 512],
|
109 |
+
"stage4d": ["De_4", (4, 1024, 128, 256), 256],
|
110 |
+
"stage3d": ["De_3", (5, 512, 64, 128), 128],
|
111 |
+
"stage2d": ["De_2", (6, 256, 32, 64), 64],
|
112 |
+
"stage1d": ["De_1", (7, 128, 16, 64), 64],
|
113 |
+
}
|
114 |
+
else:
|
115 |
+
raise ValueError("Unknown U^2-Net architecture conf. name")
|
116 |
+
elif isinstance(cfg_type, dict):
|
117 |
+
layers_cfgs = cfg_type
|
118 |
+
else:
|
119 |
+
raise ValueError("Unknown U^2-Net architecture conf. type")
|
120 |
+
self.out_ch = out_ch
|
121 |
+
self._make_layers(layers_cfgs)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
sizes = _size_map(x, self.height)
|
125 |
+
maps = [] # storage for maps
|
126 |
+
|
127 |
+
# side saliency map
|
128 |
+
def unet(x, height=1):
|
129 |
+
if height < 6:
|
130 |
+
x1 = getattr(self, f"stage{height}")(x)
|
131 |
+
x2 = unet(getattr(self, "downsample")(x1), height + 1)
|
132 |
+
x = getattr(self, f"stage{height}d")(torch.cat((x2, x1), 1))
|
133 |
+
side(x, height)
|
134 |
+
return _upsample_like(x, sizes[height - 1]) if height > 1 else x
|
135 |
+
else:
|
136 |
+
x = getattr(self, f"stage{height}")(x)
|
137 |
+
side(x, height)
|
138 |
+
return _upsample_like(x, sizes[height - 1])
|
139 |
+
|
140 |
+
def side(x, h):
|
141 |
+
# side output saliency map (before sigmoid)
|
142 |
+
x = getattr(self, f"side{h}")(x)
|
143 |
+
x = _upsample_like(x, sizes[1])
|
144 |
+
maps.append(x)
|
145 |
+
|
146 |
+
def fuse():
|
147 |
+
# fuse saliency probability maps
|
148 |
+
maps.reverse()
|
149 |
+
x = torch.cat(maps, 1)
|
150 |
+
x = getattr(self, "outconv")(x)
|
151 |
+
maps.insert(0, x)
|
152 |
+
return [torch.sigmoid(x) for x in maps]
|
153 |
+
|
154 |
+
unet(x)
|
155 |
+
maps = fuse()
|
156 |
+
return maps
|
157 |
+
|
158 |
+
def _make_layers(self, cfgs):
|
159 |
+
self.height = int((len(cfgs) + 1) / 2)
|
160 |
+
self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
161 |
+
for k, v in cfgs.items():
|
162 |
+
# build rsu block
|
163 |
+
self.add_module(k, RSU(v[0], *v[1]))
|
164 |
+
if v[2] > 0:
|
165 |
+
# build side layer
|
166 |
+
self.add_module(
|
167 |
+
f"side{v[0][-1]}", nn.Conv2d(v[2], self.out_ch, 3, padding=1)
|
168 |
+
)
|
169 |
+
# build fuse layer
|
170 |
+
self.add_module(
|
171 |
+
"outconv", nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1)
|
172 |
+
)
|
carvekit/ml/files/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
carvekit_dir = Path.home().joinpath(".cache/carvekit")
|
4 |
+
|
5 |
+
carvekit_dir.mkdir(parents=True, exist_ok=True)
|
6 |
+
|
7 |
+
checkpoints_dir = carvekit_dir.joinpath("checkpoints")
|
carvekit/ml/files/models_loc.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import pathlib
|
7 |
+
from carvekit.ml.files import checkpoints_dir
|
8 |
+
from carvekit.utils.download_models import downloader
|
9 |
+
|
10 |
+
|
11 |
+
def u2net_full_pretrained() -> pathlib.Path:
|
12 |
+
"""Returns u2net pretrained model location
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
pathlib.Path to model location
|
16 |
+
"""
|
17 |
+
return downloader("u2net.pth")
|
18 |
+
|
19 |
+
|
20 |
+
def basnet_pretrained() -> pathlib.Path:
|
21 |
+
"""Returns basnet pretrained model location
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
pathlib.Path to model location
|
25 |
+
"""
|
26 |
+
return downloader("basnet.pth")
|
27 |
+
|
28 |
+
|
29 |
+
def deeplab_pretrained() -> pathlib.Path:
|
30 |
+
"""Returns basnet pretrained model location
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
pathlib.Path to model location
|
34 |
+
"""
|
35 |
+
return downloader("deeplab.pth")
|
36 |
+
|
37 |
+
|
38 |
+
def fba_pretrained() -> pathlib.Path:
|
39 |
+
"""Returns basnet pretrained model location
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
pathlib.Path to model location
|
43 |
+
"""
|
44 |
+
return downloader("fba_matting.pth")
|
45 |
+
|
46 |
+
|
47 |
+
def tracer_b7_pretrained() -> pathlib.Path:
|
48 |
+
"""Returns TRACER with EfficientNet v1 b7 encoder pretrained model location
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
pathlib.Path to model location
|
52 |
+
"""
|
53 |
+
return downloader("tracer_b7.pth")
|
54 |
+
|
55 |
+
|
56 |
+
def tracer_hair_pretrained() -> pathlib.Path:
|
57 |
+
"""Returns TRACER with EfficientNet v1 b7 encoder model for hair segmentation location
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
pathlib.Path to model location
|
61 |
+
"""
|
62 |
+
return downloader("tracer_hair.pth")
|
63 |
+
|
64 |
+
|
65 |
+
def download_all():
|
66 |
+
u2net_full_pretrained()
|
67 |
+
fba_pretrained()
|
68 |
+
deeplab_pretrained()
|
69 |
+
basnet_pretrained()
|
70 |
+
tracer_b7_pretrained()
|
carvekit/ml/wrap/__init__.py
ADDED
File without changes
|
carvekit/ml/wrap/basnet.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import pathlib
|
7 |
+
from typing import Union, List
|
8 |
+
|
9 |
+
import PIL
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from carvekit.ml.arch.basnet.basnet import BASNet
|
15 |
+
from carvekit.ml.files.models_loc import basnet_pretrained
|
16 |
+
from carvekit.utils.image_utils import convert_image, load_image
|
17 |
+
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
|
18 |
+
|
19 |
+
__all__ = ["BASNET"]
|
20 |
+
|
21 |
+
|
22 |
+
class BASNET(BASNet):
|
23 |
+
"""BASNet model interface"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
device="cpu",
|
28 |
+
input_image_size: Union[List[int], int] = 320,
|
29 |
+
batch_size: int = 10,
|
30 |
+
load_pretrained: bool = True,
|
31 |
+
fp16: bool = False,
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
Initialize the BASNET model
|
35 |
+
|
36 |
+
Args:
|
37 |
+
device: processing device
|
38 |
+
input_image_size: input image size
|
39 |
+
batch_size: the number of images that the neural network processes in one run
|
40 |
+
load_pretrained: loading pretrained model
|
41 |
+
fp16: use fp16 precision // not supported at this moment
|
42 |
+
|
43 |
+
"""
|
44 |
+
super(BASNET, self).__init__(n_channels=3, n_classes=1)
|
45 |
+
self.device = device
|
46 |
+
self.batch_size = batch_size
|
47 |
+
if isinstance(input_image_size, list):
|
48 |
+
self.input_image_size = input_image_size[:2]
|
49 |
+
else:
|
50 |
+
self.input_image_size = (input_image_size, input_image_size)
|
51 |
+
self.to(device)
|
52 |
+
if load_pretrained:
|
53 |
+
self.load_state_dict(
|
54 |
+
torch.load(basnet_pretrained(), map_location=self.device)
|
55 |
+
)
|
56 |
+
self.eval()
|
57 |
+
|
58 |
+
def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
|
59 |
+
"""
|
60 |
+
Transform input image to suitable data format for neural network
|
61 |
+
|
62 |
+
Args:
|
63 |
+
data: input image
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
input for neural network
|
67 |
+
|
68 |
+
"""
|
69 |
+
resized = data.resize(self.input_image_size)
|
70 |
+
# noinspection PyTypeChecker
|
71 |
+
resized_arr = np.array(resized, dtype=np.float64)
|
72 |
+
temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
|
73 |
+
if np.max(resized_arr) != 0:
|
74 |
+
resized_arr /= np.max(resized_arr)
|
75 |
+
temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
|
76 |
+
temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
|
77 |
+
temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
|
78 |
+
temp_image = temp_image.transpose((2, 0, 1))
|
79 |
+
temp_image = np.expand_dims(temp_image, 0)
|
80 |
+
return torch.from_numpy(temp_image).type(torch.FloatTensor)
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def data_postprocessing(
|
84 |
+
data: torch.tensor, original_image: PIL.Image.Image
|
85 |
+
) -> PIL.Image.Image:
|
86 |
+
"""
|
87 |
+
Transforms output data from neural network to suitable data
|
88 |
+
format for using with other components of this framework.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
data: output data from neural network
|
92 |
+
original_image: input image which was used for predicted data
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Segmentation mask as PIL Image instance
|
96 |
+
|
97 |
+
"""
|
98 |
+
data = data.unsqueeze(0)
|
99 |
+
mask = data[:, 0, :, :]
|
100 |
+
ma = torch.max(mask) # Normalizes prediction
|
101 |
+
mi = torch.min(mask)
|
102 |
+
predict = ((mask - mi) / (ma - mi)).squeeze()
|
103 |
+
predict_np = predict.cpu().data.numpy() * 255
|
104 |
+
mask = Image.fromarray(predict_np).convert("L")
|
105 |
+
mask = mask.resize(original_image.size, resample=3)
|
106 |
+
return mask
|
107 |
+
|
108 |
+
def __call__(
|
109 |
+
self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
|
110 |
+
) -> List[PIL.Image.Image]:
|
111 |
+
"""
|
112 |
+
Passes input images through neural network and returns segmentation masks as PIL.Image.Image instances
|
113 |
+
|
114 |
+
Args:
|
115 |
+
images: input images
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
segmentation masks as for input images, as PIL.Image.Image instances
|
119 |
+
|
120 |
+
"""
|
121 |
+
collect_masks = []
|
122 |
+
for image_batch in batch_generator(images, self.batch_size):
|
123 |
+
images = thread_pool_processing(
|
124 |
+
lambda x: convert_image(load_image(x)), image_batch
|
125 |
+
)
|
126 |
+
batches = torch.vstack(
|
127 |
+
thread_pool_processing(self.data_preprocessing, images)
|
128 |
+
)
|
129 |
+
with torch.no_grad():
|
130 |
+
batches = batches.to(self.device)
|
131 |
+
masks, d2, d3, d4, d5, d6, d7, d8 = super(BASNET, self).__call__(
|
132 |
+
batches
|
133 |
+
)
|
134 |
+
masks_cpu = masks.cpu()
|
135 |
+
del d2, d3, d4, d5, d6, d7, d8, batches, masks
|
136 |
+
masks = thread_pool_processing(
|
137 |
+
lambda x: self.data_postprocessing(masks_cpu[x], images[x]),
|
138 |
+
range(len(images)),
|
139 |
+
)
|
140 |
+
collect_masks += masks
|
141 |
+
return collect_masks
|
carvekit/ml/wrap/deeplab_v3.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import pathlib
|
7 |
+
from typing import List, Union
|
8 |
+
|
9 |
+
import PIL.Image
|
10 |
+
import torch
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision import transforms
|
13 |
+
from torchvision.models.segmentation import deeplabv3_resnet101
|
14 |
+
from carvekit.ml.files.models_loc import deeplab_pretrained
|
15 |
+
from carvekit.utils.image_utils import convert_image, load_image
|
16 |
+
from carvekit.utils.models_utils import get_precision_autocast, cast_network
|
17 |
+
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
|
18 |
+
|
19 |
+
__all__ = ["DeepLabV3"]
|
20 |
+
|
21 |
+
|
22 |
+
class DeepLabV3:
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
device="cpu",
|
26 |
+
batch_size: int = 10,
|
27 |
+
input_image_size: Union[List[int], int] = 1024,
|
28 |
+
load_pretrained: bool = True,
|
29 |
+
fp16: bool = False,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Initialize the DeepLabV3 model
|
33 |
+
|
34 |
+
Args:
|
35 |
+
device: processing device
|
36 |
+
input_image_size: input image size
|
37 |
+
batch_size: the number of images that the neural network processes in one run
|
38 |
+
load_pretrained: loading pretrained model
|
39 |
+
fp16: use half precision
|
40 |
+
|
41 |
+
"""
|
42 |
+
self.device = device
|
43 |
+
self.batch_size = batch_size
|
44 |
+
self.network = deeplabv3_resnet101(
|
45 |
+
pretrained=False, pretrained_backbone=False, aux_loss=True
|
46 |
+
)
|
47 |
+
self.network.to(self.device)
|
48 |
+
if load_pretrained:
|
49 |
+
self.network.load_state_dict(
|
50 |
+
torch.load(deeplab_pretrained(), map_location=self.device)
|
51 |
+
)
|
52 |
+
if isinstance(input_image_size, list):
|
53 |
+
self.input_image_size = input_image_size[:2]
|
54 |
+
else:
|
55 |
+
self.input_image_size = (input_image_size, input_image_size)
|
56 |
+
self.network.eval()
|
57 |
+
self.fp16 = fp16
|
58 |
+
self.transform = transforms.Compose(
|
59 |
+
[
|
60 |
+
transforms.ToTensor(),
|
61 |
+
transforms.Normalize(
|
62 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
63 |
+
),
|
64 |
+
]
|
65 |
+
)
|
66 |
+
|
67 |
+
def to(self, device: str):
|
68 |
+
"""
|
69 |
+
Moves neural network to specified processing device
|
70 |
+
|
71 |
+
Args:
|
72 |
+
device (:class:`torch.device`): the desired device.
|
73 |
+
Returns:
|
74 |
+
None
|
75 |
+
|
76 |
+
"""
|
77 |
+
self.network.to(device)
|
78 |
+
|
79 |
+
def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
|
80 |
+
"""
|
81 |
+
Transform input image to suitable data format for neural network
|
82 |
+
|
83 |
+
Args:
|
84 |
+
data: input image
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
input for neural network
|
88 |
+
|
89 |
+
"""
|
90 |
+
copy = data.copy()
|
91 |
+
copy.thumbnail(self.input_image_size, resample=3)
|
92 |
+
return self.transform(copy)
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def data_postprocessing(
|
96 |
+
data: torch.tensor, original_image: PIL.Image.Image
|
97 |
+
) -> PIL.Image.Image:
|
98 |
+
"""
|
99 |
+
Transforms output data from neural network to suitable data
|
100 |
+
format for using with other components of this framework.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
data: output data from neural network
|
104 |
+
original_image: input image which was used for predicted data
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Segmentation mask as PIL Image instance
|
108 |
+
|
109 |
+
"""
|
110 |
+
return (
|
111 |
+
Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size)
|
112 |
+
)
|
113 |
+
|
114 |
+
def __call__(
|
115 |
+
self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
|
116 |
+
) -> List[PIL.Image.Image]:
|
117 |
+
"""
|
118 |
+
Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
|
119 |
+
|
120 |
+
Args:
|
121 |
+
images: input images
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
segmentation masks as for input images, as PIL.Image.Image instances
|
125 |
+
|
126 |
+
"""
|
127 |
+
collect_masks = []
|
128 |
+
autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
|
129 |
+
with autocast:
|
130 |
+
cast_network(self.network, dtype)
|
131 |
+
for image_batch in batch_generator(images, self.batch_size):
|
132 |
+
images = thread_pool_processing(
|
133 |
+
lambda x: convert_image(load_image(x)), image_batch
|
134 |
+
)
|
135 |
+
batches = thread_pool_processing(self.data_preprocessing, images)
|
136 |
+
with torch.no_grad():
|
137 |
+
masks = [
|
138 |
+
self.network(i.to(self.device).unsqueeze(0))["out"][0]
|
139 |
+
.argmax(0)
|
140 |
+
.byte()
|
141 |
+
.cpu()
|
142 |
+
for i in batches
|
143 |
+
]
|
144 |
+
del batches
|
145 |
+
masks = thread_pool_processing(
|
146 |
+
lambda x: self.data_postprocessing(masks[x], images[x]),
|
147 |
+
range(len(images)),
|
148 |
+
)
|
149 |
+
collect_masks += masks
|
150 |
+
return collect_masks
|
carvekit/ml/wrap/fba_matting.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import pathlib
|
7 |
+
from typing import Union, List, Tuple
|
8 |
+
|
9 |
+
import PIL
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from carvekit.ml.arch.fba_matting.models import FBA
|
16 |
+
from carvekit.ml.arch.fba_matting.transforms import (
|
17 |
+
trimap_transform,
|
18 |
+
groupnorm_normalise_image,
|
19 |
+
)
|
20 |
+
from carvekit.ml.files.models_loc import fba_pretrained
|
21 |
+
from carvekit.utils.image_utils import convert_image, load_image
|
22 |
+
from carvekit.utils.models_utils import get_precision_autocast, cast_network
|
23 |
+
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
|
24 |
+
|
25 |
+
__all__ = ["FBAMatting"]
|
26 |
+
|
27 |
+
|
28 |
+
class FBAMatting(FBA):
|
29 |
+
"""
|
30 |
+
FBA Matting Neural Network to improve edges on image.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
device="cpu",
|
36 |
+
input_tensor_size: Union[List[int], int] = 2048,
|
37 |
+
batch_size: int = 2,
|
38 |
+
encoder="resnet50_GN_WS",
|
39 |
+
load_pretrained: bool = True,
|
40 |
+
fp16: bool = False,
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
Initialize the FBAMatting model
|
44 |
+
|
45 |
+
Args:
|
46 |
+
device: processing device
|
47 |
+
input_tensor_size: input image size
|
48 |
+
batch_size: the number of images that the neural network processes in one run
|
49 |
+
encoder: neural network encoder head
|
50 |
+
load_pretrained: loading pretrained model
|
51 |
+
fp16: use half precision
|
52 |
+
|
53 |
+
"""
|
54 |
+
super(FBAMatting, self).__init__(encoder=encoder)
|
55 |
+
self.fp16 = fp16
|
56 |
+
self.device = device
|
57 |
+
self.batch_size = batch_size
|
58 |
+
if isinstance(input_tensor_size, list):
|
59 |
+
self.input_image_size = input_tensor_size[:2]
|
60 |
+
else:
|
61 |
+
self.input_image_size = (input_tensor_size, input_tensor_size)
|
62 |
+
self.to(device)
|
63 |
+
if load_pretrained:
|
64 |
+
self.load_state_dict(torch.load(fba_pretrained(), map_location=self.device))
|
65 |
+
self.eval()
|
66 |
+
|
67 |
+
def data_preprocessing(
|
68 |
+
self, data: Union[PIL.Image.Image, np.ndarray]
|
69 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
70 |
+
"""
|
71 |
+
Transform input image to suitable data format for neural network
|
72 |
+
|
73 |
+
Args:
|
74 |
+
data: input image
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
input for neural network
|
78 |
+
|
79 |
+
"""
|
80 |
+
resized = data.copy()
|
81 |
+
if self.batch_size == 1:
|
82 |
+
resized.thumbnail(self.input_image_size, resample=3)
|
83 |
+
else:
|
84 |
+
resized = resized.resize(self.input_image_size, resample=3)
|
85 |
+
# noinspection PyTypeChecker
|
86 |
+
image = np.array(resized, dtype=np.float64)
|
87 |
+
image = image / 255.0 # Normalize image to [0, 1] values range
|
88 |
+
if resized.mode == "RGB":
|
89 |
+
image = image[:, :, ::-1]
|
90 |
+
elif resized.mode == "L":
|
91 |
+
image2 = np.copy(image)
|
92 |
+
h, w = image2.shape
|
93 |
+
image = np.zeros((h, w, 2)) # Transform trimap to binary data format
|
94 |
+
image[image2 == 1, 1] = 1
|
95 |
+
image[image2 == 0, 0] = 1
|
96 |
+
else:
|
97 |
+
raise ValueError("Incorrect color mode for image")
|
98 |
+
h, w = image.shape[:2] # Scale input mlt to 8
|
99 |
+
h1 = int(np.ceil(1.0 * h / 8) * 8)
|
100 |
+
w1 = int(np.ceil(1.0 * w / 8) * 8)
|
101 |
+
x_scale = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_LANCZOS4)
|
102 |
+
image_tensor = torch.from_numpy(x_scale).permute(2, 0, 1)[None, :, :, :].float()
|
103 |
+
if resized.mode == "RGB":
|
104 |
+
return image_tensor, groupnorm_normalise_image(
|
105 |
+
image_tensor.clone(), format="nchw"
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
return (
|
109 |
+
image_tensor,
|
110 |
+
torch.from_numpy(trimap_transform(x_scale))
|
111 |
+
.permute(2, 0, 1)[None, :, :, :]
|
112 |
+
.float(),
|
113 |
+
)
|
114 |
+
|
115 |
+
@staticmethod
|
116 |
+
def data_postprocessing(
|
117 |
+
data: torch.tensor, trimap: PIL.Image.Image
|
118 |
+
) -> PIL.Image.Image:
|
119 |
+
"""
|
120 |
+
Transforms output data from neural network to suitable data
|
121 |
+
format for using with other components of this framework.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
data: output data from neural network
|
125 |
+
trimap: Map with the area we need to refine
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Segmentation mask as PIL Image instance
|
129 |
+
|
130 |
+
"""
|
131 |
+
if trimap.mode != "L":
|
132 |
+
raise ValueError("Incorrect color mode for trimap")
|
133 |
+
pred = data.numpy().transpose((1, 2, 0))
|
134 |
+
pred = cv2.resize(pred, trimap.size, cv2.INTER_LANCZOS4)[:, :, 0]
|
135 |
+
# noinspection PyTypeChecker
|
136 |
+
# Clean mask by removing all false predictions outside trimap and already known area
|
137 |
+
trimap_arr = np.array(trimap.copy())
|
138 |
+
pred[trimap_arr[:, :] == 0] = 0
|
139 |
+
# pred[trimap_arr[:, :] == 255] = 1
|
140 |
+
pred[pred < 0.3] = 0
|
141 |
+
return Image.fromarray(pred * 255).convert("L")
|
142 |
+
|
143 |
+
def __call__(
|
144 |
+
self,
|
145 |
+
images: List[Union[str, pathlib.Path, PIL.Image.Image]],
|
146 |
+
trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]],
|
147 |
+
) -> List[PIL.Image.Image]:
|
148 |
+
"""
|
149 |
+
Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
|
150 |
+
|
151 |
+
Args:
|
152 |
+
images: input images
|
153 |
+
trimaps: Maps with the areas we need to refine
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
segmentation masks as for input images, as PIL.Image.Image instances
|
157 |
+
|
158 |
+
"""
|
159 |
+
|
160 |
+
if len(images) != len(trimaps):
|
161 |
+
raise ValueError(
|
162 |
+
"Len of specified arrays of images and trimaps should be equal!"
|
163 |
+
)
|
164 |
+
|
165 |
+
collect_masks = []
|
166 |
+
autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
|
167 |
+
with autocast:
|
168 |
+
cast_network(self, dtype)
|
169 |
+
for idx_batch in batch_generator(range(len(images)), self.batch_size):
|
170 |
+
inpt_images = thread_pool_processing(
|
171 |
+
lambda x: convert_image(load_image(images[x])), idx_batch
|
172 |
+
)
|
173 |
+
|
174 |
+
inpt_trimaps = thread_pool_processing(
|
175 |
+
lambda x: convert_image(load_image(trimaps[x]), mode="L"), idx_batch
|
176 |
+
)
|
177 |
+
|
178 |
+
inpt_img_batches = thread_pool_processing(
|
179 |
+
self.data_preprocessing, inpt_images
|
180 |
+
)
|
181 |
+
inpt_trimaps_batches = thread_pool_processing(
|
182 |
+
self.data_preprocessing, inpt_trimaps
|
183 |
+
)
|
184 |
+
|
185 |
+
inpt_img_batches_transformed = torch.vstack(
|
186 |
+
[i[1] for i in inpt_img_batches]
|
187 |
+
)
|
188 |
+
inpt_img_batches = torch.vstack([i[0] for i in inpt_img_batches])
|
189 |
+
|
190 |
+
inpt_trimaps_transformed = torch.vstack(
|
191 |
+
[i[1] for i in inpt_trimaps_batches]
|
192 |
+
)
|
193 |
+
inpt_trimaps_batches = torch.vstack(
|
194 |
+
[i[0] for i in inpt_trimaps_batches]
|
195 |
+
)
|
196 |
+
|
197 |
+
with torch.no_grad():
|
198 |
+
inpt_img_batches = inpt_img_batches.to(self.device)
|
199 |
+
inpt_trimaps_batches = inpt_trimaps_batches.to(self.device)
|
200 |
+
inpt_img_batches_transformed = inpt_img_batches_transformed.to(
|
201 |
+
self.device
|
202 |
+
)
|
203 |
+
inpt_trimaps_transformed = inpt_trimaps_transformed.to(self.device)
|
204 |
+
|
205 |
+
output = super(FBAMatting, self).__call__(
|
206 |
+
inpt_img_batches,
|
207 |
+
inpt_trimaps_batches,
|
208 |
+
inpt_img_batches_transformed,
|
209 |
+
inpt_trimaps_transformed,
|
210 |
+
)
|
211 |
+
output_cpu = output.cpu()
|
212 |
+
del (
|
213 |
+
inpt_img_batches,
|
214 |
+
inpt_trimaps_batches,
|
215 |
+
inpt_img_batches_transformed,
|
216 |
+
inpt_trimaps_transformed,
|
217 |
+
output,
|
218 |
+
)
|
219 |
+
masks = thread_pool_processing(
|
220 |
+
lambda x: self.data_postprocessing(output_cpu[x], inpt_trimaps[x]),
|
221 |
+
range(len(inpt_images)),
|
222 |
+
)
|
223 |
+
collect_masks += masks
|
224 |
+
return collect_masks
|
carvekit/ml/wrap/tracer_b7.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import pathlib
|
7 |
+
import warnings
|
8 |
+
from typing import List, Union
|
9 |
+
import PIL.Image
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchvision.transforms as transforms
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from carvekit.ml.arch.tracerb7.tracer import TracerDecoder
|
16 |
+
from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7
|
17 |
+
from carvekit.ml.files.models_loc import tracer_b7_pretrained, tracer_hair_pretrained
|
18 |
+
from carvekit.utils.models_utils import get_precision_autocast, cast_network
|
19 |
+
from carvekit.utils.image_utils import load_image, convert_image
|
20 |
+
from carvekit.utils.pool_utils import thread_pool_processing, batch_generator
|
21 |
+
|
22 |
+
__all__ = ["TracerUniversalB7"]
|
23 |
+
|
24 |
+
|
25 |
+
class TracerUniversalB7(TracerDecoder):
|
26 |
+
"""TRACER B7 model interface"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
device="cpu",
|
31 |
+
input_image_size: Union[List[int], int] = 640,
|
32 |
+
batch_size: int = 4,
|
33 |
+
load_pretrained: bool = True,
|
34 |
+
fp16: bool = False,
|
35 |
+
model_path: Union[str, pathlib.Path] = None,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
Initialize the U2NET model
|
39 |
+
|
40 |
+
Args:
|
41 |
+
layers_cfg: neural network layers configuration
|
42 |
+
device: processing device
|
43 |
+
input_image_size: input image size
|
44 |
+
batch_size: the number of images that the neural network processes in one run
|
45 |
+
load_pretrained: loading pretrained model
|
46 |
+
fp16: use fp16 precision
|
47 |
+
|
48 |
+
"""
|
49 |
+
if model_path is None:
|
50 |
+
model_path = tracer_b7_pretrained()
|
51 |
+
super(TracerUniversalB7, self).__init__(
|
52 |
+
encoder=EfficientEncoderB7(),
|
53 |
+
rfb_channel=[32, 64, 128],
|
54 |
+
features_channels=[48, 80, 224, 640],
|
55 |
+
)
|
56 |
+
|
57 |
+
self.fp16 = fp16
|
58 |
+
self.device = device
|
59 |
+
self.batch_size = batch_size
|
60 |
+
if isinstance(input_image_size, list):
|
61 |
+
self.input_image_size = input_image_size[:2]
|
62 |
+
else:
|
63 |
+
self.input_image_size = (input_image_size, input_image_size)
|
64 |
+
|
65 |
+
self.transform = transforms.Compose(
|
66 |
+
[
|
67 |
+
transforms.ToTensor(),
|
68 |
+
transforms.Resize(self.input_image_size),
|
69 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
70 |
+
]
|
71 |
+
)
|
72 |
+
self.to(device)
|
73 |
+
if load_pretrained:
|
74 |
+
# TODO remove edge detector from weights. It doesn't work well with this model!
|
75 |
+
self.load_state_dict(
|
76 |
+
torch.load(model_path, map_location=self.device), strict=False
|
77 |
+
)
|
78 |
+
self.eval()
|
79 |
+
|
80 |
+
def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
|
81 |
+
"""
|
82 |
+
Transform input image to suitable data format for neural network
|
83 |
+
|
84 |
+
Args:
|
85 |
+
data: input image
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
input for neural network
|
89 |
+
|
90 |
+
"""
|
91 |
+
|
92 |
+
return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def data_postprocessing(
|
96 |
+
data: torch.tensor, original_image: PIL.Image.Image
|
97 |
+
) -> PIL.Image.Image:
|
98 |
+
"""
|
99 |
+
Transforms output data from neural network to suitable data
|
100 |
+
format for using with other components of this framework.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
data: output data from neural network
|
104 |
+
original_image: input image which was used for predicted data
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Segmentation mask as PIL Image instance
|
108 |
+
|
109 |
+
"""
|
110 |
+
output = (data.type(torch.FloatTensor).detach().cpu().numpy() * 255.0).astype(
|
111 |
+
np.uint8
|
112 |
+
)
|
113 |
+
output = output.squeeze(0)
|
114 |
+
mask = Image.fromarray(output).convert("L")
|
115 |
+
mask = mask.resize(original_image.size, resample=Image.BILINEAR)
|
116 |
+
return mask
|
117 |
+
|
118 |
+
def __call__(
|
119 |
+
self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
|
120 |
+
) -> List[PIL.Image.Image]:
|
121 |
+
"""
|
122 |
+
Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
|
123 |
+
|
124 |
+
Args:
|
125 |
+
images: input images
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
segmentation masks as for input images, as PIL.Image.Image instances
|
129 |
+
|
130 |
+
"""
|
131 |
+
collect_masks = []
|
132 |
+
autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
|
133 |
+
with autocast:
|
134 |
+
cast_network(self, dtype)
|
135 |
+
for image_batch in batch_generator(images, self.batch_size):
|
136 |
+
images = thread_pool_processing(
|
137 |
+
lambda x: convert_image(load_image(x)), image_batch
|
138 |
+
)
|
139 |
+
batches = torch.vstack(
|
140 |
+
thread_pool_processing(self.data_preprocessing, images)
|
141 |
+
)
|
142 |
+
with torch.no_grad():
|
143 |
+
batches = batches.to(self.device)
|
144 |
+
masks = super(TracerDecoder, self).__call__(batches)
|
145 |
+
masks_cpu = masks.cpu()
|
146 |
+
del batches, masks
|
147 |
+
masks = thread_pool_processing(
|
148 |
+
lambda x: self.data_postprocessing(masks_cpu[x], images[x]),
|
149 |
+
range(len(images)),
|
150 |
+
)
|
151 |
+
collect_masks += masks
|
152 |
+
|
153 |
+
return collect_masks
|
154 |
+
|
155 |
+
|
156 |
+
class TracerHair(TracerUniversalB7):
|
157 |
+
"""TRACER HAIR model interface"""
|
158 |
+
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
device="cpu",
|
162 |
+
input_image_size: Union[List[int], int] = 640,
|
163 |
+
batch_size: int = 4,
|
164 |
+
load_pretrained: bool = True,
|
165 |
+
fp16: bool = False,
|
166 |
+
model_path: Union[str, pathlib.Path] = None,
|
167 |
+
):
|
168 |
+
if model_path is None:
|
169 |
+
model_path = tracer_hair_pretrained()
|
170 |
+
warnings.warn("TracerHair has not public model yet. Don't use it!", UserWarning)
|
171 |
+
super(TracerHair, self).__init__(
|
172 |
+
device=device,
|
173 |
+
input_image_size=input_image_size,
|
174 |
+
batch_size=batch_size,
|
175 |
+
load_pretrained=load_pretrained,
|
176 |
+
fp16=fp16,
|
177 |
+
model_path=model_path,
|
178 |
+
)
|
carvekit/ml/wrap/u2net.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import pathlib
|
7 |
+
from typing import List, Union
|
8 |
+
import PIL.Image
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
from carvekit.ml.arch.u2net.u2net import U2NETArchitecture
|
14 |
+
from carvekit.ml.files.models_loc import u2net_full_pretrained
|
15 |
+
from carvekit.utils.image_utils import load_image, convert_image
|
16 |
+
from carvekit.utils.pool_utils import thread_pool_processing, batch_generator
|
17 |
+
|
18 |
+
__all__ = ["U2NET"]
|
19 |
+
|
20 |
+
|
21 |
+
class U2NET(U2NETArchitecture):
|
22 |
+
"""U^2-Net model interface"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
layers_cfg="full",
|
27 |
+
device="cpu",
|
28 |
+
input_image_size: Union[List[int], int] = 320,
|
29 |
+
batch_size: int = 10,
|
30 |
+
load_pretrained: bool = True,
|
31 |
+
fp16: bool = False,
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
Initialize the U2NET model
|
35 |
+
|
36 |
+
Args:
|
37 |
+
layers_cfg: neural network layers configuration
|
38 |
+
device: processing device
|
39 |
+
input_image_size: input image size
|
40 |
+
batch_size: the number of images that the neural network processes in one run
|
41 |
+
load_pretrained: loading pretrained model
|
42 |
+
fp16: use fp16 precision // not supported at this moment.
|
43 |
+
|
44 |
+
"""
|
45 |
+
super(U2NET, self).__init__(cfg_type=layers_cfg, out_ch=1)
|
46 |
+
self.device = device
|
47 |
+
self.batch_size = batch_size
|
48 |
+
if isinstance(input_image_size, list):
|
49 |
+
self.input_image_size = input_image_size[:2]
|
50 |
+
else:
|
51 |
+
self.input_image_size = (input_image_size, input_image_size)
|
52 |
+
self.to(device)
|
53 |
+
if load_pretrained:
|
54 |
+
self.load_state_dict(
|
55 |
+
torch.load(u2net_full_pretrained(), map_location=self.device)
|
56 |
+
)
|
57 |
+
self.eval()
|
58 |
+
|
59 |
+
def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
|
60 |
+
"""
|
61 |
+
Transform input image to suitable data format for neural network
|
62 |
+
|
63 |
+
Args:
|
64 |
+
data: input image
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
input for neural network
|
68 |
+
|
69 |
+
"""
|
70 |
+
resized = data.resize(self.input_image_size, resample=3)
|
71 |
+
# noinspection PyTypeChecker
|
72 |
+
resized_arr = np.array(resized, dtype=float)
|
73 |
+
temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
|
74 |
+
if np.max(resized_arr) != 0:
|
75 |
+
resized_arr /= np.max(resized_arr)
|
76 |
+
temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
|
77 |
+
temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
|
78 |
+
temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
|
79 |
+
temp_image = temp_image.transpose((2, 0, 1))
|
80 |
+
temp_image = np.expand_dims(temp_image, 0)
|
81 |
+
return torch.from_numpy(temp_image).type(torch.FloatTensor)
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def data_postprocessing(
|
85 |
+
data: torch.tensor, original_image: PIL.Image.Image
|
86 |
+
) -> PIL.Image.Image:
|
87 |
+
"""
|
88 |
+
Transforms output data from neural network to suitable data
|
89 |
+
format for using with other components of this framework.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
data: output data from neural network
|
93 |
+
original_image: input image which was used for predicted data
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
Segmentation mask as PIL Image instance
|
97 |
+
|
98 |
+
"""
|
99 |
+
data = data.unsqueeze(0)
|
100 |
+
mask = data[:, 0, :, :]
|
101 |
+
ma = torch.max(mask) # Normalizes prediction
|
102 |
+
mi = torch.min(mask)
|
103 |
+
predict = ((mask - mi) / (ma - mi)).squeeze()
|
104 |
+
predict_np = predict.cpu().data.numpy() * 255
|
105 |
+
mask = Image.fromarray(predict_np).convert("L")
|
106 |
+
mask = mask.resize(original_image.size, resample=3)
|
107 |
+
return mask
|
108 |
+
|
109 |
+
def __call__(
|
110 |
+
self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
|
111 |
+
) -> List[PIL.Image.Image]:
|
112 |
+
"""
|
113 |
+
Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
|
114 |
+
|
115 |
+
Args:
|
116 |
+
images: input images
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
segmentation masks as for input images, as PIL.Image.Image instances
|
120 |
+
|
121 |
+
"""
|
122 |
+
collect_masks = []
|
123 |
+
for image_batch in batch_generator(images, self.batch_size):
|
124 |
+
images = thread_pool_processing(
|
125 |
+
lambda x: convert_image(load_image(x)), image_batch
|
126 |
+
)
|
127 |
+
batches = torch.vstack(
|
128 |
+
thread_pool_processing(self.data_preprocessing, images)
|
129 |
+
)
|
130 |
+
with torch.no_grad():
|
131 |
+
batches = batches.to(self.device)
|
132 |
+
masks, d2, d3, d4, d5, d6, d7 = super(U2NET, self).__call__(batches)
|
133 |
+
masks_cpu = masks.cpu()
|
134 |
+
del d2, d3, d4, d5, d6, d7, batches, masks
|
135 |
+
masks = thread_pool_processing(
|
136 |
+
lambda x: self.data_postprocessing(masks_cpu[x], images[x]),
|
137 |
+
range(len(images)),
|
138 |
+
)
|
139 |
+
collect_masks += masks
|
140 |
+
return collect_masks
|
carvekit/pipelines/__init__.py
ADDED
File without changes
|
carvekit/pipelines/postprocessing.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
from carvekit.ml.wrap.fba_matting import FBAMatting
|
7 |
+
from typing import Union, List
|
8 |
+
from PIL import Image
|
9 |
+
from pathlib import Path
|
10 |
+
from carvekit.trimap.cv_gen import CV2TrimapGenerator
|
11 |
+
from carvekit.trimap.generator import TrimapGenerator
|
12 |
+
from carvekit.utils.mask_utils import apply_mask
|
13 |
+
from carvekit.utils.pool_utils import thread_pool_processing
|
14 |
+
from carvekit.utils.image_utils import load_image, convert_image
|
15 |
+
|
16 |
+
__all__ = ["MattingMethod"]
|
17 |
+
|
18 |
+
|
19 |
+
class MattingMethod:
|
20 |
+
"""
|
21 |
+
Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap.
|
22 |
+
Neural network for matting performs accurate object edge detection by using a special map called trimap,
|
23 |
+
with unknown area that we scan for boundary, already known general object area and the background."""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
matting_module: Union[FBAMatting],
|
28 |
+
trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
|
29 |
+
device="cpu",
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Initializes Matting Method class.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
matting_module: Initialized matting neural network class
|
36 |
+
trimap_generator: Initialized trimap generator class
|
37 |
+
device: Processing device used for applying mask to image
|
38 |
+
"""
|
39 |
+
self.device = device
|
40 |
+
self.matting_module = matting_module
|
41 |
+
self.trimap_generator = trimap_generator
|
42 |
+
|
43 |
+
def __call__(
|
44 |
+
self,
|
45 |
+
images: List[Union[str, Path, Image.Image]],
|
46 |
+
masks: List[Union[str, Path, Image.Image]],
|
47 |
+
):
|
48 |
+
"""
|
49 |
+
Passes data through apply_mask function
|
50 |
+
|
51 |
+
Args:
|
52 |
+
images: list of images
|
53 |
+
masks: list pf masks
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
list of images
|
57 |
+
"""
|
58 |
+
if len(images) != len(masks):
|
59 |
+
raise ValueError("Images and Masks lists should have same length!")
|
60 |
+
images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
|
61 |
+
masks = thread_pool_processing(
|
62 |
+
lambda x: convert_image(load_image(x), mode="L"), masks
|
63 |
+
)
|
64 |
+
trimaps = thread_pool_processing(
|
65 |
+
lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]),
|
66 |
+
range(len(images)),
|
67 |
+
)
|
68 |
+
alpha = self.matting_module(images=images, trimaps=trimaps)
|
69 |
+
return list(
|
70 |
+
map(
|
71 |
+
lambda x: apply_mask(
|
72 |
+
image=images[x], mask=alpha[x], device=self.device
|
73 |
+
),
|
74 |
+
range(len(images)),
|
75 |
+
)
|
76 |
+
)
|
carvekit/pipelines/preprocessing.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Union, List
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
__all__ = ["PreprocessingStub"]
|
12 |
+
|
13 |
+
|
14 |
+
class PreprocessingStub:
|
15 |
+
"""Stub for future preprocessing methods"""
|
16 |
+
|
17 |
+
def __call__(self, interface, images: List[Union[str, Path, Image.Image]]):
|
18 |
+
"""
|
19 |
+
Passes data though interface.segmentation_pipeline() method
|
20 |
+
|
21 |
+
Args:
|
22 |
+
interface: Interface instance
|
23 |
+
images: list of images
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
the result of passing data through segmentation_pipeline method of interface
|
27 |
+
"""
|
28 |
+
return interface.segmentation_pipeline(images=images)
|
carvekit/trimap/__init__.py
ADDED
File without changes
|
carvekit/trimap/add_ops.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
def prob_filter(mask: Image.Image, prob_threshold=231) -> Image.Image:
|
12 |
+
"""
|
13 |
+
Applies a filter to the mask by the probability of locating an object in the object area.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
prob_threshold: Threshold of probability for mark area as background.
|
17 |
+
mask: Predicted object mask
|
18 |
+
|
19 |
+
Raises:
|
20 |
+
ValueError if mask or trimap has wrong color mode
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Generated trimap for image.
|
24 |
+
"""
|
25 |
+
if mask.mode != "L":
|
26 |
+
raise ValueError("Input mask has wrong color mode.")
|
27 |
+
# noinspection PyTypeChecker
|
28 |
+
mask_array = np.array(mask)
|
29 |
+
mask_array[mask_array > prob_threshold] = 255 # Probability filter for mask
|
30 |
+
mask_array[mask_array <= prob_threshold] = 0
|
31 |
+
return Image.fromarray(mask_array).convert("L")
|
32 |
+
|
33 |
+
|
34 |
+
def prob_as_unknown_area(
|
35 |
+
trimap: Image.Image, mask: Image.Image, prob_threshold=255
|
36 |
+
) -> Image.Image:
|
37 |
+
"""
|
38 |
+
Marks any uncertainty in the seg mask as an unknown region.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
prob_threshold: Threshold of probability for mark area as unknown.
|
42 |
+
trimap: Generated trimap.
|
43 |
+
mask: Predicted object mask
|
44 |
+
|
45 |
+
Raises:
|
46 |
+
ValueError if mask or trimap has wrong color mode
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Generated trimap for image.
|
50 |
+
"""
|
51 |
+
if mask.mode != "L" or trimap.mode != "L":
|
52 |
+
raise ValueError("Input mask has wrong color mode.")
|
53 |
+
# noinspection PyTypeChecker
|
54 |
+
mask_array = np.array(mask)
|
55 |
+
# noinspection PyTypeChecker
|
56 |
+
trimap_array = np.array(trimap)
|
57 |
+
trimap_array[np.logical_and(mask_array <= prob_threshold, mask_array > 0)] = 127
|
58 |
+
return Image.fromarray(trimap_array).convert("L")
|
59 |
+
|
60 |
+
|
61 |
+
def post_erosion(trimap: Image.Image, erosion_iters=1) -> Image.Image:
|
62 |
+
"""
|
63 |
+
Performs erosion on the mask and marks the resulting area as an unknown region.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
erosion_iters: The number of iterations of erosion that
|
67 |
+
the object's mask will be subjected to before forming an unknown area
|
68 |
+
trimap: Generated trimap.
|
69 |
+
mask: Predicted object mask
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
Generated trimap for image.
|
73 |
+
"""
|
74 |
+
if trimap.mode != "L":
|
75 |
+
raise ValueError("Input mask has wrong color mode.")
|
76 |
+
# noinspection PyTypeChecker
|
77 |
+
trimap_array = np.array(trimap)
|
78 |
+
if erosion_iters > 0:
|
79 |
+
without_unknown_area = trimap_array.copy()
|
80 |
+
without_unknown_area[without_unknown_area == 127] = 0
|
81 |
+
|
82 |
+
erosion_kernel = np.ones((3, 3), np.uint8)
|
83 |
+
erode = cv2.erode(
|
84 |
+
without_unknown_area, erosion_kernel, iterations=erosion_iters
|
85 |
+
)
|
86 |
+
erode = np.where(erode == 0, 0, without_unknown_area)
|
87 |
+
trimap_array[np.logical_and(erode == 0, without_unknown_area > 0)] = 127
|
88 |
+
erode = trimap_array.copy()
|
89 |
+
else:
|
90 |
+
erode = trimap_array.copy()
|
91 |
+
return Image.fromarray(erode).convert("L")
|
carvekit/trimap/cv_gen.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import PIL.Image
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
class CV2TrimapGenerator:
|
12 |
+
def __init__(self, kernel_size: int = 30, erosion_iters: int = 1):
|
13 |
+
"""
|
14 |
+
Initialize a new CV2TrimapGenerator instance
|
15 |
+
|
16 |
+
Args:
|
17 |
+
kernel_size: The size of the offset from the object mask
|
18 |
+
in pixels when an unknown area is detected in the trimap
|
19 |
+
erosion_iters: The number of iterations of erosion that
|
20 |
+
the object's mask will be subjected to before forming an unknown area
|
21 |
+
"""
|
22 |
+
self.kernel_size = kernel_size
|
23 |
+
self.erosion_iters = erosion_iters
|
24 |
+
|
25 |
+
def __call__(
|
26 |
+
self, original_image: PIL.Image.Image, mask: PIL.Image.Image
|
27 |
+
) -> PIL.Image.Image:
|
28 |
+
"""
|
29 |
+
Generates trimap based on predicted object mask to refine object mask borders.
|
30 |
+
Based on cv2 erosion algorithm.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
original_image: Original image
|
34 |
+
mask: Predicted object mask
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Generated trimap for image.
|
38 |
+
"""
|
39 |
+
if mask.mode != "L":
|
40 |
+
raise ValueError("Input mask has wrong color mode.")
|
41 |
+
if mask.size != original_image.size:
|
42 |
+
raise ValueError("Sizes of input image and predicted mask doesn't equal")
|
43 |
+
# noinspection PyTypeChecker
|
44 |
+
mask_array = np.array(mask)
|
45 |
+
pixels = 2 * self.kernel_size + 1
|
46 |
+
kernel = np.ones((pixels, pixels), np.uint8)
|
47 |
+
|
48 |
+
if self.erosion_iters > 0:
|
49 |
+
erosion_kernel = np.ones((3, 3), np.uint8)
|
50 |
+
erode = cv2.erode(mask_array, erosion_kernel, iterations=self.erosion_iters)
|
51 |
+
erode = np.where(erode == 0, 0, mask_array)
|
52 |
+
else:
|
53 |
+
erode = mask_array.copy()
|
54 |
+
|
55 |
+
dilation = cv2.dilate(erode, kernel, iterations=1)
|
56 |
+
|
57 |
+
dilation = np.where(dilation == 255, 127, dilation) # WHITE to GRAY
|
58 |
+
trimap = np.where(erode > 127, 200, dilation) # mark the tumor inside GRAY
|
59 |
+
|
60 |
+
trimap = np.where(trimap < 127, 0, trimap) # Embelishment
|
61 |
+
trimap = np.where(trimap > 200, 0, trimap) # Embelishment
|
62 |
+
trimap = np.where(trimap == 200, 255, trimap) # GRAY to WHITE
|
63 |
+
|
64 |
+
return PIL.Image.fromarray(trimap).convert("L")
|
carvekit/trimap/generator.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
from PIL import Image
|
7 |
+
from carvekit.trimap.cv_gen import CV2TrimapGenerator
|
8 |
+
from carvekit.trimap.add_ops import prob_filter, prob_as_unknown_area, post_erosion
|
9 |
+
|
10 |
+
|
11 |
+
class TrimapGenerator(CV2TrimapGenerator):
|
12 |
+
def __init__(
|
13 |
+
self, prob_threshold: int = 231, kernel_size: int = 30, erosion_iters: int = 5
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
Initialize a TrimapGenerator instance
|
17 |
+
|
18 |
+
Args:
|
19 |
+
prob_threshold: Probability threshold at which the
|
20 |
+
prob_filter and prob_as_unknown_area operations will be applied
|
21 |
+
kernel_size: The size of the offset from the object mask
|
22 |
+
in pixels when an unknown area is detected in the trimap
|
23 |
+
erosion_iters: The number of iterations of erosion that
|
24 |
+
the object's mask will be subjected to before forming an unknown area
|
25 |
+
"""
|
26 |
+
super().__init__(kernel_size, erosion_iters=0)
|
27 |
+
self.prob_threshold = prob_threshold
|
28 |
+
self.__erosion_iters = erosion_iters
|
29 |
+
|
30 |
+
def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Image:
|
31 |
+
"""
|
32 |
+
Generates trimap based on predicted object mask to refine object mask borders.
|
33 |
+
Based on cv2 erosion algorithm and additional prob. filters.
|
34 |
+
Args:
|
35 |
+
original_image: Original image
|
36 |
+
mask: Predicted object mask
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
Generated trimap for image.
|
40 |
+
"""
|
41 |
+
filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold)
|
42 |
+
trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask)
|
43 |
+
new_trimap = prob_as_unknown_area(
|
44 |
+
trimap=trimap, mask=mask, prob_threshold=self.prob_threshold
|
45 |
+
)
|
46 |
+
new_trimap = post_erosion(new_trimap, self.__erosion_iters)
|
47 |
+
return new_trimap
|
carvekit/utils/__init__.py
ADDED
File without changes
|
carvekit/utils/download_models.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import hashlib
|
7 |
+
import os
|
8 |
+
import warnings
|
9 |
+
from abc import ABCMeta, abstractmethod, ABC
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
import carvekit
|
14 |
+
from carvekit.ml.files import checkpoints_dir
|
15 |
+
|
16 |
+
import requests
|
17 |
+
import tqdm
|
18 |
+
|
19 |
+
requests = requests.Session()
|
20 |
+
requests.headers.update({"User-Agent": f"Carvekit/{carvekit.version}"})
|
21 |
+
|
22 |
+
MODELS_URLS = {
|
23 |
+
"basnet.pth": {
|
24 |
+
"repository": "Carve/basnet-universal",
|
25 |
+
"revision": "870becbdb364fda6d8fdb2c10b072542f8d08701",
|
26 |
+
"filename": "basnet.pth",
|
27 |
+
},
|
28 |
+
"deeplab.pth": {
|
29 |
+
"repository": "Carve/deeplabv3-resnet101",
|
30 |
+
"revision": "d504005392fc877565afdf58aad0cd524682d2b0",
|
31 |
+
"filename": "deeplab.pth",
|
32 |
+
},
|
33 |
+
"fba_matting.pth": {
|
34 |
+
"repository": "Carve/fba",
|
35 |
+
"revision": "a5d3457df0fb9c88ea19ed700d409756ca2069d1",
|
36 |
+
"filename": "fba_matting.pth",
|
37 |
+
},
|
38 |
+
"u2net.pth": {
|
39 |
+
"repository": "Carve/u2net-universal",
|
40 |
+
"revision": "10305d785481cf4b2eee1d447c39cd6e5f43d74b",
|
41 |
+
"filename": "full_weights.pth",
|
42 |
+
},
|
43 |
+
"tracer_b7.pth": {
|
44 |
+
"repository": "Carve/tracer_b7",
|
45 |
+
"revision": "d8a8fd9e7b3fa0d2f1506fe7242966b34381e9c5",
|
46 |
+
"filename": "tracer_b7.pth",
|
47 |
+
},
|
48 |
+
"tracer_hair.pth": {
|
49 |
+
"repository": "Carve/tracer_b7",
|
50 |
+
"revision": "d8a8fd9e7b3fa0d2f1506fe7242966b34381e9c5",
|
51 |
+
"filename": "tracer_b7.pth", # TODO don't forget change this link!!
|
52 |
+
},
|
53 |
+
}
|
54 |
+
|
55 |
+
MODELS_CHECKSUMS = {
|
56 |
+
"basnet.pth": "e409cb709f4abca87cb11bd44a9ad3f909044a917977ab65244b4c94dd33"
|
57 |
+
"8b1a37755c4253d7cb54526b7763622a094d7b676d34b5e6886689256754e5a5e6ad",
|
58 |
+
"deeplab.pth": "9c5a1795bc8baa267200a44b49ac544a1ba2687d210f63777e4bd715387324469a59b072f8a28"
|
59 |
+
"9cc471c637b367932177e5b312e8ea6351c1763d9ff44b4857c",
|
60 |
+
"fba_matting.pth": "890906ec94c1bfd2ad08707a63e4ccb0955d7f5d25e32853950c24c78"
|
61 |
+
"4cbad2e59be277999defc3754905d0f15aa75702cdead3cfe669ff72f08811c52971613",
|
62 |
+
"u2net.pth": "16f8125e2fedd8c85db0e001ee15338b4aa2fda77bab8ba70c25e"
|
63 |
+
"bea1533fda5ee70a909b934a9bd495b432cef89d629f00a07858a517742476fa8b346de24f7",
|
64 |
+
"tracer_b7.pth": "c439c5c12d4d43d5f9be9ec61e68b2e54658a541bccac2577ef5a54fb252b6e8415d41f7e"
|
65 |
+
"c2487033d0c02b4dd08367958e4e62091318111c519f93e2632be7b",
|
66 |
+
"tracer_hair.pth": "5c2fb9973fc42fa6208920ffa9ac233cc2ea9f770b24b4a96969d3449aed7ac89e6d37e"
|
67 |
+
"e486a13e63be5499f2df6ccef1109e9e8797d1326207ac89b2f39a7cf",
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
def sha512_checksum_calc(file: Path) -> str:
|
72 |
+
"""
|
73 |
+
Calculates the SHA512 hash digest of a file on fs
|
74 |
+
|
75 |
+
Args:
|
76 |
+
file: Path to the file
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
SHA512 hash digest of a file.
|
80 |
+
"""
|
81 |
+
dd = hashlib.sha512()
|
82 |
+
with file.open("rb") as f:
|
83 |
+
for chunk in iter(lambda: f.read(4096), b""):
|
84 |
+
dd.update(chunk)
|
85 |
+
return dd.hexdigest()
|
86 |
+
|
87 |
+
|
88 |
+
class CachedDownloader:
|
89 |
+
__metaclass__ = ABCMeta
|
90 |
+
|
91 |
+
@property
|
92 |
+
@abstractmethod
|
93 |
+
def name(self) -> str:
|
94 |
+
return self.__class__.__name__
|
95 |
+
|
96 |
+
@property
|
97 |
+
@abstractmethod
|
98 |
+
def fallback_downloader(self) -> Optional["CachedDownloader"]:
|
99 |
+
pass
|
100 |
+
|
101 |
+
def download_model(self, file_name: str) -> Path:
|
102 |
+
try:
|
103 |
+
return self.download_model_base(file_name)
|
104 |
+
except BaseException as e:
|
105 |
+
if self.fallback_downloader is not None:
|
106 |
+
warnings.warn(
|
107 |
+
f"Failed to download model from {self.name} downloader."
|
108 |
+
f" Trying to download from {self.fallback_downloader.name} downloader."
|
109 |
+
)
|
110 |
+
return self.fallback_downloader.download_model(file_name)
|
111 |
+
else:
|
112 |
+
warnings.warn(
|
113 |
+
f"Failed to download model from {self.name} downloader."
|
114 |
+
f" No fallback downloader available."
|
115 |
+
)
|
116 |
+
raise e
|
117 |
+
|
118 |
+
@abstractmethod
|
119 |
+
def download_model_base(self, file_name: str) -> Path:
|
120 |
+
"""Download model from any source if not cached. Returns path if cached"""
|
121 |
+
|
122 |
+
def __call__(self, file_name: str):
|
123 |
+
return self.download_model(file_name)
|
124 |
+
|
125 |
+
|
126 |
+
class HuggingFaceCompatibleDownloader(CachedDownloader, ABC):
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
name: str = "Huggingface.co",
|
130 |
+
base_url: str = "https://huggingface.co",
|
131 |
+
fb_downloader: Optional["CachedDownloader"] = None,
|
132 |
+
):
|
133 |
+
self.cache_dir = checkpoints_dir
|
134 |
+
self.base_url = base_url
|
135 |
+
self._name = name
|
136 |
+
self._fallback_downloader = fb_downloader
|
137 |
+
|
138 |
+
@property
|
139 |
+
def fallback_downloader(self) -> Optional["CachedDownloader"]:
|
140 |
+
return self._fallback_downloader
|
141 |
+
|
142 |
+
@property
|
143 |
+
def name(self):
|
144 |
+
return self._name
|
145 |
+
|
146 |
+
def check_for_existence(self, file_name: str) -> Optional[Path]:
|
147 |
+
if file_name not in MODELS_URLS.keys():
|
148 |
+
raise FileNotFoundError("Unknown model!")
|
149 |
+
path = (
|
150 |
+
self.cache_dir
|
151 |
+
/ MODELS_URLS[file_name]["repository"].split("/")[1]
|
152 |
+
/ file_name
|
153 |
+
)
|
154 |
+
|
155 |
+
if not path.exists():
|
156 |
+
return None
|
157 |
+
|
158 |
+
if MODELS_CHECKSUMS[path.name] != sha512_checksum_calc(path):
|
159 |
+
warnings.warn(
|
160 |
+
f"Invalid checksum for model {path.name}. Downloading correct model!"
|
161 |
+
)
|
162 |
+
os.remove(path)
|
163 |
+
return None
|
164 |
+
return path
|
165 |
+
|
166 |
+
def download_model_base(self, file_name: str) -> Path:
|
167 |
+
cached_path = self.check_for_existence(file_name)
|
168 |
+
if cached_path is not None:
|
169 |
+
return cached_path
|
170 |
+
else:
|
171 |
+
cached_path = (
|
172 |
+
self.cache_dir
|
173 |
+
/ MODELS_URLS[file_name]["repository"].split("/")[1]
|
174 |
+
/ file_name
|
175 |
+
)
|
176 |
+
cached_path.parent.mkdir(parents=True, exist_ok=True)
|
177 |
+
url = MODELS_URLS[file_name]
|
178 |
+
hugging_face_url = f"{self.base_url}/{url['repository']}/resolve/{url['revision']}/{url['filename']}"
|
179 |
+
|
180 |
+
try:
|
181 |
+
r = requests.get(hugging_face_url, stream=True, timeout=10)
|
182 |
+
if r.status_code < 400:
|
183 |
+
with open(cached_path, "wb") as f:
|
184 |
+
r.raw.decode_content = True
|
185 |
+
for chunk in tqdm.tqdm(
|
186 |
+
r,
|
187 |
+
desc="Downloading " + cached_path.name + " model",
|
188 |
+
colour="blue",
|
189 |
+
):
|
190 |
+
f.write(chunk)
|
191 |
+
else:
|
192 |
+
if r.status_code == 404:
|
193 |
+
raise FileNotFoundError(f"Model {file_name} not found!")
|
194 |
+
else:
|
195 |
+
raise ConnectionError(
|
196 |
+
f"Error {r.status_code} while downloading model {file_name}!"
|
197 |
+
)
|
198 |
+
except BaseException as e:
|
199 |
+
if cached_path.exists():
|
200 |
+
os.remove(cached_path)
|
201 |
+
raise ConnectionError(
|
202 |
+
f"Exception caught when downloading model! "
|
203 |
+
f"Model name: {cached_path.name}. Exception: {str(e)}."
|
204 |
+
)
|
205 |
+
return cached_path
|
206 |
+
|
207 |
+
|
208 |
+
fallback_downloader: CachedDownloader = HuggingFaceCompatibleDownloader()
|
209 |
+
downloader: CachedDownloader = HuggingFaceCompatibleDownloader(
|
210 |
+
base_url="https://cdn.carve.photos",
|
211 |
+
fb_downloader=fallback_downloader,
|
212 |
+
name="Carve CDN",
|
213 |
+
)
|
214 |
+
downloader._fallback_downloader = fallback_downloader
|
carvekit/utils/fs_utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
from pathlib import Path
|
7 |
+
from PIL import Image
|
8 |
+
import warnings
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
|
12 |
+
def save_file(output: Optional[Path], input_path: Path, image: Image.Image):
|
13 |
+
"""
|
14 |
+
Saves an image to the file system
|
15 |
+
|
16 |
+
Args:
|
17 |
+
output: Output path [dir or end file]
|
18 |
+
input_path: Input path of the image
|
19 |
+
image: Image to be saved.
|
20 |
+
"""
|
21 |
+
if isinstance(output, Path) and str(output) != "none":
|
22 |
+
if output.is_dir() and output.exists():
|
23 |
+
image.save(output.joinpath(input_path.with_suffix(".png").name))
|
24 |
+
elif output.suffix != "":
|
25 |
+
if output.suffix != ".png":
|
26 |
+
warnings.warn(
|
27 |
+
f"Only export with .png extension is supported! Your {output.suffix}"
|
28 |
+
f" extension will be ignored and replaced with .png!"
|
29 |
+
)
|
30 |
+
image.save(output.with_suffix(".png"))
|
31 |
+
else:
|
32 |
+
raise ValueError("Wrong output path!")
|
33 |
+
elif output is None or str(output) == "none":
|
34 |
+
image.save(
|
35 |
+
input_path.with_name(
|
36 |
+
input_path.stem.split(".")[0] + "_bg_removed"
|
37 |
+
).with_suffix(".png")
|
38 |
+
)
|
carvekit/utils/image_utils.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
|
7 |
+
import pathlib
|
8 |
+
from typing import Union, Any, Tuple
|
9 |
+
|
10 |
+
import PIL.Image
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
ALLOWED_SUFFIXES = [".jpg", ".jpeg", ".bmp", ".png", ".webp"]
|
15 |
+
|
16 |
+
|
17 |
+
def to_tensor(x: Any) -> torch.Tensor:
|
18 |
+
"""
|
19 |
+
Returns a PIL.Image.Image as torch tensor without swap tensor dims.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
x: PIL.Image.Image instance
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
torch.Tensor instance
|
26 |
+
"""
|
27 |
+
return torch.tensor(np.array(x, copy=True))
|
28 |
+
|
29 |
+
|
30 |
+
def load_image(file: Union[str, pathlib.Path, PIL.Image.Image]) -> PIL.Image.Image:
|
31 |
+
"""Returns a PIL.Image.Image class by string path or pathlib path or PIL.Image.Image instance
|
32 |
+
|
33 |
+
Args:
|
34 |
+
file: File path or PIL.Image.Image instance
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
PIL.Image.Image instance
|
38 |
+
|
39 |
+
Raises:
|
40 |
+
ValueError: If file not exists or file is directory or file isn't an image or file is not correct PIL Image
|
41 |
+
|
42 |
+
"""
|
43 |
+
if isinstance(file, str) and is_image_valid(pathlib.Path(file)):
|
44 |
+
return PIL.Image.open(file)
|
45 |
+
elif isinstance(file, PIL.Image.Image):
|
46 |
+
return file
|
47 |
+
elif isinstance(file, pathlib.Path) and is_image_valid(file):
|
48 |
+
return PIL.Image.open(str(file))
|
49 |
+
else:
|
50 |
+
raise ValueError("Unknown input file type")
|
51 |
+
|
52 |
+
|
53 |
+
def convert_image(image: PIL.Image.Image, mode="RGB") -> PIL.Image.Image:
|
54 |
+
"""Performs image conversion to correct color mode
|
55 |
+
|
56 |
+
Args:
|
57 |
+
image: PIL.Image.Image instance
|
58 |
+
mode: Colort Mode to convert
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
PIL.Image.Image instance
|
62 |
+
|
63 |
+
Raises:
|
64 |
+
ValueError: If image hasn't convertable color mode, or it is too small
|
65 |
+
"""
|
66 |
+
if is_image_valid(image):
|
67 |
+
return image.convert(mode)
|
68 |
+
|
69 |
+
|
70 |
+
def is_image_valid(image: Union[pathlib.Path, PIL.Image.Image]) -> bool:
|
71 |
+
"""This function performs image validation.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
image: Path to the image or PIL.Image.Image instance being checked.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
True if image is valid
|
78 |
+
|
79 |
+
Raises:
|
80 |
+
ValueError: If file not a valid image path or image hasn't convertable color mode, or it is too small
|
81 |
+
|
82 |
+
"""
|
83 |
+
if isinstance(image, pathlib.Path):
|
84 |
+
if not image.exists():
|
85 |
+
raise ValueError("File is not exists")
|
86 |
+
elif image.is_dir():
|
87 |
+
raise ValueError("File is a directory")
|
88 |
+
elif image.suffix.lower() not in ALLOWED_SUFFIXES:
|
89 |
+
raise ValueError(
|
90 |
+
f"Unsupported image format. Supported file formats: {', '.join(ALLOWED_SUFFIXES)}"
|
91 |
+
)
|
92 |
+
elif isinstance(image, PIL.Image.Image):
|
93 |
+
if not (image.size[0] > 32 and image.size[1] > 32):
|
94 |
+
raise ValueError("Image should be bigger then (32x32) pixels.")
|
95 |
+
elif image.mode not in ["RGB", "RGBA", "L"]:
|
96 |
+
raise ValueError("Wrong image color mode.")
|
97 |
+
else:
|
98 |
+
raise ValueError("Unknown input file type")
|
99 |
+
return True
|
100 |
+
|
101 |
+
|
102 |
+
def transparency_paste(
|
103 |
+
bg_img: PIL.Image.Image, fg_img: PIL.Image.Image, box=(0, 0)
|
104 |
+
) -> PIL.Image.Image:
|
105 |
+
"""
|
106 |
+
Inserts an image into another image while maintaining transparency.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
bg_img: background image
|
110 |
+
fg_img: foreground image
|
111 |
+
box: place to paste
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Background image with pasted foreground image at point or in the specified box
|
115 |
+
"""
|
116 |
+
fg_img_trans = PIL.Image.new("RGBA", bg_img.size)
|
117 |
+
fg_img_trans.paste(fg_img, box, mask=fg_img)
|
118 |
+
new_img = PIL.Image.alpha_composite(bg_img, fg_img_trans)
|
119 |
+
return new_img
|
120 |
+
|
121 |
+
|
122 |
+
def add_margin(
|
123 |
+
pil_img: PIL.Image.Image,
|
124 |
+
top: int,
|
125 |
+
right: int,
|
126 |
+
bottom: int,
|
127 |
+
left: int,
|
128 |
+
color: Tuple[int, int, int, int],
|
129 |
+
) -> PIL.Image.Image:
|
130 |
+
"""
|
131 |
+
Adds margin to the image.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
pil_img: Image that needed to add margin.
|
135 |
+
top: pixels count at top side
|
136 |
+
right: pixels count at right side
|
137 |
+
bottom: pixels count at bottom side
|
138 |
+
left: pixels count at left side
|
139 |
+
color: color of margin
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
Image with margin.
|
143 |
+
"""
|
144 |
+
width, height = pil_img.size
|
145 |
+
new_width = width + right + left
|
146 |
+
new_height = height + top + bottom
|
147 |
+
# noinspection PyTypeChecker
|
148 |
+
result = PIL.Image.new(pil_img.mode, (new_width, new_height), color)
|
149 |
+
result.paste(pil_img, (left, top))
|
150 |
+
return result
|
carvekit/utils/mask_utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import PIL.Image
|
7 |
+
import torch
|
8 |
+
from carvekit.utils.image_utils import to_tensor
|
9 |
+
|
10 |
+
|
11 |
+
def composite(
|
12 |
+
foreground: PIL.Image.Image,
|
13 |
+
background: PIL.Image.Image,
|
14 |
+
alpha: PIL.Image.Image,
|
15 |
+
device="cpu",
|
16 |
+
):
|
17 |
+
"""
|
18 |
+
Composites foreground with background by following
|
19 |
+
https://pymatting.github.io/intro.html#alpha-matting math formula.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
device: Processing device
|
23 |
+
foreground: Image that will be pasted to background image with following alpha mask.
|
24 |
+
background: Background image
|
25 |
+
alpha: Alpha Image
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Composited image as PIL.Image instance.
|
29 |
+
"""
|
30 |
+
|
31 |
+
foreground = foreground.convert("RGBA")
|
32 |
+
background = background.convert("RGBA")
|
33 |
+
alpha_rgba = alpha.convert("RGBA")
|
34 |
+
alpha_l = alpha.convert("L")
|
35 |
+
|
36 |
+
fg = to_tensor(foreground).to(device)
|
37 |
+
alpha_rgba = to_tensor(alpha_rgba).to(device)
|
38 |
+
alpha_l = to_tensor(alpha_l).to(device)
|
39 |
+
bg = to_tensor(background).to(device)
|
40 |
+
|
41 |
+
alpha_l = alpha_l / 255
|
42 |
+
alpha_rgba = alpha_rgba / 255
|
43 |
+
|
44 |
+
bg = torch.where(torch.logical_not(alpha_rgba >= 1), bg, fg)
|
45 |
+
bg[:, :, 0] = alpha_l[:, :] * fg[:, :, 0] + (1 - alpha_l[:, :]) * bg[:, :, 0]
|
46 |
+
bg[:, :, 1] = alpha_l[:, :] * fg[:, :, 1] + (1 - alpha_l[:, :]) * bg[:, :, 1]
|
47 |
+
bg[:, :, 2] = alpha_l[:, :] * fg[:, :, 2] + (1 - alpha_l[:, :]) * bg[:, :, 2]
|
48 |
+
bg[:, :, 3] = alpha_l[:, :] * 255
|
49 |
+
|
50 |
+
del alpha_l, alpha_rgba, fg
|
51 |
+
return PIL.Image.fromarray(bg.cpu().numpy()).convert("RGBA")
|
52 |
+
|
53 |
+
|
54 |
+
def apply_mask(
|
55 |
+
image: PIL.Image.Image, mask: PIL.Image.Image, device="cpu"
|
56 |
+
) -> PIL.Image.Image:
|
57 |
+
"""
|
58 |
+
Applies mask to foreground.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
device: Processing device.
|
62 |
+
image: Image with background.
|
63 |
+
mask: Alpha Channel mask for this image.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Image without background, where mask was black.
|
67 |
+
"""
|
68 |
+
background = PIL.Image.new("RGBA", image.size, color=(130, 130, 130, 0))
|
69 |
+
return composite(image, background, mask, device=device).convert("RGBA")
|
70 |
+
|
71 |
+
|
72 |
+
def extract_alpha_channel(image: PIL.Image.Image) -> PIL.Image.Image:
|
73 |
+
"""
|
74 |
+
Extracts alpha channel from the RGBA image.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
image: RGBA PIL image
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
RGBA alpha channel image
|
81 |
+
"""
|
82 |
+
alpha = image.split()[-1]
|
83 |
+
bg = PIL.Image.new("RGBA", image.size, (0, 0, 0, 255))
|
84 |
+
bg.paste(alpha, mask=alpha)
|
85 |
+
return bg.convert("RGBA")
|
carvekit/utils/models_utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
|
7 |
+
import random
|
8 |
+
import warnings
|
9 |
+
from typing import Union, Tuple, Any
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import autocast
|
13 |
+
|
14 |
+
|
15 |
+
class EmptyAutocast(object):
|
16 |
+
"""
|
17 |
+
Empty class for disable any autocasting.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __enter__(self):
|
21 |
+
return None
|
22 |
+
|
23 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
24 |
+
return
|
25 |
+
|
26 |
+
def __call__(self, func):
|
27 |
+
return
|
28 |
+
|
29 |
+
|
30 |
+
def get_precision_autocast(
|
31 |
+
device="cpu", fp16=True, override_dtype=None
|
32 |
+
) -> Union[
|
33 |
+
Tuple[EmptyAutocast, Union[torch.dtype, Any]],
|
34 |
+
Tuple[autocast, Union[torch.dtype, Any]],
|
35 |
+
]:
|
36 |
+
"""
|
37 |
+
Returns precision and autocast settings for given device and fp16 settings.
|
38 |
+
Args:
|
39 |
+
device: Device to get precision and autocast settings for.
|
40 |
+
fp16: Whether to use fp16 precision.
|
41 |
+
override_dtype: Override dtype for autocast.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Autocast object, dtype
|
45 |
+
"""
|
46 |
+
dtype = torch.float32
|
47 |
+
cache_enabled = None
|
48 |
+
|
49 |
+
if device == "cpu" and fp16:
|
50 |
+
warnings.warn('FP16 is not supported on CPU. Using FP32 instead.')
|
51 |
+
dtype = torch.float32
|
52 |
+
|
53 |
+
# TODO: Implement BFP16 on CPU. There are unexpected slowdowns on cpu on a clean environment.
|
54 |
+
# warnings.warn(
|
55 |
+
# "Accuracy BFP16 has experimental support on the CPU. "
|
56 |
+
# "This may result in an unexpected reduction in quality."
|
57 |
+
# )
|
58 |
+
# dtype = (
|
59 |
+
# torch.bfloat16
|
60 |
+
# ) # Using bfloat16 for CPU, since autocast is not supported for float16
|
61 |
+
|
62 |
+
|
63 |
+
if "cuda" in device and fp16:
|
64 |
+
dtype = torch.float16
|
65 |
+
cache_enabled = True
|
66 |
+
|
67 |
+
if override_dtype is not None:
|
68 |
+
dtype = override_dtype
|
69 |
+
|
70 |
+
if dtype == torch.float32 and device == "cpu":
|
71 |
+
return EmptyAutocast(), dtype
|
72 |
+
|
73 |
+
return (
|
74 |
+
torch.autocast(
|
75 |
+
device_type=device, dtype=dtype, enabled=True, cache_enabled=cache_enabled
|
76 |
+
),
|
77 |
+
dtype,
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
def cast_network(network: torch.nn.Module, dtype: torch.dtype):
|
82 |
+
"""Cast network to given dtype
|
83 |
+
|
84 |
+
Args:
|
85 |
+
network: Network to be casted
|
86 |
+
dtype: Dtype to cast network to
|
87 |
+
"""
|
88 |
+
if dtype == torch.float16:
|
89 |
+
network.half()
|
90 |
+
elif dtype == torch.bfloat16:
|
91 |
+
network.bfloat16()
|
92 |
+
elif dtype == torch.float32:
|
93 |
+
network.float()
|
94 |
+
else:
|
95 |
+
raise ValueError(f"Unknown dtype {dtype}")
|
96 |
+
|
97 |
+
|
98 |
+
def fix_seed(seed=42):
|
99 |
+
"""Sets fixed random seed
|
100 |
+
|
101 |
+
Args:
|
102 |
+
seed: Random seed to be set
|
103 |
+
"""
|
104 |
+
random.seed(seed)
|
105 |
+
torch.manual_seed(seed)
|
106 |
+
if torch.cuda.is_available():
|
107 |
+
torch.cuda.manual_seed(seed)
|
108 |
+
torch.cuda.manual_seed_all(seed)
|
109 |
+
# noinspection PyUnresolvedReferences
|
110 |
+
torch.backends.cudnn.deterministic = True
|
111 |
+
# noinspection PyUnresolvedReferences
|
112 |
+
torch.backends.cudnn.benchmark = False
|
113 |
+
return True
|
114 |
+
|
115 |
+
|
116 |
+
def suppress_warnings():
|
117 |
+
# Suppress PyTorch 1.11.0 warning associated with changing order of args in nn.MaxPool2d layer,
|
118 |
+
# since source code is not affected by this issue and there aren't any other correct way to hide this message.
|
119 |
+
warnings.filterwarnings(
|
120 |
+
"ignore",
|
121 |
+
category=UserWarning,
|
122 |
+
message="Note that order of the arguments: ceil_mode and "
|
123 |
+
"return_indices will changeto match the args list "
|
124 |
+
"in nn.MaxPool2d in a future release.",
|
125 |
+
module="torch",
|
126 |
+
)
|
carvekit/utils/pool_utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
from concurrent.futures import ThreadPoolExecutor
|
7 |
+
from typing import Any, Iterable
|
8 |
+
|
9 |
+
|
10 |
+
def thread_pool_processing(func: Any, data: Iterable, workers=18):
|
11 |
+
"""
|
12 |
+
Passes all iterator data through the given function
|
13 |
+
|
14 |
+
Args:
|
15 |
+
workers: Count of workers.
|
16 |
+
func: function to pass data through
|
17 |
+
data: input iterator
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
function return list
|
21 |
+
|
22 |
+
"""
|
23 |
+
with ThreadPoolExecutor(workers) as p:
|
24 |
+
return list(p.map(func, data))
|
25 |
+
|
26 |
+
|
27 |
+
def batch_generator(iterable, n=1):
|
28 |
+
"""
|
29 |
+
Splits any iterable into n-size packets
|
30 |
+
|
31 |
+
Args:
|
32 |
+
iterable: iterator
|
33 |
+
n: size of packets
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
new n-size packet
|
37 |
+
"""
|
38 |
+
it = len(iterable)
|
39 |
+
for ndx in range(0, it, n):
|
40 |
+
yield iterable[ndx : min(ndx + n, it)]
|
carvekit/web/__init__.py
ADDED
File without changes
|
carvekit/web/app.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import uvicorn
|
4 |
+
from fastapi import FastAPI
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
+
from starlette.staticfiles import StaticFiles
|
7 |
+
|
8 |
+
from carvekit import version
|
9 |
+
from carvekit.web.deps import config
|
10 |
+
from carvekit.web.routers.api_router import api_router
|
11 |
+
|
12 |
+
app = FastAPI(title="CarveKit Web API", version=version)
|
13 |
+
|
14 |
+
app.add_middleware(
|
15 |
+
CORSMiddleware,
|
16 |
+
allow_origins=["*"],
|
17 |
+
allow_credentials=True,
|
18 |
+
allow_methods=["*"],
|
19 |
+
allow_headers=["*"],
|
20 |
+
)
|
21 |
+
|
22 |
+
app.include_router(api_router, prefix="/api")
|
23 |
+
app.mount(
|
24 |
+
"/",
|
25 |
+
StaticFiles(directory=Path(__file__).parent.joinpath("static"), html=True),
|
26 |
+
name="static",
|
27 |
+
)
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
uvicorn.run(app, host=config.host, port=config.port)
|
carvekit/web/deps.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from carvekit.web.schemas.config import WebAPIConfig
|
2 |
+
from carvekit.web.utils.init_utils import init_config
|
3 |
+
from carvekit.web.utils.task_queue import MLProcessor
|
4 |
+
|
5 |
+
config: WebAPIConfig = init_config()
|
6 |
+
ml_processor = MLProcessor(api_config=config)
|
carvekit/web/handlers/__init__.py
ADDED
File without changes
|