File size: 4,233 Bytes
146d3ac
 
 
 
 
 
 
 
 
 
ab42d96
 
146d3ac
 
 
 
 
 
 
72c2d65
 
 
 
 
 
 
 
 
146d3ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72c2d65
 
 
 
146d3ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b8bbee
146d3ac
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
from huggingface_hub import hf_hub_download
os.system("pip -qq install facenet_pytorch")
from facenet_pytorch import MTCNN
from torchvision import transforms
import torch, PIL
from tqdm.notebook import tqdm
import gradio as gr
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu"

image_size = 512

means = [0.5, 0.5, 0.5]
stds = [0.5, 0.5, 0.5]

model_path = hf_hub_download(repo_id="jjeamin/ArcaneStyleTransfer", filename="pytorch_model.bin")

if 'cuda' in device:
    style_transfer = torch.jit.load(model_path).eval().cuda().half()
    t_stds = torch.tensor(stds).cuda().half()[:,None,None]
    t_means = torch.tensor(means).cuda().half()[:,None,None]
else:
    style_transfer = torch.jit.load(model_path).eval().cpu()
    t_stds = torch.tensor(stds).cpu()[:,None,None]
    t_means = torch.tensor(means).cpu()[:,None,None]
    
mtcnn = MTCNN(image_size=image_size, margin=80)

def detect(img):
 
        # Detect faces
        batch_boxes, batch_probs, batch_points = mtcnn.detect(img, landmarks=True)
        # Select faces
        if not mtcnn.keep_all:
            batch_boxes, batch_probs, batch_points = mtcnn.select_boxes(
                batch_boxes, batch_probs, batch_points, img, method=mtcnn.selection_method
            )
 
        return batch_boxes, batch_points

def makeEven(_x):
  return _x if (_x % 2 == 0) else _x+1

def scale(boxes, _img, max_res=1_500_000, target_face=256, fixed_ratio=0, max_upscale=2, VERBOSE=False):
 
    x, y = _img.size
 
    ratio = 2 #initial ratio
 
    #scale to desired face size
    if (boxes is not None):
      if len(boxes)>0:
        ratio = target_face/max(boxes[0][2:]-boxes[0][:2]); 
        ratio = min(ratio, max_upscale)
        if VERBOSE: print('up by', ratio)

    if fixed_ratio>0:
      if VERBOSE: print('fixed ratio')
      ratio = fixed_ratio
 
    x*=ratio
    y*=ratio
 
    #downscale to fit into max res 
    res = x*y
    if res > max_res:
      ratio = pow(res/max_res,1/2); 
      if VERBOSE: print(ratio)
      x=int(x/ratio)
      y=int(y/ratio)
 
    #make dimensions even, because usually NNs fail on uneven dimensions due skip connection size mismatch
    x = makeEven(int(x))
    y = makeEven(int(y))
    
    size = (x, y)

    return _img.resize(size)

def scale_by_face_size(_img, max_res=1_500_000, target_face=256, fix_ratio=0, max_upscale=2, VERBOSE=False):
    boxes = None
    boxes, _ = detect(_img)
    if VERBOSE: print('boxes',boxes)
    img_resized = scale(boxes, _img, max_res, target_face, fix_ratio, max_upscale, VERBOSE)
    return img_resized


img_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(means, stds)])
 
def tensor2im(var):
     return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)

def proc_pil_img(input_image):
    if 'cuda' in device: 
        transformed_image = img_transforms(input_image)[None,...].cuda().half()
    else:
        transformed_image = img_transforms(input_image)[None,...].cpu()
            
    with torch.no_grad():
        result_image = style_transfer(transformed_image)[0]
        output_image = tensor2im(result_image)
        output_image = output_image.detach().cpu().numpy().astype('uint8')
        output_image = PIL.Image.fromarray(output_image)
    return output_image

def process(im):
    im = scale_by_face_size(im, target_face=image_size, max_res=1_500_000, max_upscale=1)
    res = proc_pil_img(im)
    return res
        
gr.Interface(
    process, 
    inputs=gr.inputs.Image(type="pil", label="Input", shape=(image_size, image_size)),
    outputs=gr.outputs.Image(type="pil", label="Output"),
    title="Arcane Style Transfer",
    description="Gradio demo for Arcane Style Transfer",
    article = "<p style='text-align: center'><a href='https://github.com/jjeamin/anime_style_transfer_pytorch' target='_blank'>Github Repo Pytorch by jjeamin</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=jjeamin_arcane_st' alt='visitor badge'></center></p>",
    examples=[['billie.png'], ['gongyoo.jpeg'], ['IU.png'], ['elon.png']],
    enable_queue=True,
    allow_flagging=False,
    allow_screenshot=False
    ).launch(enable_queue=True,cache_examples=True)