Add local files to repository
Browse files- app.py +273 -0
- maskextract.py +46 -0
- test.py +188 -0
- test_gradio.py +92 -0
- train.py +280 -0
- train_bit.py +246 -0
app.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image, ImageDraw
|
5 |
+
import requests
|
6 |
+
from copy import deepcopy
|
7 |
+
import cv2
|
8 |
+
from test_gradio import load_image, image_editing
|
9 |
+
|
10 |
+
import options.options as option
|
11 |
+
from utils.JPEG import DiffJPEG
|
12 |
+
from scipy.io.wavfile import read as wav_read
|
13 |
+
from scipy.io import wavfile
|
14 |
+
|
15 |
+
import os
|
16 |
+
import math
|
17 |
+
import argparse
|
18 |
+
import random
|
19 |
+
import logging
|
20 |
+
|
21 |
+
import torch.distributed as dist
|
22 |
+
import torch.multiprocessing as mp
|
23 |
+
from data.data_sampler import DistIterSampler
|
24 |
+
|
25 |
+
from utils import util
|
26 |
+
from data.util import read_img
|
27 |
+
from models import create_model as create_model_editguard
|
28 |
+
import base64
|
29 |
+
import gradio as gr
|
30 |
+
|
31 |
+
from scipy.ndimage import zoom
|
32 |
+
|
33 |
+
import matplotlib.pyplot as plt
|
34 |
+
|
35 |
+
def img_to_base64(filepath):
|
36 |
+
with open(filepath, "rb") as img_file:
|
37 |
+
return base64.b64encode(img_file.read()).decode()
|
38 |
+
|
39 |
+
logo_base64 = img_to_base64("../logo.png")
|
40 |
+
|
41 |
+
html_content = f"""
|
42 |
+
<div style='display: flex; align-items: center; justify-content: center; padding: 20px;'>
|
43 |
+
<img src='data:image/png;base64,{logo_base64}' alt='Logo' style='height: 50px; margin-right: 20px;'>
|
44 |
+
<strong><font size='8'>EditGuard<font></strong>
|
45 |
+
</div>
|
46 |
+
"""
|
47 |
+
|
48 |
+
# Examples
|
49 |
+
examples = [
|
50 |
+
["../dataset/examples/0011.png"],
|
51 |
+
["../dataset/examples/0012.png"],
|
52 |
+
["../dataset/examples/0003.png"],
|
53 |
+
["../dataset/examples/0004.png"],
|
54 |
+
["../dataset/examples/0005.png"],
|
55 |
+
["../dataset/examples/0006.png"],
|
56 |
+
["../dataset/examples/0007.png"],
|
57 |
+
["../dataset/examples/0008.png"],
|
58 |
+
["../dataset/examples/0009.png"],
|
59 |
+
["../dataset/examples/0010.png"],
|
60 |
+
["../dataset/examples/0002.png"],
|
61 |
+
]
|
62 |
+
|
63 |
+
default_example = examples[0]
|
64 |
+
|
65 |
+
def hiding(image_input, bit_input, model):
|
66 |
+
|
67 |
+
message = np.array([int(bit_input[i:i+1]) for i in range(0, len(bit_input), 1)])
|
68 |
+
message = message - 0.5
|
69 |
+
val_data = load_image(image_input, message)
|
70 |
+
model.feed_data(val_data)
|
71 |
+
container = model.image_hiding()
|
72 |
+
|
73 |
+
from PIL import Image
|
74 |
+
image = Image.fromarray(container)
|
75 |
+
return container, container
|
76 |
+
|
77 |
+
def rand(num_bits=64):
|
78 |
+
random_str = ''.join([str(random.randint(0, 1)) for _ in range(num_bits)])
|
79 |
+
return random_str
|
80 |
+
|
81 |
+
def ImageEdit(img, prompt, model_index):
|
82 |
+
image, mask = img["image"], np.float32(img["mask"])
|
83 |
+
|
84 |
+
received_image = image_editing(image, mask, prompt)
|
85 |
+
return received_image, received_image, received_image
|
86 |
+
|
87 |
+
|
88 |
+
def imgae_model_select(ckp_index=0):
|
89 |
+
# options
|
90 |
+
opt = option.parse("options/test_editguard.yml", is_train=True)
|
91 |
+
# distributed training settings
|
92 |
+
opt['dist'] = False
|
93 |
+
rank = -1
|
94 |
+
print('Disabled distributed training.')
|
95 |
+
|
96 |
+
# loading resume state if exists
|
97 |
+
if opt['path'].get('resume_state', None):
|
98 |
+
# distributed resuming: all load into default GPU
|
99 |
+
device_id = torch.cuda.current_device()
|
100 |
+
resume_state = torch.load(opt['path']['resume_state'],
|
101 |
+
map_location=lambda storage, loc: storage.cuda(device_id))
|
102 |
+
option.check_resume(opt, resume_state['iter']) # check resume options
|
103 |
+
else:
|
104 |
+
resume_state = None
|
105 |
+
|
106 |
+
# convert to NoneDict, which returns None for missing keys
|
107 |
+
opt = option.dict_to_nonedict(opt)
|
108 |
+
torch.backends.cudnn.benchmark = True
|
109 |
+
# create model
|
110 |
+
|
111 |
+
model = create_model_editguard(opt)
|
112 |
+
|
113 |
+
if ckp_index == 0:
|
114 |
+
model_pth = '../checkpoints/clean.pth'
|
115 |
+
print(model_pth)
|
116 |
+
model.load_test(model_pth)
|
117 |
+
return model
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
def Gaussian_image_degradation(image, NL):
|
122 |
+
image = torch.from_numpy(np.transpose(image, (2, 0, 1)))
|
123 |
+
image = image.unsqueeze(0)
|
124 |
+
NL = NL / 255.0
|
125 |
+
noise = np.random.normal(0, NL, image.shape)
|
126 |
+
torchnoise = torch.from_numpy(noise).float()
|
127 |
+
y_forw = image + torchnoise
|
128 |
+
y_forw = torch.clamp(y_forw, 0, 1)
|
129 |
+
y_forw = y_forw.permute(0, 2, 3, 1)
|
130 |
+
y_forw = y_forw.cpu().detach().numpy().squeeze()
|
131 |
+
|
132 |
+
y_forw = (y_forw * 255.0).astype(np.uint8)
|
133 |
+
return y_forw, y_forw
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
def JPEG_image_degradation(image, NL):
|
138 |
+
image = image.astype(np.float32)
|
139 |
+
image = torch.from_numpy(np.transpose(image, (2, 0, 1)))
|
140 |
+
image = image.unsqueeze(0)
|
141 |
+
JPEG = DiffJPEG(differentiable=True, quality=int(NL))
|
142 |
+
y_forw = JPEG(image)
|
143 |
+
y_forw = y_forw.permute(0, 2, 3, 1)
|
144 |
+
y_forw = y_forw.cpu().detach().numpy().squeeze()
|
145 |
+
y_forw = (y_forw * 255.0).astype(np.uint8)
|
146 |
+
|
147 |
+
return y_forw, y_forw
|
148 |
+
|
149 |
+
|
150 |
+
def revealing(image_edited, input_bit, model_list, model):
|
151 |
+
|
152 |
+
if model_list==0:
|
153 |
+
number = 0.2
|
154 |
+
else:
|
155 |
+
number = 0.2
|
156 |
+
|
157 |
+
container_data = load_image(image_edited) ## load tampered images
|
158 |
+
model.feed_data(container_data)
|
159 |
+
mask, remesg = model.image_recovery(number)
|
160 |
+
mask = Image.fromarray(mask.astype(np.uint8))
|
161 |
+
remesg = remesg.cpu().numpy()[0]
|
162 |
+
remesg = ''.join([str(int(x)) for x in remesg])
|
163 |
+
bit_acc = calculate_similarity_percentage(input_bit, remesg)
|
164 |
+
return mask, remesg, bit_acc
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
def calculate_similarity_percentage(str1, str2):
|
169 |
+
|
170 |
+
if len(str1) == 0:
|
171 |
+
return "原始版权水印未知"
|
172 |
+
elif len(str1) != len(str2):
|
173 |
+
return "输入输出水印长度不同"
|
174 |
+
total_length = len(str1)
|
175 |
+
same_count = sum(1 for x, y in zip(str1, str2) if x == y)
|
176 |
+
similarity_percentage = (same_count / total_length) * 100
|
177 |
+
return f"{similarity_percentage}%"
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
# Description
|
182 |
+
title = "<center><strong><font size='8'>EditGuard<font></strong></center>"
|
183 |
+
|
184 |
+
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
|
185 |
+
|
186 |
+
with gr.Blocks(css=css, title="EditGuard") as demo:
|
187 |
+
gr.HTML(html_content)
|
188 |
+
model = gr.State(value = None)
|
189 |
+
save_h = gr.State(value = None)
|
190 |
+
save_w = gr.State(value = None)
|
191 |
+
sam_global_points = gr.State([])
|
192 |
+
sam_global_point_label = gr.State([])
|
193 |
+
sam_original_image = gr.State(value=None)
|
194 |
+
sam_mask = gr.State(value=None)
|
195 |
+
|
196 |
+
with gr.Tabs():
|
197 |
+
with gr.TabItem('多功能取证水印'):
|
198 |
+
|
199 |
+
DESCRIPTION = """
|
200 |
+
## 使用方法:
|
201 |
+
- 上传图像和版权水印(64位比特序列),点击"嵌入水印"按钮,生成带水印的图像。
|
202 |
+
- 涂抹要编辑的区域,并使用Inpainting算法编辑图像。
|
203 |
+
- 点击"提取"按钮检测篡改区域并输出版权水印。"""
|
204 |
+
|
205 |
+
gr.Markdown(DESCRIPTION)
|
206 |
+
save_inpainted_image = gr.State(value=None)
|
207 |
+
with gr.Column():
|
208 |
+
with gr.Row():
|
209 |
+
model_list = gr.Dropdown(label="选择模型", choices=["模型1"], type = 'index')
|
210 |
+
clear_button = gr.Button("清除全部")
|
211 |
+
with gr.Box():
|
212 |
+
gr.Markdown("# 1. 嵌入水印")
|
213 |
+
with gr.Row():
|
214 |
+
with gr.Column():
|
215 |
+
image_input = gr.Image(source='upload', label="原始图片", interactive=True, type="numpy", value=default_example[0])
|
216 |
+
with gr.Row():
|
217 |
+
bit_input = gr.Textbox(label="输入版权水印(64位比特序列)", placeholder="在这里输入...")
|
218 |
+
rand_bit = gr.Button("🎲 随机生成版权水印")
|
219 |
+
hiding_button = gr.Button("嵌入水印")
|
220 |
+
with gr.Column():
|
221 |
+
image_watermark = gr.Image(source="upload", label="带有水印的图片", interactive=True, type="numpy")
|
222 |
+
|
223 |
+
|
224 |
+
with gr.Box():
|
225 |
+
gr.Markdown("# 2. 篡改图片")
|
226 |
+
with gr.Row():
|
227 |
+
with gr.Column():
|
228 |
+
image_edit = gr.Image(source='upload',tool="sketch", label="选取篡改区域", interactive=True, type="numpy")
|
229 |
+
inpainting_model_list = gr.Dropdown(label="选择篡改模型", choices=["模型1:SD_inpainting"], type = 'index')
|
230 |
+
text_prompt = gr.Textbox(label="篡改提示词")
|
231 |
+
inpainting_button = gr.Button("篡改图片")
|
232 |
+
with gr.Column():
|
233 |
+
image_edited = gr.Image(source="upload", label="篡改结果", interactive=True, type="numpy")
|
234 |
+
|
235 |
+
|
236 |
+
with gr.Box():
|
237 |
+
gr.Markdown("# 3. 提取水印&篡改区域")
|
238 |
+
with gr.Row():
|
239 |
+
with gr.Column():
|
240 |
+
image_edited_1 = gr.Image(source="upload", label="待提取图片", interactive=True, type="numpy")
|
241 |
+
|
242 |
+
revealing_button = gr.Button("提取")
|
243 |
+
with gr.Column():
|
244 |
+
edit_mask = gr.Image(source='upload', label="编辑区域蒙版预测", interactive=True, type="numpy")
|
245 |
+
bit_output = gr.Textbox(label="版权水印预测")
|
246 |
+
acc_output = gr.Textbox(label="水印预测准确率")
|
247 |
+
|
248 |
+
gr.Examples(
|
249 |
+
examples=examples,
|
250 |
+
inputs=[image_input],
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
model_list.change(
|
255 |
+
imgae_model_select, inputs = [model_list], outputs=[model]
|
256 |
+
)
|
257 |
+
hiding_button.click(
|
258 |
+
hiding, inputs=[image_input, bit_input, model], outputs=[image_watermark, image_edit]
|
259 |
+
)
|
260 |
+
rand_bit.click(
|
261 |
+
rand, inputs=[], outputs=[bit_input]
|
262 |
+
)
|
263 |
+
|
264 |
+
|
265 |
+
inpainting_button.click(
|
266 |
+
ImageEdit, inputs = [image_edit, text_prompt, inpainting_model_list], outputs=[image_edited, image_edited_1, save_inpainted_image]
|
267 |
+
)
|
268 |
+
|
269 |
+
revealing_button.click(
|
270 |
+
revealing, inputs=[image_edited_1, bit_input, model_list, model], outputs=[edit_mask, bit_output, acc_output]
|
271 |
+
)
|
272 |
+
|
273 |
+
demo.launch(server_name="0.0.0.0", server_port=2004, share=True, favicon_path='../logo.png')
|
maskextract.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('--threshold', default=0.2, type=float, help='Path to option YMAL file.')
|
9 |
+
args = parser.parse_args()
|
10 |
+
|
11 |
+
input_folder = 'results/test_age-set'
|
12 |
+
output_folder = 'results/mask'
|
13 |
+
|
14 |
+
for filename in os.listdir(input_folder):
|
15 |
+
if filename.endswith('_0_0_LRGT.png'):
|
16 |
+
digits = filename.split('_')[0]
|
17 |
+
if digits.isdigit():
|
18 |
+
digits = int(digits)
|
19 |
+
|
20 |
+
if digits >= 0 and digits <= 1000:
|
21 |
+
|
22 |
+
input_path_LRGT = os.path.join(input_folder, filename)
|
23 |
+
input_path_SR_h = os.path.join(input_folder, filename).replace('LRGT', 'SR_h')
|
24 |
+
|
25 |
+
image_LRGT = Image.open(input_path_LRGT).convert("RGB")
|
26 |
+
image_SR_h = Image.open(input_path_SR_h).convert("RGB")
|
27 |
+
|
28 |
+
w, h = image_SR_h.size
|
29 |
+
image_LRGT = image_LRGT.resize((w, h))
|
30 |
+
|
31 |
+
array_LRGT = np.array(image_LRGT) / 255.
|
32 |
+
array_SR_h = np.array(image_SR_h) / 255.
|
33 |
+
|
34 |
+
residual = np.abs(array_LRGT - array_SR_h)
|
35 |
+
|
36 |
+
threshold = args.threshold
|
37 |
+
mask = np.where(residual > threshold, 1, 0)
|
38 |
+
|
39 |
+
os.makedirs(output_folder, exist_ok=True)
|
40 |
+
|
41 |
+
output_path = os.path.join(output_folder, str(digits+1).zfill(4)+'.png')
|
42 |
+
|
43 |
+
mask = np.sum(mask, axis=2)
|
44 |
+
|
45 |
+
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
|
46 |
+
mask_image.save(output_path)
|
test.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import argparse
|
4 |
+
import random
|
5 |
+
import logging
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
import torch.multiprocessing as mp
|
10 |
+
from data.data_sampler import DistIterSampler
|
11 |
+
|
12 |
+
import options.options as option
|
13 |
+
from utils import util
|
14 |
+
from data import create_dataloader, create_dataset
|
15 |
+
from models import create_model
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
|
19 |
+
def init_dist(backend='nccl', **kwargs):
|
20 |
+
''' initialization for distributed training'''
|
21 |
+
# if mp.get_start_method(allow_none=True) is None:
|
22 |
+
if mp.get_start_method(allow_none=True) != 'spawn':
|
23 |
+
mp.set_start_method('spawn')
|
24 |
+
rank = int(os.environ['RANK'])
|
25 |
+
num_gpus = torch.cuda.device_count()
|
26 |
+
torch.cuda.set_device(rank % num_gpus)
|
27 |
+
dist.init_process_group(backend=backend, **kwargs)
|
28 |
+
|
29 |
+
def cal_pnsr(sr_img, gt_img):
|
30 |
+
# calculate PSNR
|
31 |
+
gt_img = gt_img / 255.
|
32 |
+
sr_img = sr_img / 255.
|
33 |
+
|
34 |
+
psnr = util.calculate_psnr(sr_img * 255, gt_img * 255)
|
35 |
+
|
36 |
+
return psnr
|
37 |
+
|
38 |
+
def get_min_avg_and_indices(nums):
|
39 |
+
# Get the indices of the smallest 1000 elements
|
40 |
+
indices = sorted(range(len(nums)), key=lambda i: nums[i])[:900]
|
41 |
+
|
42 |
+
# Calculate the average of these elements
|
43 |
+
avg = sum(nums[i] for i in indices) / 900
|
44 |
+
|
45 |
+
# Write the indices to a txt file
|
46 |
+
with open("indices.txt", "w") as file:
|
47 |
+
for index in indices:
|
48 |
+
file.write(str(index) + "\n")
|
49 |
+
|
50 |
+
return avg
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
# options
|
55 |
+
parser = argparse.ArgumentParser()
|
56 |
+
parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
|
57 |
+
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
58 |
+
help='job launcher')
|
59 |
+
parser.add_argument('--ckpt', type=str, default='/userhome/NewIBSN/EditGuard_open/checkpoints/clean.pth', help='Path to pre-trained model.')
|
60 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
61 |
+
args = parser.parse_args()
|
62 |
+
opt = option.parse(args.opt, is_train=True)
|
63 |
+
|
64 |
+
# distributed training settings
|
65 |
+
if args.launcher == 'none': # disabled distributed training
|
66 |
+
opt['dist'] = False
|
67 |
+
rank = -1
|
68 |
+
print('Disabled distributed training.')
|
69 |
+
else:
|
70 |
+
opt['dist'] = True
|
71 |
+
init_dist()
|
72 |
+
world_size = torch.distributed.get_world_size()
|
73 |
+
rank = torch.distributed.get_rank()
|
74 |
+
|
75 |
+
# loading resume state if exists
|
76 |
+
if opt['path'].get('resume_state', None):
|
77 |
+
# distributed resuming: all load into default GPU
|
78 |
+
device_id = torch.cuda.current_device()
|
79 |
+
resume_state = torch.load(opt['path']['resume_state'],
|
80 |
+
map_location=lambda storage, loc: storage.cuda(device_id))
|
81 |
+
option.check_resume(opt, resume_state['iter']) # check resume options
|
82 |
+
else:
|
83 |
+
resume_state = None
|
84 |
+
|
85 |
+
# convert to NoneDict, which returns None for missing keys
|
86 |
+
opt = option.dict_to_nonedict(opt)
|
87 |
+
|
88 |
+
torch.backends.cudnn.benchmark = True
|
89 |
+
# torch.backends.cudnn.deterministic = True
|
90 |
+
|
91 |
+
#### create train and val dataloader
|
92 |
+
dataset_ratio = 200 # enlarge the size of each epoch
|
93 |
+
for phase, dataset_opt in opt['datasets'].items():
|
94 |
+
print("phase", phase)
|
95 |
+
if phase == 'TD':
|
96 |
+
val_set = create_dataset(dataset_opt)
|
97 |
+
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
|
98 |
+
elif phase == 'val':
|
99 |
+
val_set = create_dataset(dataset_opt)
|
100 |
+
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
|
101 |
+
else:
|
102 |
+
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
103 |
+
|
104 |
+
# create model
|
105 |
+
model = create_model(opt)
|
106 |
+
model.load_test(args.ckpt)
|
107 |
+
|
108 |
+
# validation
|
109 |
+
avg_psnr = 0.0
|
110 |
+
avg_psnr_h = [0.0]*opt['num_image']
|
111 |
+
avg_psnr_lr = 0.0
|
112 |
+
biterr = []
|
113 |
+
idx = 0
|
114 |
+
for image_id, val_data in enumerate(val_loader):
|
115 |
+
img_dir = os.path.join('results',opt['name'])
|
116 |
+
util.mkdir(img_dir)
|
117 |
+
|
118 |
+
model.feed_data(val_data)
|
119 |
+
model.test(image_id)
|
120 |
+
|
121 |
+
visuals = model.get_current_visuals()
|
122 |
+
|
123 |
+
t_step = visuals['SR'].shape[0]
|
124 |
+
idx += t_step
|
125 |
+
n = len(visuals['SR_h'])
|
126 |
+
|
127 |
+
a = visuals['recmessage'][0]
|
128 |
+
b = visuals['message'][0]
|
129 |
+
|
130 |
+
bitrecord = util.decoded_message_error_rate_batch(a, b)
|
131 |
+
print(bitrecord)
|
132 |
+
biterr.append(bitrecord)
|
133 |
+
|
134 |
+
for i in range(t_step):
|
135 |
+
|
136 |
+
sr_img = util.tensor2img(visuals['SR'][i]) # uint8
|
137 |
+
sr_img_h = []
|
138 |
+
for j in range(n):
|
139 |
+
sr_img_h.append(util.tensor2img(visuals['SR_h'][j][i])) # uint8
|
140 |
+
gt_img = util.tensor2img(visuals['GT'][i]) # uint8
|
141 |
+
lr_img = util.tensor2img(visuals['LR'][i])
|
142 |
+
lrgt_img = []
|
143 |
+
for j in range(n):
|
144 |
+
lrgt_img.append(util.tensor2img(visuals['LR_ref'][j][i]))
|
145 |
+
|
146 |
+
# Save SR images for reference
|
147 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'SR'))
|
148 |
+
util.save_img(sr_img, save_img_path)
|
149 |
+
|
150 |
+
for j in range(n):
|
151 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(image_id, i, j, 'SR_h'))
|
152 |
+
util.save_img(sr_img_h[j], save_img_path)
|
153 |
+
|
154 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'GT'))
|
155 |
+
util.save_img(gt_img, save_img_path)
|
156 |
+
|
157 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'LR'))
|
158 |
+
util.save_img(lr_img, save_img_path)
|
159 |
+
|
160 |
+
for j in range(n):
|
161 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(image_id, i, j, 'LRGT'))
|
162 |
+
util.save_img(lrgt_img[j], save_img_path)
|
163 |
+
|
164 |
+
psnr = cal_pnsr(sr_img, gt_img)
|
165 |
+
psnr_h = []
|
166 |
+
for j in range(n):
|
167 |
+
psnr_h.append(cal_pnsr(sr_img_h[j], lrgt_img[j]))
|
168 |
+
psnr_lr = cal_pnsr(lr_img, gt_img)
|
169 |
+
|
170 |
+
avg_psnr += psnr
|
171 |
+
for j in range(n):
|
172 |
+
avg_psnr_h[j] += psnr_h[j]
|
173 |
+
avg_psnr_lr += psnr_lr
|
174 |
+
|
175 |
+
avg_psnr = avg_psnr / idx
|
176 |
+
avg_biterr = sum(biterr) / len(biterr)
|
177 |
+
print(get_min_avg_and_indices(biterr))
|
178 |
+
|
179 |
+
avg_psnr_h = [psnr / idx for psnr in avg_psnr_h]
|
180 |
+
avg_psnr_lr = avg_psnr_lr / idx
|
181 |
+
res_psnr_h = ''
|
182 |
+
for p in avg_psnr_h:
|
183 |
+
res_psnr_h+=('_{:.4e}'.format(p))
|
184 |
+
print('# Validation # PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}, Bit_Error: {:.4e}'.format(avg_psnr, res_psnr_h, avg_psnr_lr, avg_biterr))
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == '__main__':
|
188 |
+
main()
|
test_gradio.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
import argparse
|
6 |
+
import random
|
7 |
+
import logging
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
import torch.multiprocessing as mp
|
12 |
+
from data.data_sampler import DistIterSampler
|
13 |
+
|
14 |
+
import options.options as option
|
15 |
+
from utils import util
|
16 |
+
from data.util import read_img
|
17 |
+
from data import create_dataloader, create_dataset
|
18 |
+
from models import create_model
|
19 |
+
import numpy as np
|
20 |
+
from PIL import Image
|
21 |
+
from diffusers import StableDiffusionInpaintPipeline
|
22 |
+
|
23 |
+
|
24 |
+
def init_dist(backend='nccl', **kwargs):
|
25 |
+
''' initialization for distributed training'''
|
26 |
+
# if mp.get_start_method(allow_none=True) is None:
|
27 |
+
if mp.get_start_method(allow_none=True) != 'spawn':
|
28 |
+
mp.set_start_method('spawn')
|
29 |
+
rank = int(os.environ['RANK'])
|
30 |
+
num_gpus = torch.cuda.device_count()
|
31 |
+
torch.cuda.set_device(rank % num_gpus)
|
32 |
+
dist.init_process_group(backend=backend, **kwargs)
|
33 |
+
|
34 |
+
|
35 |
+
def load_image(image, message = None):
|
36 |
+
# img_GT = read_img(None, image_path)
|
37 |
+
img_GT = image / 255
|
38 |
+
# print(img_GT)
|
39 |
+
img_GT = img_GT[:, :, [2, 1, 0]]
|
40 |
+
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0)
|
41 |
+
img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None)
|
42 |
+
img_GT = img_GT.unsqueeze(0)
|
43 |
+
|
44 |
+
_, T, C, W, H = img_GT.shape
|
45 |
+
list_h = []
|
46 |
+
R = 0
|
47 |
+
G = 0
|
48 |
+
B = 255
|
49 |
+
image = Image.new('RGB', (W, H), (R, G, B))
|
50 |
+
result = np.array(image) / 255.
|
51 |
+
expanded_matrix = np.expand_dims(result, axis=0)
|
52 |
+
expanded_matrix = np.repeat(expanded_matrix, T, axis=0)
|
53 |
+
imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float()
|
54 |
+
imgs_LQ = imgs_LQ.permute(0, 3, 1, 2)
|
55 |
+
imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None)
|
56 |
+
imgs_LQ = imgs_LQ.unsqueeze(0)
|
57 |
+
|
58 |
+
list_h.append(imgs_LQ)
|
59 |
+
|
60 |
+
list_h = torch.stack(list_h, dim=0)
|
61 |
+
|
62 |
+
return {
|
63 |
+
'LQ': list_h,
|
64 |
+
'GT': img_GT,
|
65 |
+
'MES': message
|
66 |
+
}
|
67 |
+
|
68 |
+
|
69 |
+
def image_editing(image_numpy, mask_image, prompt):
|
70 |
+
|
71 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
72 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
73 |
+
torch_dtype=torch.float16,
|
74 |
+
).to("cuda")
|
75 |
+
|
76 |
+
pil_image = Image.fromarray(image_numpy)
|
77 |
+
print(mask_image.shape)
|
78 |
+
print("maskmin", mask_image.min(), "maskmax", mask_image.max())
|
79 |
+
mask_image = Image.fromarray(mask_image.astype(np.uint8)).convert("L")
|
80 |
+
image_init = pil_image.convert("RGB").resize((512, 512))
|
81 |
+
|
82 |
+
h, w = mask_image.size
|
83 |
+
|
84 |
+
image_inpaint = pipe(prompt=prompt, image=image_init, mask_image=mask_image, height=w, width=h).images[0]
|
85 |
+
image_inpaint = np.array(image_inpaint) / 255.
|
86 |
+
image = np.array(image_init) / 255.
|
87 |
+
mask_image = np.array(mask_image)
|
88 |
+
mask_image = np.stack([mask_image] * 3, axis=-1) / 255.
|
89 |
+
mask_image = mask_image.astype(np.uint8)
|
90 |
+
image_fuse = image * (1 - mask_image) + image_inpaint * mask_image
|
91 |
+
|
92 |
+
return image_fuse
|
train.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import argparse
|
4 |
+
import random
|
5 |
+
import logging
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
import torch.multiprocessing as mp
|
10 |
+
from data.data_sampler import DistIterSampler
|
11 |
+
|
12 |
+
import options.options as option
|
13 |
+
from utils import util
|
14 |
+
from data import create_dataloader, create_dataset
|
15 |
+
from models import create_model
|
16 |
+
|
17 |
+
|
18 |
+
def init_dist(backend='nccl', **kwargs):
|
19 |
+
''' initialization for distributed training'''
|
20 |
+
# if mp.get_start_method(allow_none=True) is None:
|
21 |
+
if mp.get_start_method(allow_none=True) != 'spawn':
|
22 |
+
mp.set_start_method('spawn')
|
23 |
+
rank = int(os.environ['RANK'])
|
24 |
+
num_gpus = torch.cuda.device_count()
|
25 |
+
torch.cuda.set_device(rank % num_gpus)
|
26 |
+
dist.init_process_group(backend=backend, **kwargs)
|
27 |
+
|
28 |
+
def cal_pnsr(sr_img, gt_img):
|
29 |
+
# calculate PSNR
|
30 |
+
gt_img = gt_img / 255.
|
31 |
+
sr_img = sr_img / 255.
|
32 |
+
psnr = util.calculate_psnr(sr_img * 255, gt_img * 255)
|
33 |
+
|
34 |
+
return psnr
|
35 |
+
|
36 |
+
def main():
|
37 |
+
# options
|
38 |
+
parser = argparse.ArgumentParser()
|
39 |
+
parser.add_argument('-opt', type=str, help='Path to option YMAL file.') # config 文件
|
40 |
+
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
41 |
+
help='job launcher')
|
42 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
43 |
+
args = parser.parse_args()
|
44 |
+
opt = option.parse(args.opt, is_train=True)
|
45 |
+
|
46 |
+
# distributed training settings
|
47 |
+
if args.launcher == 'none': # disabled distributed training
|
48 |
+
opt['dist'] = False
|
49 |
+
rank = -1
|
50 |
+
print('Disabled distributed training.')
|
51 |
+
else:
|
52 |
+
opt['dist'] = True
|
53 |
+
init_dist()
|
54 |
+
world_size = torch.distributed.get_world_size()
|
55 |
+
rank = torch.distributed.get_rank()
|
56 |
+
|
57 |
+
# loading resume state if exists
|
58 |
+
if opt['path'].get('resume_state', None):
|
59 |
+
# distributed resuming: all load into default GPU
|
60 |
+
device_id = torch.cuda.current_device()
|
61 |
+
resume_state = torch.load(opt['path']['resume_state'],
|
62 |
+
map_location=lambda storage, loc: storage.cuda(device_id))
|
63 |
+
# resume_state = torch.load(opt['path']['resume_state'],
|
64 |
+
# map_location=lambda storage, loc: storage.cuda(device_id), strict=False)
|
65 |
+
option.check_resume(opt, resume_state['iter']) # check resume options
|
66 |
+
else:
|
67 |
+
resume_state = None
|
68 |
+
|
69 |
+
# mkdir and loggers
|
70 |
+
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
|
71 |
+
if resume_state is None:
|
72 |
+
util.mkdir_and_rename(
|
73 |
+
opt['path']['experiments_root']) # rename experiment folder if exists
|
74 |
+
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
|
75 |
+
and 'pretrain_model' not in key and 'resume' not in key))
|
76 |
+
|
77 |
+
# config loggers. Before it, the log will not work
|
78 |
+
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
79 |
+
screen=True, tofile=True)
|
80 |
+
util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
|
81 |
+
screen=True, tofile=True)
|
82 |
+
logger = logging.getLogger('base')
|
83 |
+
logger.info(option.dict2str(opt))
|
84 |
+
# tensorboard logger
|
85 |
+
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
86 |
+
version = float(torch.__version__[0:3])
|
87 |
+
if version >= 1.1: # PyTorch 1.1
|
88 |
+
from torch.utils.tensorboard import SummaryWriter
|
89 |
+
else:
|
90 |
+
logger.info(
|
91 |
+
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
92 |
+
from tensorboardX import SummaryWriter
|
93 |
+
tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
|
94 |
+
else:
|
95 |
+
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
96 |
+
logger = logging.getLogger('base')
|
97 |
+
|
98 |
+
# convert to NoneDict, which returns None for missing keys
|
99 |
+
opt = option.dict_to_nonedict(opt)
|
100 |
+
|
101 |
+
# random seed
|
102 |
+
seed = opt['train']['manual_seed']
|
103 |
+
if seed is None:
|
104 |
+
seed = random.randint(1, 10000)
|
105 |
+
if rank <= 0:
|
106 |
+
logger.info('Random seed: {}'.format(seed))
|
107 |
+
util.set_random_seed(seed)
|
108 |
+
|
109 |
+
torch.backends.cudnn.benchmark = True
|
110 |
+
# torch.backends.cudnn.deterministic = True
|
111 |
+
|
112 |
+
#### create train and val dataloader
|
113 |
+
dataset_ratio = 200 # enlarge the size of each epoch
|
114 |
+
for phase, dataset_opt in opt['datasets'].items():
|
115 |
+
if phase == 'train':
|
116 |
+
train_set = create_dataset(dataset_opt)
|
117 |
+
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
118 |
+
total_iters = int(opt['train']['niter'])
|
119 |
+
total_epochs = int(math.ceil(total_iters / train_size))
|
120 |
+
if opt['dist']:
|
121 |
+
train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
|
122 |
+
total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
123 |
+
else:
|
124 |
+
train_sampler = None
|
125 |
+
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
|
126 |
+
if rank <= 0:
|
127 |
+
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
128 |
+
len(train_set), train_size))
|
129 |
+
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
130 |
+
total_epochs, total_iters))
|
131 |
+
elif phase == 'val':
|
132 |
+
val_set = create_dataset(dataset_opt)
|
133 |
+
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
|
134 |
+
if rank <= 0:
|
135 |
+
logger.info('Number of val images in [{:s}]: {:d}'.format(
|
136 |
+
dataset_opt['name'], len(val_set)))
|
137 |
+
else:
|
138 |
+
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
139 |
+
assert train_loader is not None
|
140 |
+
|
141 |
+
# create model
|
142 |
+
model = create_model(opt)
|
143 |
+
# resume training
|
144 |
+
if resume_state:
|
145 |
+
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
146 |
+
resume_state['epoch'], resume_state['iter']))
|
147 |
+
|
148 |
+
start_epoch = resume_state['epoch']
|
149 |
+
current_step = resume_state['iter']
|
150 |
+
model.resume_training(resume_state) # handle optimizers and schedulers
|
151 |
+
else:
|
152 |
+
current_step = 0
|
153 |
+
start_epoch = 0
|
154 |
+
|
155 |
+
# training
|
156 |
+
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
157 |
+
for epoch in range(start_epoch, total_epochs + 1):
|
158 |
+
if opt['dist']:
|
159 |
+
train_sampler.set_epoch(epoch)
|
160 |
+
for _, train_data in enumerate(train_loader):
|
161 |
+
current_step += 1
|
162 |
+
if current_step > total_iters:
|
163 |
+
break
|
164 |
+
# training
|
165 |
+
model.feed_data(train_data)
|
166 |
+
model.optimize_parameters(current_step)
|
167 |
+
|
168 |
+
# update learning rate
|
169 |
+
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
|
170 |
+
|
171 |
+
# log
|
172 |
+
if current_step % opt['logger']['print_freq'] == 0:
|
173 |
+
logs = model.get_current_log()
|
174 |
+
message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
|
175 |
+
epoch, current_step, model.get_current_learning_rate())
|
176 |
+
for k, v in logs.items():
|
177 |
+
message += '{:s}: {:.4e} '.format(k, v)
|
178 |
+
# tensorboard logger
|
179 |
+
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
180 |
+
if rank <= 0:
|
181 |
+
tb_logger.add_scalar(k, v, current_step)
|
182 |
+
if rank <= 0:
|
183 |
+
logger.info(message)
|
184 |
+
|
185 |
+
# validation
|
186 |
+
if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
|
187 |
+
avg_psnr = 0.0
|
188 |
+
avg_psnr_h = [0.0]*opt['num_image']
|
189 |
+
avg_psnr_lr = 0.0
|
190 |
+
avg_biterr = 0.0
|
191 |
+
idx = 0
|
192 |
+
for image_id, val_data in enumerate(val_loader):
|
193 |
+
img_dir = os.path.join(opt['path']['val_images'])
|
194 |
+
util.mkdir(img_dir)
|
195 |
+
|
196 |
+
model.feed_data(val_data)
|
197 |
+
model.test(image_id)
|
198 |
+
|
199 |
+
visuals = model.get_current_visuals()
|
200 |
+
|
201 |
+
t_step = visuals['SR'].shape[0]
|
202 |
+
idx += t_step
|
203 |
+
n = len(visuals['SR_h'])
|
204 |
+
|
205 |
+
avg_biterr += util.decoded_message_error_rate_batch(visuals['recmessage'][0], visuals['message'][0])
|
206 |
+
|
207 |
+
for i in range(t_step):
|
208 |
+
|
209 |
+
sr_img = util.tensor2img(visuals['SR'][i]) # uint8
|
210 |
+
sr_img_h = []
|
211 |
+
for j in range(n):
|
212 |
+
sr_img_h.append(util.tensor2img(visuals['SR_h'][j][i])) # uint8
|
213 |
+
gt_img = util.tensor2img(visuals['GT'][i]) # uint8
|
214 |
+
lr_img = util.tensor2img(visuals['LR'][i])
|
215 |
+
lrgt_img = []
|
216 |
+
for j in range(n):
|
217 |
+
lrgt_img.append(util.tensor2img(visuals['LR_ref'][j][i]))
|
218 |
+
|
219 |
+
# Save SR images for reference
|
220 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'SR'))
|
221 |
+
util.save_img(sr_img, save_img_path)
|
222 |
+
|
223 |
+
for j in range(n):
|
224 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(image_id, i, j, 'SR_h'))
|
225 |
+
util.save_img(sr_img_h[j], save_img_path)
|
226 |
+
|
227 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'GT'))
|
228 |
+
util.save_img(gt_img, save_img_path)
|
229 |
+
|
230 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'LR'))
|
231 |
+
util.save_img(lr_img, save_img_path)
|
232 |
+
|
233 |
+
for j in range(n):
|
234 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(image_id, i, j, 'LRGT'))
|
235 |
+
util.save_img(lrgt_img[j], save_img_path)
|
236 |
+
|
237 |
+
psnr = cal_pnsr(sr_img, gt_img)
|
238 |
+
psnr_h = []
|
239 |
+
for j in range(n):
|
240 |
+
psnr_h.append(cal_pnsr(sr_img_h[j], lrgt_img[j]))
|
241 |
+
psnr_lr = cal_pnsr(lr_img, gt_img)
|
242 |
+
|
243 |
+
avg_psnr += psnr
|
244 |
+
for j in range(n):
|
245 |
+
avg_psnr_h[j] += psnr_h[j]
|
246 |
+
avg_psnr_lr += psnr_lr
|
247 |
+
|
248 |
+
avg_psnr = avg_psnr / idx
|
249 |
+
avg_psnr_h = [psnr / idx for psnr in avg_psnr_h]
|
250 |
+
avg_psnr_lr = avg_psnr_lr / idx
|
251 |
+
avg_biterr = avg_biterr / idx
|
252 |
+
|
253 |
+
# log
|
254 |
+
res_psnr_h = ''
|
255 |
+
for p in avg_psnr_h:
|
256 |
+
res_psnr_h+=('_{:.4e}'.format(p))
|
257 |
+
|
258 |
+
logger.info('# Validation # PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}, Bit_acc: {: .4e}'.format(avg_psnr, res_psnr_h, avg_psnr_lr, avg_biterr))
|
259 |
+
logger_val = logging.getLogger('val') # validation logger
|
260 |
+
logger_val.info('<epoch:{:3d}, iter:{:8,d}> PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}, Bit_acc: {: .4e}'.format(
|
261 |
+
epoch, current_step, avg_psnr, res_psnr_h, avg_psnr_lr, avg_biterr))
|
262 |
+
# tensorboard logger
|
263 |
+
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
264 |
+
tb_logger.add_scalar('psnr', avg_psnr, current_step)
|
265 |
+
|
266 |
+
# save models and training states
|
267 |
+
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
268 |
+
if rank <= 0:
|
269 |
+
logger.info('Saving models and training states.')
|
270 |
+
model.save(current_step)
|
271 |
+
model.save_training_state(epoch, current_step)
|
272 |
+
|
273 |
+
if rank <= 0:
|
274 |
+
logger.info('Saving the final model.')
|
275 |
+
model.save('latest')
|
276 |
+
logger.info('End of training.')
|
277 |
+
|
278 |
+
|
279 |
+
if __name__ == '__main__':
|
280 |
+
main()
|
train_bit.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import argparse
|
4 |
+
import random
|
5 |
+
import logging
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
import torch.multiprocessing as mp
|
10 |
+
from data.data_sampler import DistIterSampler
|
11 |
+
|
12 |
+
import options.options as option
|
13 |
+
from utils import util
|
14 |
+
from data import create_dataloader, create_dataset
|
15 |
+
from models import create_model
|
16 |
+
|
17 |
+
|
18 |
+
def init_dist(backend='nccl', **kwargs):
|
19 |
+
''' initialization for distributed training'''
|
20 |
+
# if mp.get_start_method(allow_none=True) is None:
|
21 |
+
if mp.get_start_method(allow_none=True) != 'spawn':
|
22 |
+
mp.set_start_method('spawn')
|
23 |
+
rank = int(os.environ['RANK'])
|
24 |
+
num_gpus = torch.cuda.device_count()
|
25 |
+
torch.cuda.set_device(rank % num_gpus)
|
26 |
+
dist.init_process_group(backend=backend, **kwargs)
|
27 |
+
|
28 |
+
def cal_pnsr(sr_img, gt_img):
|
29 |
+
# calculate PSNR
|
30 |
+
gt_img = gt_img / 255.
|
31 |
+
sr_img = sr_img / 255.
|
32 |
+
psnr = util.calculate_psnr(sr_img * 255, gt_img * 255)
|
33 |
+
|
34 |
+
return psnr
|
35 |
+
|
36 |
+
def main():
|
37 |
+
# options
|
38 |
+
parser = argparse.ArgumentParser()
|
39 |
+
parser.add_argument('-opt', type=str, help='Path to option YMAL file.') # config 文件
|
40 |
+
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
41 |
+
help='job launcher')
|
42 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
43 |
+
args = parser.parse_args()
|
44 |
+
opt = option.parse(args.opt, is_train=True)
|
45 |
+
|
46 |
+
# distributed training settings
|
47 |
+
if args.launcher == 'none': # disabled distributed training
|
48 |
+
opt['dist'] = False
|
49 |
+
rank = -1
|
50 |
+
print('Disabled distributed training.')
|
51 |
+
else:
|
52 |
+
opt['dist'] = True
|
53 |
+
init_dist()
|
54 |
+
world_size = torch.distributed.get_world_size()
|
55 |
+
rank = torch.distributed.get_rank()
|
56 |
+
|
57 |
+
# loading resume state if exists
|
58 |
+
if opt['path'].get('resume_state', None):
|
59 |
+
# distributed resuming: all load into default GPU
|
60 |
+
device_id = torch.cuda.current_device()
|
61 |
+
resume_state = torch.load(opt['path']['resume_state'],
|
62 |
+
map_location=lambda storage, loc: storage.cuda(device_id))
|
63 |
+
# resume_state = torch.load(opt['path']['resume_state'],
|
64 |
+
# map_location=lambda storage, loc: storage.cuda(device_id), strict=False)
|
65 |
+
option.check_resume(opt, resume_state['iter']) # check resume options
|
66 |
+
else:
|
67 |
+
resume_state = None
|
68 |
+
|
69 |
+
# mkdir and loggers
|
70 |
+
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
|
71 |
+
if resume_state is None:
|
72 |
+
util.mkdir_and_rename(
|
73 |
+
opt['path']['experiments_root']) # rename experiment folder if exists
|
74 |
+
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
|
75 |
+
and 'pretrain_model' not in key and 'resume' not in key))
|
76 |
+
|
77 |
+
# config loggers. Before it, the log will not work
|
78 |
+
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
79 |
+
screen=True, tofile=True)
|
80 |
+
util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
|
81 |
+
screen=True, tofile=True)
|
82 |
+
logger = logging.getLogger('base')
|
83 |
+
logger.info(option.dict2str(opt))
|
84 |
+
# tensorboard logger
|
85 |
+
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
86 |
+
version = float(torch.__version__[0:3])
|
87 |
+
if version >= 1.1: # PyTorch 1.1
|
88 |
+
from torch.utils.tensorboard import SummaryWriter
|
89 |
+
else:
|
90 |
+
logger.info(
|
91 |
+
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
92 |
+
from tensorboardX import SummaryWriter
|
93 |
+
tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
|
94 |
+
else:
|
95 |
+
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
96 |
+
logger = logging.getLogger('base')
|
97 |
+
|
98 |
+
# convert to NoneDict, which returns None for missing keys
|
99 |
+
opt = option.dict_to_nonedict(opt)
|
100 |
+
|
101 |
+
# random seed
|
102 |
+
seed = opt['train']['manual_seed']
|
103 |
+
if seed is None:
|
104 |
+
seed = random.randint(1, 10000)
|
105 |
+
if rank <= 0:
|
106 |
+
logger.info('Random seed: {}'.format(seed))
|
107 |
+
util.set_random_seed(seed)
|
108 |
+
|
109 |
+
torch.backends.cudnn.benchmark = True
|
110 |
+
# torch.backends.cudnn.deterministic = True
|
111 |
+
|
112 |
+
#### create train and val dataloader
|
113 |
+
dataset_ratio = 200 # enlarge the size of each epoch
|
114 |
+
for phase, dataset_opt in opt['datasets'].items():
|
115 |
+
if phase == 'train':
|
116 |
+
train_set = create_dataset(dataset_opt)
|
117 |
+
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
|
118 |
+
total_iters = int(opt['train']['niter'])
|
119 |
+
total_epochs = int(math.ceil(total_iters / train_size))
|
120 |
+
if opt['dist']:
|
121 |
+
train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
|
122 |
+
total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
123 |
+
else:
|
124 |
+
train_sampler = None
|
125 |
+
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
|
126 |
+
if rank <= 0:
|
127 |
+
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
128 |
+
len(train_set), train_size))
|
129 |
+
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
130 |
+
total_epochs, total_iters))
|
131 |
+
elif phase == 'val':
|
132 |
+
val_set = create_dataset(dataset_opt)
|
133 |
+
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
|
134 |
+
if rank <= 0:
|
135 |
+
logger.info('Number of val images in [{:s}]: {:d}'.format(
|
136 |
+
dataset_opt['name'], len(val_set)))
|
137 |
+
else:
|
138 |
+
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
139 |
+
assert train_loader is not None
|
140 |
+
|
141 |
+
# create model
|
142 |
+
model = create_model(opt)
|
143 |
+
# resume training
|
144 |
+
if resume_state:
|
145 |
+
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
146 |
+
resume_state['epoch'], resume_state['iter']))
|
147 |
+
|
148 |
+
start_epoch = resume_state['epoch']
|
149 |
+
current_step = resume_state['iter']
|
150 |
+
model.resume_training(resume_state) # handle optimizers and schedulers
|
151 |
+
else:
|
152 |
+
current_step = 0
|
153 |
+
start_epoch = 0
|
154 |
+
|
155 |
+
# training
|
156 |
+
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
157 |
+
for epoch in range(start_epoch, total_epochs + 1):
|
158 |
+
if opt['dist']:
|
159 |
+
train_sampler.set_epoch(epoch)
|
160 |
+
for _, train_data in enumerate(train_loader):
|
161 |
+
current_step += 1
|
162 |
+
if current_step > total_iters:
|
163 |
+
break
|
164 |
+
# training
|
165 |
+
model.feed_data(train_data)
|
166 |
+
model.optimize_parameters(current_step)
|
167 |
+
|
168 |
+
# update learning rate
|
169 |
+
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
|
170 |
+
|
171 |
+
# log
|
172 |
+
if current_step % opt['logger']['print_freq'] == 0:
|
173 |
+
logs = model.get_current_log()
|
174 |
+
message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
|
175 |
+
epoch, current_step, model.get_current_learning_rate())
|
176 |
+
for k, v in logs.items():
|
177 |
+
message += '{:s}: {:.4e} '.format(k, v)
|
178 |
+
# tensorboard logger
|
179 |
+
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
180 |
+
if rank <= 0:
|
181 |
+
tb_logger.add_scalar(k, v, current_step)
|
182 |
+
if rank <= 0:
|
183 |
+
logger.info(message)
|
184 |
+
|
185 |
+
# validation
|
186 |
+
if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
|
187 |
+
avg_psnr = 0.0
|
188 |
+
avg_psnr_h = [0.0]*opt['num_image']
|
189 |
+
avg_psnr_lr = 0.0
|
190 |
+
avg_biterr = 0.0
|
191 |
+
idx = 0
|
192 |
+
for image_id, val_data in enumerate(val_loader):
|
193 |
+
img_dir = os.path.join(opt['path']['val_images'])
|
194 |
+
util.mkdir(img_dir)
|
195 |
+
|
196 |
+
model.feed_data(val_data)
|
197 |
+
model.test(image_id)
|
198 |
+
|
199 |
+
visuals = model.get_current_visuals()
|
200 |
+
|
201 |
+
t_step = visuals['recmessage'].shape[0]
|
202 |
+
idx += t_step
|
203 |
+
n = 1
|
204 |
+
# print(visuals['message'].shape)
|
205 |
+
avg_biterr += util.decoded_message_error_rate_batch(visuals['recmessage'][0], visuals['message'][0])
|
206 |
+
print(util.decoded_message_error_rate_batch(visuals['recmessage'][0], visuals['message'][0]))
|
207 |
+
|
208 |
+
for i in range(t_step):
|
209 |
+
|
210 |
+
gt_img = util.tensor2img(visuals['GT'][i]) # uint8
|
211 |
+
lr_img = util.tensor2img(visuals['LR'][i])
|
212 |
+
|
213 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'GT'))
|
214 |
+
util.save_img(gt_img, save_img_path)
|
215 |
+
|
216 |
+
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'LR'))
|
217 |
+
util.save_img(lr_img, save_img_path)
|
218 |
+
psnr_lr = cal_pnsr(lr_img, gt_img)
|
219 |
+
avg_psnr_lr += psnr_lr
|
220 |
+
|
221 |
+
avg_psnr_lr = avg_psnr_lr / idx
|
222 |
+
avg_biterr = avg_biterr / idx
|
223 |
+
|
224 |
+
logger.info('# Validation # PSNR_Stego: {:.4e}, Bit_acc: {: .4e}'.format(avg_psnr_lr, avg_biterr))
|
225 |
+
logger_val = logging.getLogger('val') # validation logger
|
226 |
+
logger_val.info('<epoch:{:3d}, iter:{:8,d}> PSNR_Stego: {:.4e}, Bit_acc: {: .4e}'.format(
|
227 |
+
epoch, current_step, avg_psnr_lr, avg_biterr))
|
228 |
+
# tensorboard logger
|
229 |
+
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
230 |
+
tb_logger.add_scalar('psnr', avg_psnr, current_step)
|
231 |
+
|
232 |
+
# save models and training states
|
233 |
+
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
234 |
+
if rank <= 0:
|
235 |
+
logger.info('Saving models and training states.')
|
236 |
+
model.save(current_step)
|
237 |
+
model.save_training_state(epoch, current_step)
|
238 |
+
|
239 |
+
if rank <= 0:
|
240 |
+
logger.info('Saving the final model.')
|
241 |
+
model.save('latest')
|
242 |
+
logger.info('End of training.')
|
243 |
+
|
244 |
+
|
245 |
+
if __name__ == '__main__':
|
246 |
+
main()
|