jjeamin commited on
Commit
146d3ac
1 Parent(s): 140cd2a
Files changed (2) hide show
  1. app.py +115 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+ os.system("pip -qq install facenet_pytorch")
4
+ from facenet_pytorch import MTCNN
5
+ from torchvision import transforms
6
+ import torch, PIL
7
+ from tqdm.notebook import tqdm
8
+ import gradio as gr
9
+ import torch
10
+
11
+ image_size = 512
12
+
13
+ means = [0.5, 0.5, 0.5]
14
+ stds = [0.5, 0.5, 0.5]
15
+
16
+ model_path = hf_hub_download(repo_id="jjeamin/ArcaneStyleTransfer", filename="pytorch_model.bin")
17
+ style_transfer = torch.jit.load(model_path).eval().cuda().half()
18
+
19
+ mtcnn = MTCNN(image_size=image_size, margin=80)
20
+
21
+ def detect(img):
22
+
23
+ # Detect faces
24
+ batch_boxes, batch_probs, batch_points = mtcnn.detect(img, landmarks=True)
25
+ # Select faces
26
+ if not mtcnn.keep_all:
27
+ batch_boxes, batch_probs, batch_points = mtcnn.select_boxes(
28
+ batch_boxes, batch_probs, batch_points, img, method=mtcnn.selection_method
29
+ )
30
+
31
+ return batch_boxes, batch_points
32
+
33
+ def makeEven(_x):
34
+ return _x if (_x % 2 == 0) else _x+1
35
+
36
+ def scale(boxes, _img, max_res=1_500_000, target_face=256, fixed_ratio=0, max_upscale=2, VERBOSE=False):
37
+
38
+ x, y = _img.size
39
+
40
+ ratio = 2 #initial ratio
41
+
42
+ #scale to desired face size
43
+ if (boxes is not None):
44
+ if len(boxes)>0:
45
+ ratio = target_face/max(boxes[0][2:]-boxes[0][:2]);
46
+ ratio = min(ratio, max_upscale)
47
+ if VERBOSE: print('up by', ratio)
48
+
49
+ if fixed_ratio>0:
50
+ if VERBOSE: print('fixed ratio')
51
+ ratio = fixed_ratio
52
+
53
+ x*=ratio
54
+ y*=ratio
55
+
56
+ #downscale to fit into max res
57
+ res = x*y
58
+ if res > max_res:
59
+ ratio = pow(res/max_res,1/2);
60
+ if VERBOSE: print(ratio)
61
+ x=int(x/ratio)
62
+ y=int(y/ratio)
63
+
64
+ #make dimensions even, because usually NNs fail on uneven dimensions due skip connection size mismatch
65
+ x = makeEven(int(x))
66
+ y = makeEven(int(y))
67
+
68
+ size = (x, y)
69
+
70
+ return _img.resize(size)
71
+
72
+ def scale_by_face_size(_img, max_res=1_500_000, target_face=256, fix_ratio=0, max_upscale=2, VERBOSE=False):
73
+ boxes = None
74
+ boxes, _ = detect(_img)
75
+ if VERBOSE: print('boxes',boxes)
76
+ img_resized = scale(boxes, _img, max_res, target_face, fix_ratio, max_upscale, VERBOSE)
77
+ return img_resized
78
+
79
+ t_stds = torch.tensor(stds).cuda().half()[:,None,None]
80
+ t_means = torch.tensor(means).cuda().half()[:,None,None]
81
+
82
+ img_transforms = transforms.Compose([
83
+ transforms.ToTensor(),
84
+ transforms.Normalize(means, stds)])
85
+
86
+ def tensor2im(var):
87
+ return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
88
+
89
+ def proc_pil_img(input_image):
90
+ transformed_image = img_transforms(input_image)[None,...].cuda().half()
91
+
92
+ with torch.no_grad():
93
+ result_image = style_transfer(transformed_image)[0]
94
+ output_image = tensor2im(result_image)
95
+ output_image = output_image.detach().cpu().numpy().astype('uint8')
96
+ output_image = PIL.Image.fromarray(output_image)
97
+ return output_image
98
+
99
+ def process(im):
100
+ im = scale_by_face_size(im, target_face=image_size, max_res=1_500_000, max_upscale=1)
101
+ res = proc_pil_img(im)
102
+ return res
103
+
104
+ gr.Interface(
105
+ process,
106
+ inputs=gr.inputs.Image(type="pil", label="Input", shape=(image_size, image_size)),
107
+ outputs=gr.outputs.Image(type="pil", label="Output"),
108
+ title="Arcane Style Transfer",
109
+ description="Gradio demo for Arcane Style Transfer",
110
+ article = "<p style='text-align: center'><a href='https://github.com/jjeamin/anime_style_transfer_pytorch' target='_blank'>Github Repo Pytorch by jjeamin</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=jjeamin_arcane_st' alt='visitor badge'></center></p>",
111
+ examples=[['billie.png'], ['gongyoo.jpeg'], ['tony.png'], ['will.png']],
112
+ enable_queue=True,
113
+ allow_flagging=False,
114
+ allow_screenshot=False
115
+ ).launch(enable_queue=True,cache_examples=True)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ gdown
5
+ numpy
6
+ scipy
7
+ cmake
8
+ onnxruntime-gpu
9
+ opencv-python-headless