jamino30 commited on
Commit
a3814f8
·
verified ·
1 Parent(s): 2eb84e6

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. app.py +28 -30
  2. inference.py +3 -5
  3. requirements.txt +0 -1
  4. utils.py +2 -5
app.py CHANGED
@@ -1,18 +1,16 @@
1
  import os
2
  import time
3
  from datetime import datetime, timezone, timedelta
4
- from concurrent.futures import ThreadPoolExecutor
5
 
6
  import spaces
7
  import torch
8
  import torch.optim as optim
9
- import torchvision.models as models
10
  import numpy as np
11
  import gradio as gr
12
  from safetensors.torch import load_file
13
  from huggingface_hub import hf_hub_download
14
 
15
- from utils import preprocess_img, preprocess_img_from_path, postprocess_img
16
  from vgg.vgg19 import VGG_19
17
  from u2net.model import U2Net
18
  from inference import inference
@@ -20,8 +18,8 @@ from inference import inference
20
  if torch.cuda.is_available(): device = 'cuda'
21
  elif torch.backends.mps.is_available(): device = 'mps'
22
  else: device = 'cpu'
23
- print('DEVICE:', device)
24
- if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
25
 
26
  def load_model_without_module(model, model_path):
27
  state_dict = load_file(model_path, device=device)
