1plus1 commited on
Commit
4637a55
·
verified ·
1 Parent(s): dc7419f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -29
app.py CHANGED
@@ -5,34 +5,14 @@ import numpy as np
5
  import cv2
6
  import gradio as gr
7
  import os
8
- import subprocess
9
- import gdown
10
-
11
- os.makedirs('/content', exist_ok=True)
12
- # Download model
13
- # Pix2pix model
14
- model_url = 'https://drive.google.com/drive/folders/1jOxiKyf8n7fwNZfgeZyUrNJ90LEmIL3S?usp=sharing'
15
- os.makedirs('/content/pix2pix', exist_ok=True)
16
- subprocess.run(['gdown', '--fuzzy', model_url, '-O', '/content/pix2pix', '--folder'], check=True)
17
-
18
- # WD-Net
19
- model_url = 'https://drive.google.com/file/d/1M8EOE4Ej8oS4_0BHCEwExxu5CMFS5HQZ/view?usp=sharing'
20
- os.makedirs('/content/WD-Net', exist_ok=True)
21
- subprocess.run(['gdown', '--fuzzy', model_url, '-O', '/content/WD-Net/model.zip'], check=True)
22
- subprocess.run(['unzip', '/content/WD-Net/model.zip', '-d', '/content/WD-Net'], check=True)
23
-
24
- # MS-UNet
25
- model_url = 'https://drive.google.com/file/d/1-0_bEWTItkILbCJQ4ViEBGg0zJPaIcC1/view?usp=sharing'
26
- os.makedirs('/content/MS-UNet', exist_ok=True)
27
- subprocess.run(['gdown', '--fuzzy', model_url, '-O', '/content/MS-UNet/unet.keras'], check=True)
28
 
29
  # Load model
30
  # Load Pix2Pix
31
- pix2pix_path = '/model/wt_generator_best.keras'
32
  Pix2Pix = keras.saving.load_model(pix2pix_path)
33
 
34
  # Load MS-UNet
35
- unet_path = '/model/unet.keras'
36
  MS_UNet = keras.saving.load_model(unet_path)
37
 
38
  # Load WD-Net
@@ -43,16 +23,66 @@ class Clip(keras.layers.Layer):
43
  def call(self, input):
44
  return tf.clip_by_value(input, 0, 1)
45
 
46
- gen_path = '/model/generator_epoch:10.keras'
47
- WD_Net = keras.saving.load_model(gen_path)
 
 
48
 
49
  # Define infer function
