Spaces:
Build error
Build error
from pathlib import Path | |
from rembg import remove | |
import io | |
# Apply the transformations needed | |
from torch import autocast, nn | |
import torch | |
import torch.nn as nn | |
import torch | |
import torchvision.transforms as transforms | |
import torchvision.utils as utils | |
import torch.nn as nn | |
import pyrootutils | |
from PIL import Image | |
import numpy as np | |
from utils.photo_wct import PhotoWCT | |
from utils.photo_smooth import Propagator | |
#from utils.smooth_filter import smooth_filter | |
# Load models | |
root = Path.cwd() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load model | |
p_wct = PhotoWCT() | |
p_wct.load_state_dict(torch.load(root/"models/components/photo_wct.pth")) | |
p_pro = Propagator() | |
stylization_module=p_wct | |
smoothing_module=p_pro | |
#Dependecies - To be installed - | |
#!pip install replicate | |
#Token - To be authenticated - | |
#API TOKEN - 664474670af075461f85420f7b1d23d18484f826 | |
#To be declared as an environment variable - | |
#export REPLICATE_API_TOKEN = | |
import replicate | |
import os | |
import requests | |
def stableDiffusionAPICall(text_prompt): | |
os.environ['REPLICATE_API_TOKEN'] = 'a9f4c06cb9808f42b29637bb60b7b88f106ad5b8' | |
model = replicate.models.get("stability-ai/stable-diffusion") | |
#text_prompt = 'photorealistic, elf fighting Sauron' | |
gen_bg_img = model.predict(prompt=text_prompt)[0] | |
img_data = requests.get(gen_bg_img).content | |
# r_data = binascii.unhexlify(img_data) | |
stream = io.BytesIO(img_data) | |
img = Image.open(stream) | |
del img_data | |
return img | |
def memory_limit_image_resize(cont_img): | |
# prevent too small or too big images | |
MINSIZE=400 | |
MAXSIZE=800 | |
orig_width = cont_img.width | |
orig_height = cont_img.height | |
if max(cont_img.width,cont_img.height) < MINSIZE: | |
if cont_img.width > cont_img.height: | |
cont_img.thumbnail((int(cont_img.width*1.0/cont_img.height*MINSIZE), MINSIZE), Image.BICUBIC) | |
else: | |
cont_img.thumbnail((MINSIZE, int(cont_img.height*1.0/cont_img.width*MINSIZE)), Image.BICUBIC) | |
if min(cont_img.width,cont_img.height) > MAXSIZE: | |
if cont_img.width > cont_img.height: | |
cont_img.thumbnail((MAXSIZE, int(cont_img.height*1.0/cont_img.width*MAXSIZE)), Image.BICUBIC) | |
else: | |
cont_img.thumbnail(((int(cont_img.width*1.0/cont_img.height*MAXSIZE), MAXSIZE)), Image.BICUBIC) | |
print("Resize image: (%d,%d)->(%d,%d)" % (orig_width, orig_height, cont_img.width, cont_img.height)) | |
return cont_img.width, cont_img.height | |
def superimpose(input_img,back_img): | |
matte_img = remove(input_img) | |
back_img.paste(matte_img, (0, 0), matte_img) | |
return back_img,input_img | |
def style_transfer(cont_img,styl_img): | |
with torch.no_grad(): | |
new_cw, new_ch = memory_limit_image_resize(cont_img) | |
new_sw, new_sh = memory_limit_image_resize(styl_img) | |
cont_pilimg = cont_img.copy() | |
cw = cont_pilimg.width | |
ch = cont_pilimg.height | |
cont_img = transforms.ToTensor()(cont_img).unsqueeze(0) | |
styl_img = transforms.ToTensor()(styl_img).unsqueeze(0) | |
cont_seg = [] | |
styl_seg = [] | |
if device == 'cuda': | |
cont_img = cont_img.to(device) | |
styl_img = styl_img.to(device) | |
stylization_module.to(device) | |
cont_seg = np.asarray(cont_seg) | |
styl_seg = np.asarray(styl_seg) | |
stylized_img = stylization_module.transform(cont_img, styl_img, cont_seg, styl_seg) | |
if ch != new_ch or cw != new_cw: | |
stylized_img = nn.functional.upsample(stylized_img, size=(ch, cw), mode='bilinear') | |
grid = utils.make_grid(stylized_img.data, nrow=1, padding=0) | |
ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() | |
stylized_img = Image.fromarray(ndarr) | |
#final_img = smooth_filter(stylized_img, cont_pilimg, f_radius=15, f_edge=1e-1) | |
return stylized_img | |
def smoother(stylized_img, over_img): | |
final_img = smoothing_module.process(stylized_img, over_img) | |
#final_img = smooth_filter(stylized_img, over_img, f_radius=15, f_edge=1e-1) | |
return final_img | |
if __name__ == "__main__": | |
root = pyrootutils.setup_root(__file__, pythonpath=True) | |
fg_path = root/"notebooks/profile_new.png" | |
bg_path = root/"notebooks/back_img.png" | |
ckpt_path = root/"src/models/MODNet/pretrained/modnet_photographic_portrait_matting.ckpt" | |
#stableDiffusionAPICall("Photorealistic scenery of a concert") | |
fg_img = Image.open(fg_path).resize((800,800)) | |
bg_img = Image.open(bg_path).resize((800,800)) | |
#img = combined_display(fg_img, bg_img,ckpt_path) | |
img = superimpose(fg_img,bg_img) | |
img.save(root/"notebooks/overlay.png") | |
# bg_img.paste(img, (0, 0), img) | |
# bg_img.save(root/"notebooks/check.png") | |