Ricoooo commited on
Commit
8da8f47
1 Parent(s): 3a6de9b

Add local files to repository

Browse files
Files changed (6) hide show
  1. app.py +273 -0
  2. maskextract.py +46 -0
  3. test.py +188 -0
  4. test_gradio.py +92 -0
  5. train.py +280 -0
  6. train_bit.py +246 -0
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image, ImageDraw
5
+ import requests
6
+ from copy import deepcopy
7
+ import cv2
8
+ from test_gradio import load_image, image_editing
9
+
10
+ import options.options as option
11
+ from utils.JPEG import DiffJPEG
12
+ from scipy.io.wavfile import read as wav_read
13
+ from scipy.io import wavfile
14
+
15
+ import os
16
+ import math
17
+ import argparse
18
+ import random
19
+ import logging
20
+
21
+ import torch.distributed as dist
22
+ import torch.multiprocessing as mp
23
+ from data.data_sampler import DistIterSampler
24
+
25
+ from utils import util
26
+ from data.util import read_img
27
+ from models import create_model as create_model_editguard
28
+ import base64
29
+ import gradio as gr
30
+
31
+ from scipy.ndimage import zoom
32
+
33
+ import matplotlib.pyplot as plt
34
+
35
+ def img_to_base64(filepath):
36
+ with open(filepath, "rb") as img_file:
37
+ return base64.b64encode(img_file.read()).decode()
38
+
39
+ logo_base64 = img_to_base64("../logo.png")
40
+
41
+ html_content = f"""
42
+ <div style='display: flex; align-items: center; justify-content: center; padding: 20px;'>
43
+ <img src='data:image/png;base64,{logo_base64}' alt='Logo' style='height: 50px; margin-right: 20px;'>
44
+ <strong><font size='8'>EditGuard<font></strong>
45
+ </div>
46
+ """
47
+
48
+ # Examples
49
+ examples = [
50
+ ["../dataset/examples/0011.png"],
51
+ ["../dataset/examples/0012.png"],
52
+ ["../dataset/examples/0003.png"],
53
+ ["../dataset/examples/0004.png"],
54
+ ["../dataset/examples/0005.png"],
55
+ ["../dataset/examples/0006.png"],
56
+ ["../dataset/examples/0007.png"],
57
+ ["../dataset/examples/0008.png"],
58
+ ["../dataset/examples/0009.png"],
59
+ ["../dataset/examples/0010.png"],
60
+ ["../dataset/examples/0002.png"],
61
+ ]
62
+
63
+ default_example = examples[0]
64
+
65
+ def hiding(image_input, bit_input, model):
66
+
67
+ message = np.array([int(bit_input[i:i+1]) for i in range(0, len(bit_input), 1)])
68
+ message = message - 0.5
69
+ val_data = load_image(image_input, message)
70
+ model.feed_data(val_data)
71
+ container = model.image_hiding()
72
+
73
+ from PIL import Image
74
+ image = Image.fromarray(container)
75
+ return container, container
76
+
77
+ def rand(num_bits=64):
78
+ random_str = ''.join([str(random.randint(0, 1)) for _ in range(num_bits)])
79
+ return random_str
80
+
81
+ def ImageEdit(img, prompt, model_index):
82
+ image, mask = img["image"], np.float32(img["mask"])
83
+
84
+ received_image = image_editing(image, mask, prompt)
85
+ return received_image, received_image, received_image
86
+
87
+
88
+ def imgae_model_select(ckp_index=0):
89
+ # options
90
+ opt = option.parse("options/test_editguard.yml", is_train=True)
91
+ # distributed training settings
92
+ opt['dist'] = False
93
+ rank = -1
94
+ print('Disabled distributed training.')
95
+
96
+ # loading resume state if exists
97
+ if opt['path'].get('resume_state', None):
98
+ # distributed resuming: all load into default GPU
99
+ device_id = torch.cuda.current_device()
100
+ resume_state = torch.load(opt['path']['resume_state'],
101
+ map_location=lambda storage, loc: storage.cuda(device_id))
102
+ option.check_resume(opt, resume_state['iter']) # check resume options
103
+ else:
104
+ resume_state = None
105
+
106
+ # convert to NoneDict, which returns None for missing keys
107
+ opt = option.dict_to_nonedict(opt)
108
+ torch.backends.cudnn.benchmark = True
109
+ # create model
110
+
111
+ model = create_model_editguard(opt)
112
+
113
+ if ckp_index == 0:
114
+ model_pth = '../checkpoints/clean.pth'
115
+ print(model_pth)
116
+ model.load_test(model_pth)
117
+ return model
118
+
119
+
120
+
121
+ def Gaussian_image_degradation(image, NL):
122
+ image = torch.from_numpy(np.transpose(image, (2, 0, 1)))
123
+ image = image.unsqueeze(0)
124
+ NL = NL / 255.0
125
+ noise = np.random.normal(0, NL, image.shape)
126
+ torchnoise = torch.from_numpy(noise).float()
127
+ y_forw = image + torchnoise
128
+ y_forw = torch.clamp(y_forw, 0, 1)
129
+ y_forw = y_forw.permute(0, 2, 3, 1)
130
+ y_forw = y_forw.cpu().detach().numpy().squeeze()
131
+
132
+ y_forw = (y_forw * 255.0).astype(np.uint8)
133
+ return y_forw, y_forw
134
+
135
+
136
+
137
+ def JPEG_image_degradation(image, NL):
138
+ image = image.astype(np.float32)
139
+ image = torch.from_numpy(np.transpose(image, (2, 0, 1)))
140
+ image = image.unsqueeze(0)
141
+ JPEG = DiffJPEG(differentiable=True, quality=int(NL))
142
+ y_forw = JPEG(image)
143
+ y_forw = y_forw.permute(0, 2, 3, 1)
144
+ y_forw = y_forw.cpu().detach().numpy().squeeze()
145
+ y_forw = (y_forw * 255.0).astype(np.uint8)
146
+
147
+ return y_forw, y_forw
148
+
149
+
150
+ def revealing(image_edited, input_bit, model_list, model):
151
+
152
+ if model_list==0:
153
+ number = 0.2
154
+ else:
155
+ number = 0.2
156
+
157
+ container_data = load_image(image_edited) ## load tampered images
158
+ model.feed_data(container_data)
159
+ mask, remesg = model.image_recovery(number)
160
+ mask = Image.fromarray(mask.astype(np.uint8))
161
+ remesg = remesg.cpu().numpy()[0]
162
+ remesg = ''.join([str(int(x)) for x in remesg])
163
+ bit_acc = calculate_similarity_percentage(input_bit, remesg)
164
+ return mask, remesg, bit_acc
165
+
166
+
167
+
168
+ def calculate_similarity_percentage(str1, str2):
169
+
170
+ if len(str1) == 0:
171
+ return "原始版权水印未知"
172
+ elif len(str1) != len(str2):
173
+ return "输入输出水印长度不同"
174
+ total_length = len(str1)
175
+ same_count = sum(1 for x, y in zip(str1, str2) if x == y)
176
+ similarity_percentage = (same_count / total_length) * 100
177
+ return f"{similarity_percentage}%"
178
+
179
+
180
+
181
+ # Description
182
+ title = "<center><strong><font size='8'>EditGuard<font></strong></center>"
183
+
184
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
185
+
186
+ with gr.Blocks(css=css, title="EditGuard") as demo:
187
+ gr.HTML(html_content)
188
+ model = gr.State(value = None)
189
+ save_h = gr.State(value = None)
190
+ save_w = gr.State(value = None)
191
+ sam_global_points = gr.State([])
192
+ sam_global_point_label = gr.State([])
193
+ sam_original_image = gr.State(value=None)
194
+ sam_mask = gr.State(value=None)
195
+
196
+ with gr.Tabs():
197
+ with gr.TabItem('多功能取证水印'):
198
+
199
+ DESCRIPTION = """
200
+ ## 使用方法:
201
+ - 上传图像和版权水印(64位比特序列),点击"嵌入水印"按钮,生成带水印的图像。
202
+ - 涂抹要编辑的区域,并使用Inpainting算法编辑图像。
203
+ - 点击"提取"按钮检测篡改区域并输出版权水印。"""
204
+
205
+ gr.Markdown(DESCRIPTION)
206
+ save_inpainted_image = gr.State(value=None)
207
+ with gr.Column():
208
+ with gr.Row():
209
+ model_list = gr.Dropdown(label="选择模型", choices=["模型1"], type = 'index')
210
+ clear_button = gr.Button("清除全部")
211
+ with gr.Box():
212
+ gr.Markdown("# 1. 嵌入水印")
213
+ with gr.Row():
214
+ with gr.Column():
215
+ image_input = gr.Image(source='upload', label="原始图片", interactive=True, type="numpy", value=default_example[0])
216
+ with gr.Row():
217
+ bit_input = gr.Textbox(label="输入版权水印(64位比特序列)", placeholder="在这里输入...")
218
+ rand_bit = gr.Button("🎲 随机生成版权水印")
219
+ hiding_button = gr.Button("嵌入水印")
220
+ with gr.Column():
221
+ image_watermark = gr.Image(source="upload", label="带有水印的图片", interactive=True, type="numpy")
222
+
223
+
224
+ with gr.Box():
225
+ gr.Markdown("# 2. 篡改图片")
226
+ with gr.Row():
227
+ with gr.Column():
228
+ image_edit = gr.Image(source='upload',tool="sketch", label="选取篡改区域", interactive=True, type="numpy")
229
+ inpainting_model_list = gr.Dropdown(label="选择篡改模型", choices=["模型1:SD_inpainting"], type = 'index')
230
+ text_prompt = gr.Textbox(label="篡改提示词")
231
+ inpainting_button = gr.Button("篡改图片")
232
+ with gr.Column():
233
+ image_edited = gr.Image(source="upload", label="篡改结果", interactive=True, type="numpy")
234
+
235
+
236
+ with gr.Box():
237
+ gr.Markdown("# 3. 提取水印&篡改区域")
238
+ with gr.Row():
239
+ with gr.Column():
240
+ image_edited_1 = gr.Image(source="upload", label="待提取图片", interactive=True, type="numpy")
241
+
242
+ revealing_button = gr.Button("提取")
243
+ with gr.Column():
244
+ edit_mask = gr.Image(source='upload', label="编辑区域蒙版预测", interactive=True, type="numpy")
245
+ bit_output = gr.Textbox(label="版权水印预测")
246
+ acc_output = gr.Textbox(label="水印预测准确率")
247
+
248
+ gr.Examples(
249
+ examples=examples,
250
+ inputs=[image_input],
251
+ )
252
+
253
+
254
+ model_list.change(
255
+ imgae_model_select, inputs = [model_list], outputs=[model]
256
+ )
257
+ hiding_button.click(
258
+ hiding, inputs=[image_input, bit_input, model], outputs=[image_watermark, image_edit]
259
+ )
260
+ rand_bit.click(
261
+ rand, inputs=[], outputs=[bit_input]
262
+ )
263
+
264
+
265
+ inpainting_button.click(
266
+ ImageEdit, inputs = [image_edit, text_prompt, inpainting_model_list], outputs=[image_edited, image_edited_1, save_inpainted_image]
267
+ )
268
+
269
+ revealing_button.click(
270
+ revealing, inputs=[image_edited_1, bit_input, model_list, model], outputs=[edit_mask, bit_output, acc_output]
271
+ )
272
+
273
+ demo.launch(server_name="0.0.0.0", server_port=2004, share=True, favicon_path='../logo.png')
maskextract.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import numpy as np
4
+ import argparse
5
+
6
+ if __name__ == "__main__":
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--threshold', default=0.2, type=float, help='Path to option YMAL file.')
9
+ args = parser.parse_args()
10
+
11
+ input_folder = 'results/test_age-set'
12
+ output_folder = 'results/mask'
13
+
14
+ for filename in os.listdir(input_folder):
15
+ if filename.endswith('_0_0_LRGT.png'):
16
+ digits = filename.split('_')[0]
17
+ if digits.isdigit():
18
+ digits = int(digits)
19
+
20
+ if digits >= 0 and digits <= 1000:
21
+
22
+ input_path_LRGT = os.path.join(input_folder, filename)
23
+ input_path_SR_h = os.path.join(input_folder, filename).replace('LRGT', 'SR_h')
24
+
25
+ image_LRGT = Image.open(input_path_LRGT).convert("RGB")
26
+ image_SR_h = Image.open(input_path_SR_h).convert("RGB")
27
+
28
+ w, h = image_SR_h.size
29
+ image_LRGT = image_LRGT.resize((w, h))
30
+
31
+ array_LRGT = np.array(image_LRGT) / 255.
32
+ array_SR_h = np.array(image_SR_h) / 255.
33
+
34
+ residual = np.abs(array_LRGT - array_SR_h)
35
+
36
+ threshold = args.threshold
37
+ mask = np.where(residual > threshold, 1, 0)
38
+
39
+ os.makedirs(output_folder, exist_ok=True)
40
+
41
+ output_path = os.path.join(output_folder, str(digits+1).zfill(4)+'.png')
42
+
43
+ mask = np.sum(mask, axis=2)
44
+
45
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8))
46
+ mask_image.save(output_path)
test.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import argparse
4
+ import random
5
+ import logging
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.multiprocessing as mp
10
+ from data.data_sampler import DistIterSampler
11
+
12
+ import options.options as option
13
+ from utils import util
14
+ from data import create_dataloader, create_dataset
15
+ from models import create_model
16
+ import numpy as np
17
+
18
+
19
+ def init_dist(backend='nccl', **kwargs):
20
+ ''' initialization for distributed training'''
21
+ # if mp.get_start_method(allow_none=True) is None:
22
+ if mp.get_start_method(allow_none=True) != 'spawn':
23
+ mp.set_start_method('spawn')
24
+ rank = int(os.environ['RANK'])
25
+ num_gpus = torch.cuda.device_count()
26
+ torch.cuda.set_device(rank % num_gpus)
27
+ dist.init_process_group(backend=backend, **kwargs)
28
+
29
+ def cal_pnsr(sr_img, gt_img):
30
+ # calculate PSNR
31
+ gt_img = gt_img / 255.
32
+ sr_img = sr_img / 255.
33
+
34
+ psnr = util.calculate_psnr(sr_img * 255, gt_img * 255)
35
+
36
+ return psnr
37
+
38
+ def get_min_avg_and_indices(nums):
39
+ # Get the indices of the smallest 1000 elements
40
+ indices = sorted(range(len(nums)), key=lambda i: nums[i])[:900]
41
+
42
+ # Calculate the average of these elements
43
+ avg = sum(nums[i] for i in indices) / 900
44
+
45
+ # Write the indices to a txt file
46
+ with open("indices.txt", "w") as file:
47
+ for index in indices:
48
+ file.write(str(index) + "\n")
49
+
50
+ return avg
51
+
52
+
53
+ def main():
54
+ # options
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
57
+ parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
58
+ help='job launcher')
59
+ parser.add_argument('--ckpt', type=str, default='/userhome/NewIBSN/EditGuard_open/checkpoints/clean.pth', help='Path to pre-trained model.')
60
+ parser.add_argument('--local_rank', type=int, default=0)
61
+ args = parser.parse_args()
62
+ opt = option.parse(args.opt, is_train=True)
63
+
64
+ # distributed training settings
65
+ if args.launcher == 'none': # disabled distributed training
66
+ opt['dist'] = False
67
+ rank = -1
68
+ print('Disabled distributed training.')
69
+ else:
70
+ opt['dist'] = True
71
+ init_dist()
72
+ world_size = torch.distributed.get_world_size()
73
+ rank = torch.distributed.get_rank()
74
+
75
+ # loading resume state if exists
76
+ if opt['path'].get('resume_state', None):
77
+ # distributed resuming: all load into default GPU
78
+ device_id = torch.cuda.current_device()
79
+ resume_state = torch.load(opt['path']['resume_state'],
80
+ map_location=lambda storage, loc: storage.cuda(device_id))
81
+ option.check_resume(opt, resume_state['iter']) # check resume options
82
+ else:
83
+ resume_state = None
84
+
85
+ # convert to NoneDict, which returns None for missing keys
86
+ opt = option.dict_to_nonedict(opt)
87
+
88
+ torch.backends.cudnn.benchmark = True
89
+ # torch.backends.cudnn.deterministic = True
90
+
91
+ #### create train and val dataloader
92
+ dataset_ratio = 200 # enlarge the size of each epoch
93
+ for phase, dataset_opt in opt['datasets'].items():
94
+ print("phase", phase)
95
+ if phase == 'TD':
96
+ val_set = create_dataset(dataset_opt)
97
+ val_loader = create_dataloader(val_set, dataset_opt, opt, None)
98
+ elif phase == 'val':
99
+ val_set = create_dataset(dataset_opt)
100
+ val_loader = create_dataloader(val_set, dataset_opt, opt, None)
101
+ else:
102
+ raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
103
+
104
+ # create model
105
+ model = create_model(opt)
106
+ model.load_test(args.ckpt)
107
+
108
+ # validation
109
+ avg_psnr = 0.0
110
+ avg_psnr_h = [0.0]*opt['num_image']
111
+ avg_psnr_lr = 0.0
112
+ biterr = []
113
+ idx = 0
114
+ for image_id, val_data in enumerate(val_loader):
115
+ img_dir = os.path.join('results',opt['name'])
116
+ util.mkdir(img_dir)
117
+
118
+ model.feed_data(val_data)
119
+ model.test(image_id)
120
+
121
+ visuals = model.get_current_visuals()
122
+
123
+ t_step = visuals['SR'].shape[0]
124
+ idx += t_step
125
+ n = len(visuals['SR_h'])
126
+
127
+ a = visuals['recmessage'][0]
128
+ b = visuals['message'][0]
129
+
130
+ bitrecord = util.decoded_message_error_rate_batch(a, b)
131
+ print(bitrecord)
132
+ biterr.append(bitrecord)
133
+
134
+ for i in range(t_step):
135
+
136
+ sr_img = util.tensor2img(visuals['SR'][i]) # uint8
137
+ sr_img_h = []
138
+ for j in range(n):
139
+ sr_img_h.append(util.tensor2img(visuals['SR_h'][j][i])) # uint8
140
+ gt_img = util.tensor2img(visuals['GT'][i]) # uint8
141
+ lr_img = util.tensor2img(visuals['LR'][i])
142
+ lrgt_img = []
143
+ for j in range(n):
144
+ lrgt_img.append(util.tensor2img(visuals['LR_ref'][j][i]))
145
+
146
+ # Save SR images for reference
147
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'SR'))
148
+ util.save_img(sr_img, save_img_path)
149
+
150
+ for j in range(n):
151
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(image_id, i, j, 'SR_h'))
152
+ util.save_img(sr_img_h[j], save_img_path)
153
+
154
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'GT'))
155
+ util.save_img(gt_img, save_img_path)
156
+
157
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'LR'))
158
+ util.save_img(lr_img, save_img_path)
159
+
160
+ for j in range(n):
161
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(image_id, i, j, 'LRGT'))
162
+ util.save_img(lrgt_img[j], save_img_path)
163
+
164
+ psnr = cal_pnsr(sr_img, gt_img)
165
+ psnr_h = []
166
+ for j in range(n):
167
+ psnr_h.append(cal_pnsr(sr_img_h[j], lrgt_img[j]))
168
+ psnr_lr = cal_pnsr(lr_img, gt_img)
169
+
170
+ avg_psnr += psnr
171
+ for j in range(n):
172
+ avg_psnr_h[j] += psnr_h[j]
173
+ avg_psnr_lr += psnr_lr
174
+
175
+ avg_psnr = avg_psnr / idx
176
+ avg_biterr = sum(biterr) / len(biterr)
177
+ print(get_min_avg_and_indices(biterr))
178
+
179
+ avg_psnr_h = [psnr / idx for psnr in avg_psnr_h]
180
+ avg_psnr_lr = avg_psnr_lr / idx
181
+ res_psnr_h = ''
182
+ for p in avg_psnr_h:
183
+ res_psnr_h+=('_{:.4e}'.format(p))
184
+ print('# Validation # PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}, Bit_Error: {:.4e}'.format(avg_psnr, res_psnr_h, avg_psnr_lr, avg_biterr))
185
+
186
+
187
+ if __name__ == '__main__':
188
+ main()
test_gradio.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import os
4
+ import math
5
+ import argparse
6
+ import random
7
+ import logging
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ import torch.multiprocessing as mp
12
+ from data.data_sampler import DistIterSampler
13
+
14
+ import options.options as option
15
+ from utils import util
16
+ from data.util import read_img
17
+ from data import create_dataloader, create_dataset
18
+ from models import create_model
19
+ import numpy as np
20
+ from PIL import Image
21
+ from diffusers import StableDiffusionInpaintPipeline
22
+
23
+
24
+ def init_dist(backend='nccl', **kwargs):
25
+ ''' initialization for distributed training'''
26
+ # if mp.get_start_method(allow_none=True) is None:
27
+ if mp.get_start_method(allow_none=True) != 'spawn':
28
+ mp.set_start_method('spawn')
29
+ rank = int(os.environ['RANK'])
30
+ num_gpus = torch.cuda.device_count()
31
+ torch.cuda.set_device(rank % num_gpus)
32
+ dist.init_process_group(backend=backend, **kwargs)
33
+
34
+
35
+ def load_image(image, message = None):
36
+ # img_GT = read_img(None, image_path)
37
+ img_GT = image / 255
38
+ # print(img_GT)
39
+ img_GT = img_GT[:, :, [2, 1, 0]]
40
+ img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0)
41
+ img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None)
42
+ img_GT = img_GT.unsqueeze(0)
43
+
44
+ _, T, C, W, H = img_GT.shape
45
+ list_h = []
46
+ R = 0
47
+ G = 0
48
+ B = 255
49
+ image = Image.new('RGB', (W, H), (R, G, B))
50
+ result = np.array(image) / 255.
51
+ expanded_matrix = np.expand_dims(result, axis=0)
52
+ expanded_matrix = np.repeat(expanded_matrix, T, axis=0)
53
+ imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float()
54
+ imgs_LQ = imgs_LQ.permute(0, 3, 1, 2)
55
+ imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None)
56
+ imgs_LQ = imgs_LQ.unsqueeze(0)
57
+
58
+ list_h.append(imgs_LQ)
59
+
60
+ list_h = torch.stack(list_h, dim=0)
61
+
62
+ return {
63
+ 'LQ': list_h,
64
+ 'GT': img_GT,
65
+ 'MES': message
66
+ }
67
+
68
+
69
+ def image_editing(image_numpy, mask_image, prompt):
70
+
71
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
72
+ "stabilityai/stable-diffusion-2-inpainting",
73
+ torch_dtype=torch.float16,
74
+ ).to("cuda")
75
+
76
+ pil_image = Image.fromarray(image_numpy)
77
+ print(mask_image.shape)
78
+ print("maskmin", mask_image.min(), "maskmax", mask_image.max())
79
+ mask_image = Image.fromarray(mask_image.astype(np.uint8)).convert("L")
80
+ image_init = pil_image.convert("RGB").resize((512, 512))
81
+
82
+ h, w = mask_image.size
83
+
84
+ image_inpaint = pipe(prompt=prompt, image=image_init, mask_image=mask_image, height=w, width=h).images[0]
85
+ image_inpaint = np.array(image_inpaint) / 255.
86
+ image = np.array(image_init) / 255.
87
+ mask_image = np.array(mask_image)
88
+ mask_image = np.stack([mask_image] * 3, axis=-1) / 255.
89
+ mask_image = mask_image.astype(np.uint8)
90
+ image_fuse = image * (1 - mask_image) + image_inpaint * mask_image
91
+
92
+ return image_fuse
train.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import argparse
4
+ import random
5
+ import logging
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.multiprocessing as mp
10
+ from data.data_sampler import DistIterSampler
11
+
12
+ import options.options as option
13
+ from utils import util
14
+ from data import create_dataloader, create_dataset
15
+ from models import create_model
16
+
17
+
18
+ def init_dist(backend='nccl', **kwargs):
19
+ ''' initialization for distributed training'''
20
+ # if mp.get_start_method(allow_none=True) is None:
21
+ if mp.get_start_method(allow_none=True) != 'spawn':
22
+ mp.set_start_method('spawn')
23
+ rank = int(os.environ['RANK'])
24
+ num_gpus = torch.cuda.device_count()
25
+ torch.cuda.set_device(rank % num_gpus)
26
+ dist.init_process_group(backend=backend, **kwargs)
27
+
28
+ def cal_pnsr(sr_img, gt_img):
29
+ # calculate PSNR
30
+ gt_img = gt_img / 255.
31
+ sr_img = sr_img / 255.
32
+ psnr = util.calculate_psnr(sr_img * 255, gt_img * 255)
33
+
34
+ return psnr
35
+
36
+ def main():
37
+ # options
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument('-opt', type=str, help='Path to option YMAL file.') # config 文件
40
+ parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
41
+ help='job launcher')
42
+ parser.add_argument('--local_rank', type=int, default=0)
43
+ args = parser.parse_args()
44
+ opt = option.parse(args.opt, is_train=True)
45
+
46
+ # distributed training settings
47
+ if args.launcher == 'none': # disabled distributed training
48
+ opt['dist'] = False
49
+ rank = -1
50
+ print('Disabled distributed training.')
51
+ else:
52
+ opt['dist'] = True
53
+ init_dist()
54
+ world_size = torch.distributed.get_world_size()
55
+ rank = torch.distributed.get_rank()
56
+
57
+ # loading resume state if exists
58
+ if opt['path'].get('resume_state', None):
59
+ # distributed resuming: all load into default GPU
60
+ device_id = torch.cuda.current_device()
61
+ resume_state = torch.load(opt['path']['resume_state'],
62
+ map_location=lambda storage, loc: storage.cuda(device_id))
63
+ # resume_state = torch.load(opt['path']['resume_state'],
64
+ # map_location=lambda storage, loc: storage.cuda(device_id), strict=False)
65
+ option.check_resume(opt, resume_state['iter']) # check resume options
66
+ else:
67
+ resume_state = None
68
+
69
+ # mkdir and loggers
70
+ if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
71
+ if resume_state is None:
72
+ util.mkdir_and_rename(
73
+ opt['path']['experiments_root']) # rename experiment folder if exists
74
+ util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
75
+ and 'pretrain_model' not in key and 'resume' not in key))
76
+
77
+ # config loggers. Before it, the log will not work
78
+ util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
79
+ screen=True, tofile=True)
80
+ util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
81
+ screen=True, tofile=True)
82
+ logger = logging.getLogger('base')
83
+ logger.info(option.dict2str(opt))
84
+ # tensorboard logger
85
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
86
+ version = float(torch.__version__[0:3])
87
+ if version >= 1.1: # PyTorch 1.1
88
+ from torch.utils.tensorboard import SummaryWriter
89
+ else:
90
+ logger.info(
91
+ 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
92
+ from tensorboardX import SummaryWriter
93
+ tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
94
+ else:
95
+ util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
96
+ logger = logging.getLogger('base')
97
+
98
+ # convert to NoneDict, which returns None for missing keys
99
+ opt = option.dict_to_nonedict(opt)
100
+
101
+ # random seed
102
+ seed = opt['train']['manual_seed']
103
+ if seed is None:
104
+ seed = random.randint(1, 10000)
105
+ if rank <= 0:
106
+ logger.info('Random seed: {}'.format(seed))
107
+ util.set_random_seed(seed)
108
+
109
+ torch.backends.cudnn.benchmark = True
110
+ # torch.backends.cudnn.deterministic = True
111
+
112
+ #### create train and val dataloader
113
+ dataset_ratio = 200 # enlarge the size of each epoch
114
+ for phase, dataset_opt in opt['datasets'].items():
115
+ if phase == 'train':
116
+ train_set = create_dataset(dataset_opt)
117
+ train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
118
+ total_iters = int(opt['train']['niter'])
119
+ total_epochs = int(math.ceil(total_iters / train_size))
120
+ if opt['dist']:
121
+ train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
122
+ total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
123
+ else:
124
+ train_sampler = None
125
+ train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
126
+ if rank <= 0:
127
+ logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
128
+ len(train_set), train_size))
129
+ logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
130
+ total_epochs, total_iters))
131
+ elif phase == 'val':
132
+ val_set = create_dataset(dataset_opt)
133
+ val_loader = create_dataloader(val_set, dataset_opt, opt, None)
134
+ if rank <= 0:
135
+ logger.info('Number of val images in [{:s}]: {:d}'.format(
136
+ dataset_opt['name'], len(val_set)))
137
+ else:
138
+ raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
139
+ assert train_loader is not None
140
+
141
+ # create model
142
+ model = create_model(opt)
143
+ # resume training
144
+ if resume_state:
145
+ logger.info('Resuming training from epoch: {}, iter: {}.'.format(
146
+ resume_state['epoch'], resume_state['iter']))
147
+
148
+ start_epoch = resume_state['epoch']
149
+ current_step = resume_state['iter']
150
+ model.resume_training(resume_state) # handle optimizers and schedulers
151
+ else:
152
+ current_step = 0
153
+ start_epoch = 0
154
+
155
+ # training
156
+ logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
157
+ for epoch in range(start_epoch, total_epochs + 1):
158
+ if opt['dist']:
159
+ train_sampler.set_epoch(epoch)
160
+ for _, train_data in enumerate(train_loader):
161
+ current_step += 1
162
+ if current_step > total_iters:
163
+ break
164
+ # training
165
+ model.feed_data(train_data)
166
+ model.optimize_parameters(current_step)
167
+
168
+ # update learning rate
169
+ model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
170
+
171
+ # log
172
+ if current_step % opt['logger']['print_freq'] == 0:
173
+ logs = model.get_current_log()
174
+ message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
175
+ epoch, current_step, model.get_current_learning_rate())
176
+ for k, v in logs.items():
177
+ message += '{:s}: {:.4e} '.format(k, v)
178
+ # tensorboard logger
179
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
180
+ if rank <= 0:
181
+ tb_logger.add_scalar(k, v, current_step)
182
+ if rank <= 0:
183
+ logger.info(message)
184
+
185
+ # validation
186
+ if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
187
+ avg_psnr = 0.0
188
+ avg_psnr_h = [0.0]*opt['num_image']
189
+ avg_psnr_lr = 0.0
190
+ avg_biterr = 0.0
191
+ idx = 0
192
+ for image_id, val_data in enumerate(val_loader):
193
+ img_dir = os.path.join(opt['path']['val_images'])
194
+ util.mkdir(img_dir)
195
+
196
+ model.feed_data(val_data)
197
+ model.test(image_id)
198
+
199
+ visuals = model.get_current_visuals()
200
+
201
+ t_step = visuals['SR'].shape[0]
202
+ idx += t_step
203
+ n = len(visuals['SR_h'])
204
+
205
+ avg_biterr += util.decoded_message_error_rate_batch(visuals['recmessage'][0], visuals['message'][0])
206
+
207
+ for i in range(t_step):
208
+
209
+ sr_img = util.tensor2img(visuals['SR'][i]) # uint8
210
+ sr_img_h = []
211
+ for j in range(n):
212
+ sr_img_h.append(util.tensor2img(visuals['SR_h'][j][i])) # uint8
213
+ gt_img = util.tensor2img(visuals['GT'][i]) # uint8
214
+ lr_img = util.tensor2img(visuals['LR'][i])
215
+ lrgt_img = []
216
+ for j in range(n):
217
+ lrgt_img.append(util.tensor2img(visuals['LR_ref'][j][i]))
218
+
219
+ # Save SR images for reference
220
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'SR'))
221
+ util.save_img(sr_img, save_img_path)
222
+
223
+ for j in range(n):
224
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(image_id, i, j, 'SR_h'))
225
+ util.save_img(sr_img_h[j], save_img_path)
226
+
227
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'GT'))
228
+ util.save_img(gt_img, save_img_path)
229
+
230
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'LR'))
231
+ util.save_img(lr_img, save_img_path)
232
+
233
+ for j in range(n):
234
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(image_id, i, j, 'LRGT'))
235
+ util.save_img(lrgt_img[j], save_img_path)
236
+
237
+ psnr = cal_pnsr(sr_img, gt_img)
238
+ psnr_h = []
239
+ for j in range(n):
240
+ psnr_h.append(cal_pnsr(sr_img_h[j], lrgt_img[j]))
241
+ psnr_lr = cal_pnsr(lr_img, gt_img)
242
+
243
+ avg_psnr += psnr
244
+ for j in range(n):
245
+ avg_psnr_h[j] += psnr_h[j]
246
+ avg_psnr_lr += psnr_lr
247
+
248
+ avg_psnr = avg_psnr / idx
249
+ avg_psnr_h = [psnr / idx for psnr in avg_psnr_h]
250
+ avg_psnr_lr = avg_psnr_lr / idx
251
+ avg_biterr = avg_biterr / idx
252
+
253
+ # log
254
+ res_psnr_h = ''
255
+ for p in avg_psnr_h:
256
+ res_psnr_h+=('_{:.4e}'.format(p))
257
+
258
+ logger.info('# Validation # PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}, Bit_acc: {: .4e}'.format(avg_psnr, res_psnr_h, avg_psnr_lr, avg_biterr))
259
+ logger_val = logging.getLogger('val') # validation logger
260
+ logger_val.info('<epoch:{:3d}, iter:{:8,d}> PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}, Bit_acc: {: .4e}'.format(
261
+ epoch, current_step, avg_psnr, res_psnr_h, avg_psnr_lr, avg_biterr))
262
+ # tensorboard logger
263
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
264
+ tb_logger.add_scalar('psnr', avg_psnr, current_step)
265
+
266
+ # save models and training states
267
+ if current_step % opt['logger']['save_checkpoint_freq'] == 0:
268
+ if rank <= 0:
269
+ logger.info('Saving models and training states.')
270
+ model.save(current_step)
271
+ model.save_training_state(epoch, current_step)
272
+
273
+ if rank <= 0:
274
+ logger.info('Saving the final model.')
275
+ model.save('latest')
276
+ logger.info('End of training.')
277
+
278
+
279
+ if __name__ == '__main__':
280
+ main()
train_bit.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import argparse
4
+ import random
5
+ import logging
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.multiprocessing as mp
10
+ from data.data_sampler import DistIterSampler
11
+
12
+ import options.options as option
13
+ from utils import util
14
+ from data import create_dataloader, create_dataset
15
+ from models import create_model
16
+
17
+
18
+ def init_dist(backend='nccl', **kwargs):
19
+ ''' initialization for distributed training'''
20
+ # if mp.get_start_method(allow_none=True) is None:
21
+ if mp.get_start_method(allow_none=True) != 'spawn':
22
+ mp.set_start_method('spawn')
23
+ rank = int(os.environ['RANK'])
24
+ num_gpus = torch.cuda.device_count()
25
+ torch.cuda.set_device(rank % num_gpus)
26
+ dist.init_process_group(backend=backend, **kwargs)
27
+
28
+ def cal_pnsr(sr_img, gt_img):
29
+ # calculate PSNR
30
+ gt_img = gt_img / 255.
31
+ sr_img = sr_img / 255.
32
+ psnr = util.calculate_psnr(sr_img * 255, gt_img * 255)
33
+
34
+ return psnr
35
+
36
+ def main():
37
+ # options
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument('-opt', type=str, help='Path to option YMAL file.') # config 文件
40
+ parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
41
+ help='job launcher')
42
+ parser.add_argument('--local_rank', type=int, default=0)
43
+ args = parser.parse_args()
44
+ opt = option.parse(args.opt, is_train=True)
45
+
46
+ # distributed training settings
47
+ if args.launcher == 'none': # disabled distributed training
48
+ opt['dist'] = False
49
+ rank = -1
50
+ print('Disabled distributed training.')
51
+ else:
52
+ opt['dist'] = True
53
+ init_dist()
54
+ world_size = torch.distributed.get_world_size()
55
+ rank = torch.distributed.get_rank()
56
+
57
+ # loading resume state if exists
58
+ if opt['path'].get('resume_state', None):
59
+ # distributed resuming: all load into default GPU
60
+ device_id = torch.cuda.current_device()
61
+ resume_state = torch.load(opt['path']['resume_state'],
62
+ map_location=lambda storage, loc: storage.cuda(device_id))
63
+ # resume_state = torch.load(opt['path']['resume_state'],
64
+ # map_location=lambda storage, loc: storage.cuda(device_id), strict=False)
65
+ option.check_resume(opt, resume_state['iter']) # check resume options
66
+ else:
67
+ resume_state = None
68
+
69
+ # mkdir and loggers
70
+ if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
71
+ if resume_state is None:
72
+ util.mkdir_and_rename(
73
+ opt['path']['experiments_root']) # rename experiment folder if exists
74
+ util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
75
+ and 'pretrain_model' not in key and 'resume' not in key))
76
+
77
+ # config loggers. Before it, the log will not work
78
+ util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
79
+ screen=True, tofile=True)
80
+ util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
81
+ screen=True, tofile=True)
82
+ logger = logging.getLogger('base')
83
+ logger.info(option.dict2str(opt))
84
+ # tensorboard logger
85
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
86
+ version = float(torch.__version__[0:3])
87
+ if version >= 1.1: # PyTorch 1.1
88
+ from torch.utils.tensorboard import SummaryWriter
89
+ else:
90
+ logger.info(
91
+ 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
92
+ from tensorboardX import SummaryWriter
93
+ tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
94
+ else:
95
+ util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
96
+ logger = logging.getLogger('base')
97
+
98
+ # convert to NoneDict, which returns None for missing keys
99
+ opt = option.dict_to_nonedict(opt)
100
+
101
+ # random seed
102
+ seed = opt['train']['manual_seed']
103
+ if seed is None:
104
+ seed = random.randint(1, 10000)
105
+ if rank <= 0:
106
+ logger.info('Random seed: {}'.format(seed))
107
+ util.set_random_seed(seed)
108
+
109
+ torch.backends.cudnn.benchmark = True
110
+ # torch.backends.cudnn.deterministic = True
111
+
112
+ #### create train and val dataloader
113
+ dataset_ratio = 200 # enlarge the size of each epoch
114
+ for phase, dataset_opt in opt['datasets'].items():
115
+ if phase == 'train':
116
+ train_set = create_dataset(dataset_opt)
117
+ train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
118
+ total_iters = int(opt['train']['niter'])
119
+ total_epochs = int(math.ceil(total_iters / train_size))
120
+ if opt['dist']:
121
+ train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
122
+ total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
123
+ else:
124
+ train_sampler = None
125
+ train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
126
+ if rank <= 0:
127
+ logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
128
+ len(train_set), train_size))
129
+ logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
130
+ total_epochs, total_iters))
131
+ elif phase == 'val':
132
+ val_set = create_dataset(dataset_opt)
133
+ val_loader = create_dataloader(val_set, dataset_opt, opt, None)
134
+ if rank <= 0:
135
+ logger.info('Number of val images in [{:s}]: {:d}'.format(
136
+ dataset_opt['name'], len(val_set)))
137
+ else:
138
+ raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
139
+ assert train_loader is not None
140
+
141
+ # create model
142
+ model = create_model(opt)
143
+ # resume training
144
+ if resume_state:
145
+ logger.info('Resuming training from epoch: {}, iter: {}.'.format(
146
+ resume_state['epoch'], resume_state['iter']))
147
+
148
+ start_epoch = resume_state['epoch']
149
+ current_step = resume_state['iter']
150
+ model.resume_training(resume_state) # handle optimizers and schedulers
151
+ else:
152
+ current_step = 0
153
+ start_epoch = 0
154
+
155
+ # training
156
+ logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
157
+ for epoch in range(start_epoch, total_epochs + 1):
158
+ if opt['dist']:
159
+ train_sampler.set_epoch(epoch)
160
+ for _, train_data in enumerate(train_loader):
161
+ current_step += 1
162
+ if current_step > total_iters:
163
+ break
164
+ # training
165
+ model.feed_data(train_data)
166
+ model.optimize_parameters(current_step)
167
+
168
+ # update learning rate
169
+ model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
170
+
171
+ # log
172
+ if current_step % opt['logger']['print_freq'] == 0:
173
+ logs = model.get_current_log()
174
+ message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
175
+ epoch, current_step, model.get_current_learning_rate())
176
+ for k, v in logs.items():
177
+ message += '{:s}: {:.4e} '.format(k, v)
178
+ # tensorboard logger
179
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
180
+ if rank <= 0:
181
+ tb_logger.add_scalar(k, v, current_step)
182
+ if rank <= 0:
183
+ logger.info(message)
184
+
185
+ # validation
186
+ if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
187
+ avg_psnr = 0.0
188
+ avg_psnr_h = [0.0]*opt['num_image']
189
+ avg_psnr_lr = 0.0
190
+ avg_biterr = 0.0
191
+ idx = 0
192
+ for image_id, val_data in enumerate(val_loader):
193
+ img_dir = os.path.join(opt['path']['val_images'])
194
+ util.mkdir(img_dir)
195
+
196
+ model.feed_data(val_data)
197
+ model.test(image_id)
198
+
199
+ visuals = model.get_current_visuals()
200
+
201
+ t_step = visuals['recmessage'].shape[0]
202
+ idx += t_step
203
+ n = 1
204
+ # print(visuals['message'].shape)
205
+ avg_biterr += util.decoded_message_error_rate_batch(visuals['recmessage'][0], visuals['message'][0])
206
+ print(util.decoded_message_error_rate_batch(visuals['recmessage'][0], visuals['message'][0]))
207
+
208
+ for i in range(t_step):
209
+
210
+ gt_img = util.tensor2img(visuals['GT'][i]) # uint8
211
+ lr_img = util.tensor2img(visuals['LR'][i])
212
+
213
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'GT'))
214
+ util.save_img(gt_img, save_img_path)
215
+
216
+ save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'LR'))
217
+ util.save_img(lr_img, save_img_path)
218
+ psnr_lr = cal_pnsr(lr_img, gt_img)
219
+ avg_psnr_lr += psnr_lr
220
+
221
+ avg_psnr_lr = avg_psnr_lr / idx
222
+ avg_biterr = avg_biterr / idx
223
+
224
+ logger.info('# Validation # PSNR_Stego: {:.4e}, Bit_acc: {: .4e}'.format(avg_psnr_lr, avg_biterr))
225
+ logger_val = logging.getLogger('val') # validation logger
226
+ logger_val.info('<epoch:{:3d}, iter:{:8,d}> PSNR_Stego: {:.4e}, Bit_acc: {: .4e}'.format(
227
+ epoch, current_step, avg_psnr_lr, avg_biterr))
228
+ # tensorboard logger
229
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
230
+ tb_logger.add_scalar('psnr', avg_psnr, current_step)
231
+
232
+ # save models and training states
233
+ if current_step % opt['logger']['save_checkpoint_freq'] == 0:
234
+ if rank <= 0:
235
+ logger.info('Saving models and training states.')
236
+ model.save(current_step)
237
+ model.save_training_state(epoch, current_step)
238
+
239
+ if rank <= 0:
240
+ logger.info('Saving the final model.')
241
+ model.save('latest')
242
+ logger.info('End of training.')
243
+
244
+
245
+ if __name__ == '__main__':
246
+ main()