Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- app.py +28 -30
- inference.py +3 -5
- requirements.txt +0 -1
- utils.py +2 -5
app.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
from datetime import datetime, timezone, timedelta
|
4 |
-
from concurrent.futures import ThreadPoolExecutor
|
5 |
|
6 |
import spaces
|
7 |
import torch
|
8 |
import torch.optim as optim
|
9 |
-
import torchvision.models as models
|
10 |
import numpy as np
|
11 |
import gradio as gr
|
12 |
from safetensors.torch import load_file
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
|
15 |
-
from utils import preprocess_img,
|
16 |
from vgg.vgg19 import VGG_19
|
17 |
from u2net.model import U2Net
|
18 |
from inference import inference
|
@@ -20,8 +18,8 @@ from inference import inference
|
|
20 |
if torch.cuda.is_available(): device = 'cuda'
|
21 |
elif torch.backends.mps.is_available(): device = 'mps'
|
22 |
else: device = 'cpu'
|
23 |
-
print('
|
24 |
-
if device == 'cuda': print('
|
25 |
|
26 |
def load_model_without_module(model, model_path):
|
27 |
state_dict = load_file(model_path, device=device)
|
@@ -40,41 +38,42 @@ local_model_path = hf_hub_download(repo_id='jamino30/u2net-saliency', filename='
|
|
40 |
load_model_without_module(sod_model, local_model_path)
|
41 |
|
42 |
style_files = os.listdir('./style_images')
|
43 |
-
style_options = {
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
img_size = 512
|
46 |
|
47 |
-
# store style(s) features
|
48 |
cached_style_features = {}
|
49 |
for style_name, style_img_path in style_options.items():
|
50 |
-
style_img =
|
51 |
with torch.no_grad():
|
52 |
style_features = model(style_img)
|
53 |
cached_style_features[style_name] = style_features
|
54 |
|
55 |
@spaces.GPU(duration=30)
|
56 |
-
def run(content_image, style_name, style_strength=
|
57 |
yield None
|
58 |
content_img, original_size = preprocess_img(content_image, img_size)
|
59 |
content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
|
60 |
content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
|
61 |
style_features = cached_style_features[style_name]
|
62 |
|
63 |
-
if optim_name == '
|
64 |
-
optim_caller =
|
65 |
-
iterations = 101
|
66 |
-
elif optim_name == 'AdamW':
|
67 |
-
optim_caller = torch.optim.AdamW
|
68 |
-
iterations = 101
|
69 |
elif optim_name == 'L-BFGS':
|
70 |
-
optim_caller =
|
71 |
-
iterations = 20
|
72 |
|
73 |
-
print('-'*
|
74 |
-
print(
|
75 |
-
print('STYLE:', style_name)
|
76 |
-
print('CONTENT IMG SIZE:', original_size)
|
77 |
-
print('STYLE STRENGTH:', style_strength, f'(lr={lrs[style_strength-1]:.3f})')
|
78 |
|
79 |
st = time.time()
|
80 |
generated_img = inference(
|
@@ -85,11 +84,10 @@ def run(content_image, style_name, style_strength=10, optim_name='AdamW', apply_
|
|
85 |
style_features=style_features,
|
86 |
lr=lrs[style_strength-1],
|
87 |
apply_to_background=apply_to_background,
|
88 |
-
iterations=iterations,
|
89 |
optim_caller=optim_caller,
|
90 |
)
|
91 |
et = time.time()
|
92 |
-
print('
|
93 |
|
94 |
yield postprocess_img(generated_img, original_size)
|
95 |
|
@@ -107,14 +105,14 @@ with gr.Blocks(css=css) as demo:
|
|
107 |
gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer w/ Salient Region Preservation")
|
108 |
with gr.Row(elem_id='container'):
|
109 |
with gr.Column():
|
110 |
-
content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
|
111 |
-
style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
|
112 |
with gr.Group():
|
113 |
-
|
114 |
with gr.Group():
|
115 |
-
|
|
|
|
|
116 |
with gr.Accordion(label='Advanced Options', open=False):
|
117 |
-
optim_dropdown = gr.Radio(choices=['
|
118 |
submit_button = gr.Button('Submit', variant='primary')
|
119 |
|
120 |
examples = gr.Examples(
|
|
|
1 |
import os
|
2 |
import time
|
3 |
from datetime import datetime, timezone, timedelta
|
|
|
4 |
|
5 |
import spaces
|
6 |
import torch
|
7 |
import torch.optim as optim
|
|
|
8 |
import numpy as np
|
9 |
import gradio as gr
|
10 |
from safetensors.torch import load_file
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
|
13 |
+
from utils import preprocess_img, postprocess_img
|
14 |
from vgg.vgg19 import VGG_19
|
15 |
from u2net.model import U2Net
|
16 |
from inference import inference
|
|
|
18 |
if torch.cuda.is_available(): device = 'cuda'
|
19 |
elif torch.backends.mps.is_available(): device = 'mps'
|
20 |
else: device = 'cpu'
|
21 |
+
print('Device:', device)
|
22 |
+
if device == 'cuda': print('Name:', torch.cuda.get_device_name())
|
23 |
|
24 |
def load_model_without_module(model, model_path):
|
25 |
state_dict = load_file(model_path, device=device)
|
|
|
38 |
load_model_without_module(sod_model, local_model_path)
|
39 |
|
40 |
style_files = os.listdir('./style_images')
|
41 |
+
style_options = {
|
42 |
+
'Starry Night': './style_images/Starry_Night.jpg',
|
43 |
+
'Starry Night (v2)': './style_images/Starry_Night_v2.jpg',
|
44 |
+
'Scream': './style_images/Scream.jpg',
|
45 |
+
'Great Wave': './style_images/Great_Wave.jpg',
|
46 |
+
'Oil Painting': './style_images/Oil_Painting.jpg',
|
47 |
+
'Watercolor': './style_images/Watercolor.jpg',
|
48 |
+
'Mosaic': './style_images/Mosaic.jpg',
|
49 |
+
'Lego Bricks': './style_images/Lego_Bricks.jpg',
|
50 |
+
'Bokeh': './style_images/Bokeh.jpg',
|
51 |
+
}
|
52 |
+
lrs = np.linspace(0.015, 0.075, 3).tolist()
|
53 |
img_size = 512
|
54 |
|
|
|
55 |
cached_style_features = {}
|
56 |
for style_name, style_img_path in style_options.items():
|
57 |
+
style_img = preprocess_img(style_img_path, img_size)[0].to(device)
|
58 |
with torch.no_grad():
|
59 |
style_features = model(style_img)
|
60 |
cached_style_features[style_name] = style_features
|
61 |
|
62 |
@spaces.GPU(duration=30)
|
63 |
+
def run(content_image, style_name, style_strength=len(lrs), optim_name='AdamW', apply_to_background=False):
|
64 |
yield None
|
65 |
content_img, original_size = preprocess_img(content_image, img_size)
|
66 |
content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
|
67 |
content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
|
68 |
style_features = cached_style_features[style_name]
|
69 |
|
70 |
+
if optim_name == 'AdamW':
|
71 |
+
optim_caller = optim.AdamW
|
|
|
|
|
|
|
|
|
72 |
elif optim_name == 'L-BFGS':
|
73 |
+
optim_caller = optim.LBFGS
|
|
|
74 |
|
75 |
+
print('-'*30)
|
76 |
+
print(datetime.now(timezone.utc) - timedelta(hours=5)) # EST
|
|
|
|
|
|
|
77 |
|
78 |
st = time.time()
|
79 |
generated_img = inference(
|
|
|
84 |
style_features=style_features,
|
85 |
lr=lrs[style_strength-1],
|
86 |
apply_to_background=apply_to_background,
|
|
|
87 |
optim_caller=optim_caller,
|
88 |
)
|
89 |
et = time.time()
|
90 |
+
print(f'{et-st:.2f}s')
|
91 |
|
92 |
yield postprocess_img(generated_img, original_size)
|
93 |
|
|
|
105 |
gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer w/ Salient Region Preservation")
|
106 |
with gr.Row(elem_id='container'):
|
107 |
with gr.Column():
|
|
|
|
|
108 |
with gr.Group():
|
109 |
+
content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
|
110 |
with gr.Group():
|
111 |
+
style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
|
112 |
+
style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=len(lrs), step=1, value=len(lrs))
|
113 |
+
apply_to_background_checkbox = gr.Checkbox(label='Apply style transfer exclusively to the background', value=False)
|
114 |
with gr.Accordion(label='Advanced Options', open=False):
|
115 |
+
optim_dropdown = gr.Radio(choices=['AdamW', 'L-BFGS'], label='Optimizer', value='AdamW', type='value')
|
116 |
submit_button = gr.Button('Submit', variant='primary')
|
117 |
|
118 |
examples = gr.Examples(
|
inference.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import torch
|
2 |
import torch.optim as optim
|
3 |
import torch.nn.functional as F
|
4 |
-
from torchvision.transforms.functional import gaussian_blur
|
5 |
-
from tqdm import tqdm
|
6 |
|
7 |
def gram_matrix(feature):
|
8 |
b, c, h, w = feature.size()
|
@@ -28,8 +26,8 @@ def inference(
|
|
28 |
content_image_norm,
|
29 |
style_features,
|
30 |
apply_to_background,
|
31 |
-
lr=5e-2,
|
32 |
-
iterations=
|
33 |
optim_caller=optim.AdamW,
|
34 |
alpha=1,
|
35 |
beta=1,
|
@@ -58,7 +56,7 @@ def inference(
|
|
58 |
total_loss.backward()
|
59 |
return total_loss
|
60 |
|
61 |
-
for _ in
|
62 |
optimizer.step(closure)
|
63 |
if apply_to_background:
|
64 |
with torch.no_grad():
|
|
|
1 |
import torch
|
2 |
import torch.optim as optim
|
3 |
import torch.nn.functional as F
|
|
|
|
|
4 |
|
5 |
def gram_matrix(feature):
|
6 |
b, c, h, w = feature.size()
|
|
|
26 |
content_image_norm,
|
27 |
style_features,
|
28 |
apply_to_background,
|
29 |
+
lr=1.5e-2,
|
30 |
+
iterations=51,
|
31 |
optim_caller=optim.AdamW,
|
32 |
alpha=1,
|
33 |
beta=1,
|
|
|
56 |
total_loss.backward()
|
57 |
return total_loss
|
58 |
|
59 |
+
for _ in range(iterations):
|
60 |
optimizer.step(closure)
|
61 |
if apply_to_background:
|
62 |
with torch.no_grad():
|
requirements.txt
CHANGED
@@ -6,5 +6,4 @@ huggingface_hub
|
|
6 |
pillow
|
7 |
gradio
|
8 |
spaces
|
9 |
-
tqdm
|
10 |
tensorboard
|
|
|
6 |
pillow
|
7 |
gradio
|
8 |
spaces
|
|
|
9 |
tensorboard
|
utils.py
CHANGED
@@ -3,11 +3,8 @@ from PIL import Image
|
|
3 |
import torch
|
4 |
import torchvision.transforms as transforms
|
5 |
|
6 |
-
def
|
7 |
-
img = Image.open(
|
8 |
-
return preprocess_img(img, img_size, normalize)
|
9 |
-
|
10 |
-
def preprocess_img(img: Image, img_size, normalize=False):
|
11 |
original_size = img.size
|
12 |
|
13 |
if normalize:
|
|
|
3 |
import torch
|
4 |
import torchvision.transforms as transforms
|
5 |
|
6 |
+
def preprocess_img(img, img_size, normalize=False):
|
7 |
+
if type(img) == str: img = Image.open(img)
|
|
|
|
|
|
|
8 |
original_size = img.size
|
9 |
|
10 |
if normalize:
|