@@ -40,41 +38,42 @@ local_model_path = hf_hub_download(repo_id='jamino30/u2net-saliency', filename='
40
  load_model_without_module(sod_model, local_model_path)
41
 
42
  style_files = os.listdir('./style_images')
43
- style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
44
- lrs = np.logspace(np.log10(0.001), np.log10(0.1), 10).tolist()
 
 
 
 
 
 
 
 
 
 
45
  img_size = 512
46
 
47
- # store style(s) features
48
  cached_style_features = {}
49
  for style_name, style_img_path in style_options.items():
50
- style_img = preprocess_img_from_path(style_img_path, img_size)[0].to(device)
51
  with torch.no_grad():
52
  style_features = model(style_img)
53
  cached_style_features[style_name] = style_features
54
 
55
  @spaces.GPU(duration=30)
56
- def run(content_image, style_name, style_strength=10, optim_name='AdamW', apply_to_background=False):
57
  yield None
58
  content_img, original_size = preprocess_img(content_image, img_size)
59
  content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
60
  content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
61
  style_features = cached_style_features[style_name]
62
 
63
- if optim_name == 'Adam':
64
- optim_caller = torch.optim.Adam
65
- iterations = 101
66
- elif optim_name == 'AdamW':
67
- optim_caller = torch.optim.AdamW
68
- iterations = 101
69
  elif optim_name == 'L-BFGS':
70
- optim_caller = torch.optim.LBFGS
71
- iterations = 20
72
 
73
- print('-'*15)
74
- print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est
75
- print('STYLE:', style_name)
76
- print('CONTENT IMG SIZE:', original_size)
77
- print('STYLE STRENGTH:', style_strength, f'(lr={lrs[style_strength-1]:.3f})')
78
 
79
  st = time.time()
80
  generated_img = inference(
@@ -85,11 +84,10 @@ def run(content_image, style_name, style_strength=10, optim_name='AdamW', apply_
85
  style_features=style_features,
86
  lr=lrs[style_strength-1],
87
  apply_to_background=apply_to_background,
88
- iterations=iterations,
89
  optim_caller=optim_caller,
90
  )
91
  et = time.time()
92
- print('TIME TAKEN:', et-st)
93
 
94
  yield postprocess_img(generated_img, original_size)
95
 
@@ -107,14 +105,14 @@ with gr.Blocks(css=css) as demo:
107
  gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer w/ Salient Region Preservation")
108
  with gr.Row(elem_id='container'):
109
  with gr.Column():
110
- content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
111
- style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
112
  with gr.Group():
113
- style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=10, info='Higher values add artistic flair, lower values add a realistic feel.')
114
  with gr.Group():
115
- apply_to_background_checkbox = gr.Checkbox(label='Apply styling to background only', value=False)
 
 
116
  with gr.Accordion(label='Advanced Options', open=False):
117
- optim_dropdown = gr.Radio(choices=['Adam', 'AdamW', 'L-BFGS'], label='Optimizer', value='AdamW', type='value')
118
  submit_button = gr.Button('Submit', variant='primary')
119
 
120
  examples = gr.Examples(
 
1
  import os
2
  import time
3
  from datetime import datetime, timezone, timedelta
 
4
 
5
  import spaces
6
  import torch
7
  import torch.optim as optim
 
8
  import numpy as np
9
  import gradio as gr
10
  from safetensors.torch import load_file
11
  from huggingface_hub import hf_hub_download
12
 
13
+ from utils import preprocess_img, postprocess_img
14
  from vgg.vgg19 import VGG_19
15
  from u2net.model import U2Net
16
  from inference import inference
 
18
  if torch.cuda.is_available(): device = 'cuda'
19
  elif torch.backends.mps.is_available(): device = 'mps'
20
  else: device = 'cpu'
21
+ print('Device:', device)
22
+ if device == 'cuda': print('Name:', torch.cuda.get_device_name())
23
 
24
  def load_model_without_module(model, model_path):
25
  state_dict = load_file(model_path, device=device)
 
38
  load_model_without_module(sod_model, local_model_path)
39
 
40
  style_files = os.listdir('./style_images')
41
+ style_options = {
42
+ 'Starry Night': './style_images/Starry_Night.jpg',
43
+ 'Starry Night (v2)': './style_images/Starry_Night_v2.jpg',
44
+ 'Scream': './style_images/Scream.jpg',
45
+ 'Great Wave': './style_images/Great_Wave.jpg',
46
+ 'Oil Painting': './style_images/Oil_Painting.jpg',
47
+ 'Watercolor': './style_images/Watercolor.jpg',
48
+ 'Mosaic': './style_images/Mosaic.jpg',
49
+ 'Lego Bricks': './style_images/Lego_Bricks.jpg',
50
+ 'Bokeh': './style_images/Bokeh.jpg',
51
+ }
52
+ lrs = np.linspace(0.015, 0.075, 3).tolist()
53
  img_size = 512
54
 
 
55
  cached_style_features = {}
56
  for style_name, style_img_path in style_options.items():
57
+ style_img = preprocess_img(style_img_path, img_size)[0].to(device)
58
  with torch.no_grad():
59
  style_features = model(style_img)
60
  cached_style_features[style_name] = style_features
61
 
62
  @spaces.GPU(duration=30)
63
+ def run(content_image, style_name, style_strength=len(lrs), optim_name='AdamW', apply_to_background=False):
64
  yield None
65
  content_img, original_size = preprocess_img(content_image, img_size)
66
  content_img_normalized, _ = preprocess_img(content_image, img_size, normalize=True)
67
  content_img, content_img_normalized = content_img.to(device), content_img_normalized.to(device)
68
  style_features = cached_style_features[style_name]
69
 
70
+ if optim_name == 'AdamW':
71
+ optim_caller = optim.AdamW
 
 
 
 
72
  elif optim_name == 'L-BFGS':
73
+ optim_caller = optim.LBFGS
 
74
 
75
+ print('-'*30)
76
+ print(datetime.now(timezone.utc) - timedelta(hours=5)) # EST
 
 
 
77
 
78
  st = time.time()
79
  generated_img = inference(
 
84
  style_features=style_features,
85
  lr=lrs[style_strength-1],
86
  apply_to_background=apply_to_background,
 
87
  optim_caller=optim_caller,
88
  )
89
  et = time.time()
90
+ print(f'{et-st:.2f}s')
91
 
92
  yield postprocess_img(generated_img, original_size)
93
 
 
105
  gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer w/ Salient Region Preservation")
106
  with gr.Row(elem_id='container'):
107
  with gr.Column():
 
 
108
  with gr.Group():
109
+ content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
110
  with gr.Group():
111
+ style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
112
+ style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=len(lrs), step=1, value=len(lrs))
113
+ apply_to_background_checkbox = gr.Checkbox(label='Apply style transfer exclusively to the background', value=False)
114
  with gr.Accordion(label='Advanced Options', open=False):
115
+ optim_dropdown = gr.Radio(choices=['AdamW', 'L-BFGS'], label='Optimizer', value='AdamW', type='value')
116
  submit_button = gr.Button('Submit', variant='primary')
117
 
118
  examples = gr.Examples(
inference.py CHANGED
@@ -1,8 +1,6 @@
1
  import torch
2
  import torch.optim as optim
3
  import torch.nn.functional as F
4
- from torchvision.transforms.functional import gaussian_blur
5
- from tqdm import tqdm
6
 
7
  def gram_matrix(feature):
8
  b, c, h, w = feature.size()
@@ -28,8 +26,8 @@ def inference(
28
  content_image_norm,
29
  style_features,
30
  apply_to_background,
31
- lr=5e-2,
32
- iterations=101,
33
  optim_caller=optim.AdamW,
34
  alpha=1,
35
  beta=1,
@@ -58,7 +56,7 @@ def inference(
58
  total_loss.backward()
59
  return total_loss
60
 
61
- for _ in tqdm(range(iterations)):
62
  optimizer.step(closure)
63
  if apply_to_background:
64
  with torch.no_grad():
 
1
  import torch
2
  import torch.optim as optim
3
  import torch.nn.functional as F
 
 
4
 
5
  def gram_matrix(feature):
6
  b, c, h, w = feature.size()
 
26
  content_image_norm,
27
  style_features,
28
  apply_to_background,
29
+ lr=1.5e-2,
30
+ iterations=51,
31
  optim_caller=optim.AdamW,
32
  alpha=1,
33
  beta=1,
 
56
  total_loss.backward()
57
  return total_loss
58
 
59
+ for _ in range(iterations):
60
  optimizer.step(closure)
61
  if apply_to_background:
62
  with torch.no_grad():
requirements.txt CHANGED
@@ -6,5 +6,4 @@ huggingface_hub
6
  pillow
7
  gradio
8
  spaces
9
- tqdm
10
  tensorboard
 
6
  pillow
7
  gradio
8
  spaces
 
9
  tensorboard
utils.py CHANGED
@@ -3,11 +3,8 @@ from PIL import Image
3
  import torch
4
  import torchvision.transforms as transforms
5
 
6
- def preprocess_img_from_path(path_to_image, img_size, normalize=False):
7
- img = Image.open(path_to_image)
8
- return preprocess_img(img, img_size, normalize)
9
-
10
- def preprocess_img(img: Image, img_size, normalize=False):
11
  original_size = img.size
12
 
13
  if normalize:
 
3
  import torch
4
  import torchvision.transforms as transforms
5
 
6
+ def preprocess_img(img, img_size, normalize=False):
7
+ if type(img) == str: img = Image.open(img)
 
 
 
8
  original_size = img.size
9
 
10
  if normalize: