File size: 4,607 Bytes
4158574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5842b4a
4158574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d8d63c
4158574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from pathlib import Path
from rembg import remove
import io

# Apply the transformations needed
from torch import autocast, nn
import torch
import torch.nn as nn
import torch
import torchvision.transforms as transforms
import torchvision.utils as utils
import torch.nn as nn
import pyrootutils
from PIL import Image
import numpy as np
from utils.photo_wct import PhotoWCT
from utils.photo_smooth import Propagator

# Load models
root = Path.cwd()
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
p_wct = PhotoWCT()
p_wct.load_state_dict(torch.load(root/"models/components/photo_wct.pth"))
p_pro = Propagator()
stylization_module=p_wct
smoothing_module=p_pro


#Dependecies - To be installed -
#!pip install replicate
#Token - To be authenticated -
#API TOKEN - 664474670af075461f85420f7b1d23d18484f826
#To be declared as an environment variable - 
#export REPLICATE_API_TOKEN = 
import replicate
import os
import requests



def stableDiffusionAPICall(text_prompt):
    os.environ['REPLICATE_API_TOKEN'] = 'a9f4c06cb9808f42b29637bb60b7b88f106ad5b8'
    model = replicate.models.get("stability-ai/stable-diffusion")
    #text_prompt = 'photorealistic, elf fighting Sauron'
    gen_bg_img = model.predict(prompt=text_prompt)[0]
    img_data = requests.get(gen_bg_img).content
    # r_data = binascii.unhexlify(img_data)
    stream = io.BytesIO(img_data)
    img = Image.open(stream)
    del img_data
    
    return img



def memory_limit_image_resize(cont_img):
    # prevent too small or too big images
    MINSIZE=400
    MAXSIZE=800
    orig_width = cont_img.width
    orig_height = cont_img.height
    if max(cont_img.width,cont_img.height) < MINSIZE:
        if cont_img.width > cont_img.height:
            cont_img.thumbnail((int(cont_img.width*1.0/cont_img.height*MINSIZE), MINSIZE), Image.BICUBIC)
        else:
            cont_img.thumbnail((MINSIZE, int(cont_img.height*1.0/cont_img.width*MINSIZE)), Image.BICUBIC)
    if min(cont_img.width,cont_img.height) > MAXSIZE:
        if cont_img.width > cont_img.height:
            cont_img.thumbnail((MAXSIZE, int(cont_img.height*1.0/cont_img.width*MAXSIZE)), Image.BICUBIC)
        else:
            cont_img.thumbnail(((int(cont_img.width*1.0/cont_img.height*MAXSIZE), MAXSIZE)), Image.BICUBIC)
    print("Resize image: (%d,%d)->(%d,%d)" % (orig_width, orig_height, cont_img.width, cont_img.height))
    return cont_img.width, cont_img.height





def superimpose(input_img,back_img):
    matte_img = remove(input_img)
    back_img.paste(matte_img, (0, 0), matte_img)
    return back_img,input_img



def style_transfer(cont_img,styl_img):
    with torch.no_grad():
        new_cw, new_ch = memory_limit_image_resize(cont_img)
        new_sw, new_sh = memory_limit_image_resize(styl_img)
        cont_pilimg = cont_img.copy()
        cw = cont_pilimg.width
        ch = cont_pilimg.height
        cont_img = transforms.ToTensor()(cont_img).unsqueeze(0)
        styl_img = transforms.ToTensor()(styl_img).unsqueeze(0)

        cont_seg = []
        styl_seg = []

        if device == 'cuda':
            cont_img = cont_img.to(device)
            styl_img = styl_img.to(device)
            stylization_module.to(device)
        cont_seg = np.asarray(cont_seg)
        styl_seg = np.asarray(styl_seg)

        stylized_img = stylization_module.transform(cont_img, styl_img, cont_seg, styl_seg)
        if ch != new_ch or cw != new_cw:
            stylized_img = nn.functional.upsample(stylized_img, size=(ch, cw), mode='bilinear')
        grid = utils.make_grid(stylized_img.data, nrow=1, padding=0)
        ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
        stylized_img = Image.fromarray(ndarr)
        #final_img = smooth_filter(stylized_img, cont_pilimg, f_radius=15, f_edge=1e-1)
    return stylized_img

def smoother(stylized_img, over_img):
    final_img = smoothing_module.process(stylized_img, over_img)
    return final_img


if __name__ == "__main__":
    root = pyrootutils.setup_root(__file__, pythonpath=True)
    fg_path = root/"notebooks/profile_new.png"
    bg_path = root/"notebooks/back_img.png"
    ckpt_path = root/"src/models/MODNet/pretrained/modnet_photographic_portrait_matting.ckpt"

    #stableDiffusionAPICall("Photorealistic scenery of a concert")
    fg_img = Image.open(fg_path).resize((800,800))
    bg_img = Image.open(bg_path).resize((800,800))
    #img = combined_display(fg_img, bg_img,ckpt_path)
    img = superimpose(fg_img,bg_img)
    img.save(root/"notebooks/overlay.png")
    # bg_img.paste(img, (0, 0), img)
    # bg_img.save(root/"notebooks/check.png")