50
  def infer(img, model='WD-Net'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Image original shape
52
  org_shape = img.shape
 
 
 
 
53
  # Choose model
54
  if model == 'WD-Net':
55
- generator = WD_Net
56
  img = tf.image.resize(img, [256, 256], method='area')
57
  # Normalize image and return
58
  img = tf.cast(img, tf.float32) / 255.
@@ -81,15 +111,28 @@ def infer(img, model='WD-Net'):
81
  rm_wt = rm_wt[0]
82
  rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
83
  out_img = (rm_wt * 255).astype(np.uint8)
84
- return out_img
85
 
86
  # Main gradio code
 
 
 
 
 
87
  model_list = ['WD-Net', 'MS-UNet', 'Pix2Pix']
88
 
89
  demo = gr.Interface(
90
  fn=infer,
91
- inputs=[gr.Image(), gr.Dropdown(model_list)],
92
- outputs=gr.Image(),
93
  )
94
 
95
- demo.launch()
 
 
 
 
 
 
 
 
 
5
  import cv2
6
  import gradio as gr
7
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Load model
10
  # Load Pix2Pix
11
+ pix2pix_path = './model/wt_generator_best.keras'
12
  Pix2Pix = keras.saving.load_model(pix2pix_path)
13
 
14
  # Load MS-UNet
15
+ unet_path = './model/unet.keras'
16
  MS_UNet = keras.saving.load_model(unet_path)
17
 
18
  # Load WD-Net
 
23
  def call(self, input):
24
  return tf.clip_by_value(input, 0, 1)
25
 
26
+ old_gen_path = './model/generator_epoch_10.keras'
27
+ WD_Net_old = keras.saving.load_model(old_gen_path)
28
+ new_gen_path = './model/WD-Net_generator.keras'
29
+ WD_Net_new = keras.saving.load_model(new_gen_path)
30
 
31
  # Define infer function
32
  def infer(img, model='WD-Net'):
33
+ # Read image
34
+ # img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3)
35
+ # Image original shape
36
+ org_shape = img.shape
37
+ org_img = tf.image.resize(img, [256, 256], method='area')
38
+ org_img = tf.cast(org_img, tf.uint8).numpy()
39
+ org_img = cv2.resize(org_img, (org_shape[1], org_shape[0]))
40
+
41
+ # Choose model
42
+ if model == 'WD-Net':
43
+ generator = WD_Net_old
44
+ img = tf.image.resize(img, [256, 256], method='area')
45
+ # Normalize image and return
46
+ img = tf.cast(img, tf.float32) / 255.
47
+ img = tf.expand_dims(img, axis=0)
48
+ rm_wt = generator.predict(img, verbose=0)
49
+ rm_wt = rm_wt['I'][0]
50
+ rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
51
+ out_img = (rm_wt * 255).astype(np.uint8)
52
+ elif model == 'MS-UNet':
53
+ generator = MS_UNet
54
+ img = tf.image.resize(img, [256, 256], method='area')
55
+ # Normalize image and return
56
+ img = (tf.cast(img, tf.float32) - 127.5) / 127.5
57
+ img = tf.expand_dims(img, axis=0)
58
+ rm_wt = generator.predict(img, verbose=0)
59
+ rm_wt = rm_wt[0]
60
+ rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
61
+ out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8)
62
+ elif model == 'Pix2Pix':
63
+ generator = Pix2Pix
64
+ img = tf.image.resize(img, [256, 256], method='area')
65
+ # Normalize image and return
66
+ img = tf.cast(img, tf.float32) / 255.
67
+ img = tf.expand_dims(img, axis=0)
68
+ rm_wt = generator.predict(img, verbose=0)
69
+ rm_wt = rm_wt[0]
70
+ rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
71
+ out_img = (rm_wt * 255).astype(np.uint8)
72
+ return org_img, out_img
73
+
74
+ def infer_v1(img_path, model="WD_Net"):
75
+ # Read image
76
+ img = tf.image.decode_png(tf.io.read_file('./data/' + img_path), channels=3)
77
  # Image original shape
78
  org_shape = img.shape
79
+ # org_img = tf.image.resize(img, [256, 256], method='area')
80
+ org_img = tf.cast(img, tf.uint8).numpy()
81
+ # org_img = cv2.resize(org_img, (org_shape[1], org_shape[0]))
82
+
83
  # Choose model
84
  if model == 'WD-Net':
85
+ generator = WD_Net_new
86
  img = tf.image.resize(img, [256, 256], method='area')
87
  # Normalize image and return
88
  img = tf.cast(img, tf.float32) / 255.
 
111
  rm_wt = rm_wt[0]
112
  rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
113
  out_img = (rm_wt * 255).astype(np.uint8)
114
+ return org_img, out_img
115
 
116
  # Main gradio code
117
+ # Define data and sort it
118
+ data = os.listdir('./data')
119
+ data.sort()
120
+
121
+ # Model list
122
  model_list = ['WD-Net', 'MS-UNet', 'Pix2Pix']
123
 
124
  demo = gr.Interface(
125
  fn=infer,
126
+ inputs=[gr.Image(label="Choose an Image"), gr.Dropdown(model_list, label="Model")],
127
+ outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")],
128
  )
129
 
130
+ demo_v1 = gr.Interface(
131
+ fn=infer_v1,
132
+ inputs=[gr.Dropdown(data, label="Choose an Image"), gr.Dropdown(model_list, label="Model")],
133
+ outputs=[gr.Image(label="Watermarked Image"), gr.Image(label="Removed Watermarked Image")],
134
+ )
135
+
136
+ tabbed_interface = gr.TabbedInterface([demo, demo_v1], ["Document", "Patch"], title="Watermark Removal")
137
+
138
+ tabbed_interface.launch()