Podtekatel commited on
Commit
2599800
2 Parent(s): c7ae016 cbd5cde

Update to V2

Browse files
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_url, cached_download
8
+
9
+ from inference.face_detector import StatRetinaFaceDetector
10
+ from inference.model_pipeline import VSNetModelPipeline
11
+ from inference.onnx_model import ONNXModel
12
+
13
+ logging.basicConfig(
14
+ format='%(asctime)s %(levelname)-8s %(message)s',
15
+ level=logging.INFO,
16
+ datefmt='%Y-%m-%d %H:%M:%S')
17
+
18
+ MODEL_IMG_SIZE = 512
19
+ usage_count = 0 # Based on hugging face logs
20
+ def load_model():
21
+ REPO_ID = "Podtekatel/ArcaneVSK2"
22
+ FILENAME_OLD = "arcane_exp_228_ep_159_512_res_V2.onnx"
23
+
24
+ global model_old
25
+ global pipeline_old
26
+
27
+ # Old model
28
+ model_path = cached_download(
29
+ hf_hub_url(REPO_ID, FILENAME_OLD), use_auth_token=os.getenv('HF_TOKEN')
30
+ )
31
+ model_old = ONNXModel(model_path)
32
+
33
+ pipeline_old = VSNetModelPipeline(model_old, StatRetinaFaceDetector(MODEL_IMG_SIZE), background_resize=1024, no_detected_resize=1024)
34
+
35
+ return model_old
36
+ load_model()
37
+
38
+ def inference(img):
39
+ img = np.array(img)
40
+ out_img = pipeline_old(img)
41
+
42
+ out_img = Image.fromarray(out_img)
43
+ global usage_count
44
+ usage_count += 1
45
+ logging.info(f'Usage count is {usage_count}')
46
+ return out_img
47
+
48
+
49
+ title = "ARCNStyleTransferV2"
50
+ description = "Gradio Demo for Arcane Season 1 style transfer. To use it, simply upload your image, or click one of the examples to load them. Press ❤️ if you like this space!"
51
+ article = "This is one of my successful experiments on style transfer. I've built my own pipeline, generator model and private dataset to train this model<br>" \
52
+ "" \
53
+ "" \
54
+ "" \
55
+ "Model pipeline which used in project is improved CartoonGAN.<br>" \
56
+ "This model was trained on RTX 2080 Ti 3 days with batch size 7.<br>" \
57
+ "Model weights 80 MB in ONNX fp32 format, infers 100 ms on GPU and 600 ms on CPU at 512x512 resolution.<br>" \
58
+ "If you want to use this app or integrate this model into yours, please contact me at email '[email protected]'."
59
+
60
+ imgs_folder = 'demo'
61
+ examples = [[os.path.join(imgs_folder, img_filename)] for img_filename in sorted(os.listdir(imgs_folder))]
62
+
63
+ demo = gr.Interface(
64
+ fn=inference,
65
+ inputs=[gr.inputs.Image(type="pil")],
66
+ outputs=gr.outputs.Image(type="pil"),
67
+ title=title,
68
+ description=description,
69
+ article=article,
70
+ examples=examples)
71
+ demo.queue(concurrency_count=1)
72
+ demo.launch()
demo/gates.png ADDED
demo/jack_sparrow.jpeg ADDED
demo/kianu.jpg ADDED
demo/squid_game.jpeg ADDED
hf_download.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from huggingface_hub import hf_hub_url, cached_download
3
+ import joblib
4
+
5
+ REPO_ID = "MalchuL/JJBAGAN"
6
+ FILENAME = "198_jjba_8_k_2_099_ep.onnx"
7
+
8
+ model = cached_download(
9
+ hf_hub_url(REPO_ID, FILENAME)
10
+ )
11
+ print(model)
12
+
13
+ import onnxruntime
14
+ ort_session = onnxruntime.InferenceSession(str(model))
15
+ input_name = ort_session.get_inputs()[0].name
16
+ ort_inputs = {input_name: np.random.randn(1, 3, 256, 256).astype(dtype=np.float32)}
17
+ ort_outs = ort_session.run(None, ort_inputs)
18
+ print(ort_outs)
inference/__init__.py ADDED
File without changes
inference/box_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def convert_to_square(bboxes):
5
+ """Convert bounding boxes to a square form.
6
+ Arguments:
7
+ bboxes: a float numpy array of shape [n, 4].
8
+ Returns:
9
+ a float numpy array of shape [4],
10
+ squared bounding boxes.
11
+ """
12
+
13
+ square_bboxes = np.zeros_like(bboxes)
14
+ x1, y1, x2, y2 = bboxes
15
+ h = y2 - y1 + 1.0
16
+ w = x2 - x1 + 1.0
17
+ max_side = np.maximum(h, w)
18
+ square_bboxes[0] = x1 + w * 0.5 - max_side * 0.5
19
+ square_bboxes[1] = y1 + h * 0.5 - max_side * 0.5
20
+ square_bboxes[2] = square_bboxes[0] + max_side - 1.0
21
+ square_bboxes[3] = square_bboxes[1] + max_side - 1.0
22
+ return square_bboxes
23
+
24
+
25
+ def scale_box(box, scale):
26
+ x1, y1, x2, y2 = box
27
+ center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
28
+ w, h = x2 - x1, y2 - y1
29
+ new_w, new_h = w * scale, h * scale
30
+ y1, y2, x1, x2 = center_y - new_h / 2, center_y + new_h / 2, center_x - new_w / 2, center_x + new_w / 2,
31
+ return np.array((x1, y1, x2, y2))
inference/center_crop.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ # From albumentations
5
+ def center_crop(img: np.ndarray, crop_height: int, crop_width: int):
6
+ height, width = img.shape[:2]
7
+ if height < crop_height or width < crop_width:
8
+ raise ValueError(
9
+ "Requested crop size ({crop_height}, {crop_width}) is "
10
+ "larger than the image size ({height}, {width})".format(
11
+ crop_height=crop_height, crop_width=crop_width, height=height, width=width
12
+ )
13
+ )
14
+ x1, y1, x2, y2 = get_center_crop_coords(height, width, crop_height, crop_width)
15
+ img = img[y1:y2, x1:x2]
16
+ return img
17
+
18
+
19
+ def get_center_crop_coords(height: int, width: int, crop_height: int, crop_width: int):
20
+ y1 = (height - crop_height) // 2
21
+ y2 = y1 + crop_height
22
+ x1 = (width - crop_width) // 2
23
+ x2 = x1 + crop_width
24
+ return x1, y1, x2, y2
inference/face_detector.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from typing import List
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from retinaface import RetinaFace
8
+ from retinaface.model import retinaface_model
9
+
10
+ from .box_utils import convert_to_square
11
+
12
+
13
+ class FaceDetector(ABC):
14
+ def __init__(self, target_size):
15
+ self.target_size = target_size
16
+ @abstractmethod
17
+ def detect_crops(self, img, *args, **kwargs) -> List[np.ndarray]:
18
+ """
19
+ Img is a numpy ndarray in range [0..255], uint8 dtype, RGB type
20
+ Returns ndarray with [x1, y1, x2, y2] in row
21
+ """
22
+ pass
23
+
24
+ @abstractmethod
25
+ def postprocess_crops(self, crops, *args, **kwargs) -> List[np.ndarray]:
26
+ return crops
27
+
28
+ def sort_faces(self, crops):
29
+ sorted_faces = sorted(crops, key=lambda x: -(x[2] - x[0]) * (x[3] - x[1]))
30
+ sorted_faces = np.stack(sorted_faces, axis=0)
31
+ return sorted_faces
32
+
33
+ def fix_range_crops(self, img, crops):
34
+ H, W, _ = img.shape
35
+ final_crops = []
36
+ for crop in crops:
37
+ x1, y1, x2, y2 = crop
38
+ x1 = max(min(round(x1), W), 0)
39
+ y1 = max(min(round(y1), H), 0)
40
+ x2 = max(min(round(x2), W), 0)
41
+ y2 = max(min(round(y2), H), 0)
42
+ new_crop = [x1, y1, x2, y2]
43
+ final_crops.append(new_crop)
44
+ final_crops = np.array(final_crops, dtype=np.int)
45
+ return final_crops
46
+
47
+ def crop_faces(self, img, crops) -> List[np.ndarray]:
48
+ cropped_faces = []
49
+ for crop in crops:
50
+ x1, y1, x2, y2 = crop
51
+ face_crop = img[y1:y2, x1:x2, :]
52
+ cropped_faces.append(face_crop)
53
+ return cropped_faces
54
+
55
+ def unify_and_merge(self, cropped_images):
56
+ return cropped_images
57
+
58
+ def __call__(self, img):
59
+ return self.detect_faces(img)
60
+
61
+ def detect_faces(self, img):
62
+ crops = self.detect_crops(img)
63
+ if crops is None or len(crops) == 0:
64
+ return [], []
65
+ crops = self.sort_faces(crops)
66
+ updated_crops = self.postprocess_crops(crops)
67
+ updated_crops = self.fix_range_crops(img, updated_crops)
68
+ cropped_faces = self.crop_faces(img, updated_crops)
69
+ unified_faces = self.unify_and_merge(cropped_faces)
70
+ return unified_faces, updated_crops
71
+
72
+
73
+ class StatRetinaFaceDetector(FaceDetector):
74
+ def __init__(self, target_size=None):
75
+ super().__init__(target_size)
76
+ self.model = retinaface_model.build_model()
77
+ #self.relative_offsets = [0.3258, 0.5225, 0.3258, 0.1290]
78
+ self.relative_offsets = [0.3619, 0.5830, 0.3619, 0.1909]
79
+
80
+ def postprocess_crops(self, crops, *args, **kwargs) -> np.ndarray:
81
+ final_crops = []
82
+ x1_offset, y1_offset, x2_offset, y2_offset = self.relative_offsets
83
+ for crop in crops:
84
+ x1, y1, x2, y2 = crop
85
+ w, h = x2 - x1, y2 - y1
86
+ x1 -= w * x1_offset
87
+ y1 -= h * y1_offset
88
+ x2 += w * x2_offset
89
+ y2 += h * y2_offset
90
+ crop = np.array([x1, y1, x2, y2], dtype=crop.dtype)
91
+ crop = convert_to_square(crop)
92
+ final_crops.append(crop)
93
+ final_crops = np.stack(final_crops, axis=0)
94
+ return final_crops
95
+
96
+ def detect_crops(self, img, *args, **kwargs):
97
+ faces = RetinaFace.detect_faces(img, model=self.model)
98
+ crops = []
99
+ if isinstance(faces, tuple):
100
+ faces = {}
101
+ for name, face in faces.items():
102
+ x1, y1, x2, y2 = face['facial_area']
103
+ crop = np.array([x1, y1, x2, y2])
104
+ crops.append(crop)
105
+ if len(crops) > 0:
106
+ crops = np.stack(crops, axis=0)
107
+ return crops
108
+
109
+ def unify_and_merge(self, cropped_images):
110
+ if self.target_size is None:
111
+ return cropped_images
112
+ else:
113
+ resized_images = []
114
+ for cropped_image in cropped_images:
115
+ resized_image = cv2.resize(cropped_image, (self.target_size, self.target_size),
116
+ interpolation=cv2.INTER_LINEAR)
117
+ resized_images.append(resized_image)
118
+
119
+ resized_images = np.stack(resized_images, axis=0)
120
+ return resized_images
121
+
inference/model_pipeline.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from .center_crop import center_crop
8
+ from .face_detector import FaceDetector
9
+
10
+
11
+ class VSNetModelPipeline:
12
+ def __init__(self, model, face_detector: FaceDetector, background_resize=720, no_detected_resize=256, use_cloning=True):
13
+ self.background_resize = background_resize
14
+ self.no_detected_resize = no_detected_resize
15
+ self.model = model
16
+ self.face_detector = face_detector
17
+ self.mask = self.create_circular_mask(face_detector.target_size, face_detector.target_size)
18
+ self.use_cloning = use_cloning
19
+
20
+ @staticmethod
21
+ def create_circular_mask(h, w, power=None, clipping_coef=0.85):
22
+ center = (int(w / 2), int(h / 2))
23
+
24
+ Y, X = np.ogrid[:h, :w]
25
+ dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
26
+ print(dist_from_center.max(), dist_from_center.min())
27
+ clipping_radius = min((h - center[0]), (w - center[1])) * clipping_coef
28
+ max_size = max((h - center[0]), (w - center[1]))
29
+ dist_from_center[dist_from_center < clipping_radius] = clipping_radius
30
+ dist_from_center[dist_from_center > max_size] = max_size
31
+ max_distance, min_distance = np.max(dist_from_center), np.min(dist_from_center)
32
+ dist_from_center = 1 - (dist_from_center - min_distance) / (max_distance - min_distance)
33
+ if power is not None:
34
+ dist_from_center = np.power(dist_from_center, power)
35
+ dist_from_center = np.stack([dist_from_center] * 3, axis=2)
36
+ # mask = dist_from_center <= radius
37
+ return dist_from_center
38
+
39
+
40
+ @staticmethod
41
+ def resize_size(image, size=720, always_apply=True):
42
+ h, w, c = np.shape(image)
43
+ if min(h, w) > size or always_apply:
44
+ if h < w:
45
+ h, w = int(size * h / w), size
46
+ else:
47
+ h, w = size, int(size * w / h)
48
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
49
+ return image
50
+
51
+ def normalize(self, img):
52
+ img = img.astype(np.float32) / 255 * 2 - 1
53
+ return img
54
+
55
+ def denormalize(self, img):
56
+ return (img + 1) / 2
57
+
58
+ def divide_crop(self, img, must_divided=32):
59
+ h, w, _ = img.shape
60
+ h = h // must_divided * must_divided
61
+ w = w // must_divided * must_divided
62
+
63
+ img = center_crop(img, h, w)
64
+ return img
65
+
66
+ def merge_crops(self, faces_imgs, crops, full_image):
67
+ for face, crop in zip(faces_imgs, crops):
68
+ x1, y1, x2, y2 = crop
69
+ W, H = x2 - x1, y2 - y1
70
+ result_face = cv2.resize(face, (W, H), interpolation=cv2.INTER_LINEAR)
71
+ face_mask = cv2.resize(self.mask, (W, H), interpolation=cv2.INTER_LINEAR)
72
+ if self.use_cloning:
73
+ center = round((x2 + x1) / 2), round((y2 + y1) / 2)
74
+ full_image = cv2.seamlessClone(result_face, full_image, (face_mask > 0.0).astype(np.uint8) * 255, center, cv2.NORMAL_CLONE)
75
+ else:
76
+ input_face = full_image[y1: y2, x1: x2]
77
+ full_image[y1: y2, x1: x2] = (result_face * face_mask + input_face * (1 - face_mask)).astype(np.uint8)
78
+ return full_image
79
+
80
+ def __call__(self, img):
81
+ return self.process_image(img)
82
+
83
+ def process_image(self, img):
84
+ img = self.resize_size(img, size=self.background_resize)
85
+ img = self.divide_crop(img)
86
+
87
+ face_crops, coords = self.face_detector(img)
88
+
89
+ if len(face_crops) > 0:
90
+ start_time = time.time()
91
+ faces = self.normalize(face_crops)
92
+ faces = faces.transpose(0, 3, 1, 2)
93
+ out_faces = self.model(faces)
94
+ out_faces = self.denormalize(out_faces)
95
+ out_faces = out_faces.transpose(0, 2, 3, 1)
96
+ out_faces = np.clip(out_faces * 255, 0, 255).astype(np.uint8)
97
+ end_time = time.time()
98
+ logging.info(f'Face FPS {1 / (end_time - start_time)}')
99
+ else:
100
+ out_faces = []
101
+ img = self.resize_size(img, size=self.no_detected_resize)
102
+ img = self.divide_crop(img)
103
+
104
+ start_time = time.time()
105
+ full_image = self.normalize(img)
106
+ full_image = np.expand_dims(full_image, 0).transpose(0, 3, 1, 2)
107
+ full_image = self.model(full_image)
108
+ full_image = self.denormalize(full_image)
109
+ full_image = full_image.transpose(0, 2, 3, 1)
110
+ full_image = np.clip(full_image * 255, 0, 255).astype(np.uint8)
111
+ end_time = time.time()
112
+ logging.info(f'Background FPS {1 / (end_time - start_time)}')
113
+
114
+ result_image = self.merge_crops(out_faces, coords, full_image[0])
115
+ return result_image
inference/onnx_model.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime
3
+
4
+
5
+ class ONNXModel:
6
+ def __init__(self, onnx_mode_path):
7
+ self.path = onnx_mode_path
8
+ self.ort_session = onnxruntime.InferenceSession(str(self.path))
9
+ self.input_name = self.ort_session.get_inputs()[0].name
10
+
11
+ def __call__(self, img):
12
+ ort_inputs = {self.input_name: img.astype(dtype=np.float32)}
13
+ ort_outs = self.ort_session.run(None, ort_inputs)[0]
14
+ return ort_outs
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ huggingface_hub
2
+ onnxruntime
3
+ numpy
4
+ gradio
5
+ retina-face