somanchiu commited on
Commit
0a5812e
·
verified ·
1 Parent(s): 78265b2

Integrating a discriminator to guide the model toward generating more realistic facial details. This did introduce some texture to the faces.

Browse files
Experimenting with Adversarial Loss/Discriminatorv3_3.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # a modified version of https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/backbones/iresnet.py
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
8
+ using_ckpt = False
9
+
10
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11
+ """3x3 convolution with padding"""
12
+ return nn.Conv2d(in_planes,
13
+ out_planes,
14
+ kernel_size=3,
15
+ stride=stride,
16
+ padding=dilation,
17
+ groups=groups,
18
+ bias=True,
19
+ dilation=dilation)
20
+
21
+
22
+ def conv1x1(in_planes, out_planes, stride=1):
23
+ """1x1 convolution"""
24
+ return nn.Conv2d(in_planes,
25
+ out_planes,
26
+ kernel_size=1,
27
+ stride=stride,
28
+ bias=True)
29
+
30
+
31
+ class IBasicBlock(nn.Module):
32
+ expansion = 1
33
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
34
+ groups=1, base_width=64, dilation=1):
35
+ super(IBasicBlock, self).__init__()
36
+ if groups != 1 or base_width != 64:
37
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
38
+ if dilation > 1:
39
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
40
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
41
+ self.conv1 = conv3x3(inplanes, planes)
42
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
43
+ self.prelu = nn.PReLU(planes)
44
+ self.conv2 = conv3x3(planes, planes, stride)
45
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
46
+ self.downsample = downsample
47
+ self.stride = stride
48
+
49
+ def forward_impl(self, x):
50
+ identity = x
51
+ out = self.bn1(x)
52
+ out = self.conv1(out)
53
+ out = self.bn2(out)
54
+ out = self.prelu(out)
55
+ out = self.conv2(out)
56
+ out = self.bn3(out)
57
+ if self.downsample is not None:
58
+ identity = self.downsample(x)
59
+ out += identity
60
+ return out
61
+
62
+ def forward(self, x):
63
+ if self.training and using_ckpt:
64
+ return checkpoint(self.forward_impl, x)
65
+ else:
66
+ return self.forward_impl(x)
67
+
68
+
69
+ class IResNet(nn.Module):
70
+ fc_scale = 14 * 14
71
+ def __init__(self,
72
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
73
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
74
+ super(IResNet, self).__init__()
75
+ self.extra_gflops = 0.0
76
+ self.fp16 = fp16
77
+ self.inplanes = 64
78
+ self.dilation = 1
79
+ if replace_stride_with_dilation is None:
80
+ replace_stride_with_dilation = [False, False, False]
81
+ if len(replace_stride_with_dilation) != 3:
82
+ raise ValueError("replace_stride_with_dilation should be None "
83
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
84
+ self.groups = groups
85
+ self.base_width = width_per_group
86
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=True)
87
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
88
+ self.prelu = nn.PReLU(self.inplanes)
89
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
90
+ self.layer2 = self._make_layer(block,
91
+ 128,
92
+ layers[1],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[0])
95
+ self.layer3 = self._make_layer(block,
96
+ 256,
97
+ layers[2],
98
+ stride=2,
99
+ dilate=replace_stride_with_dilation[1])
100
+ self.layer4 = self._make_layer(block,
101
+ 512,
102
+ layers[3],
103
+ stride=2,
104
+ dilate=replace_stride_with_dilation[2])
105
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
106
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
107
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
108
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
109
+ nn.init.constant_(self.features.weight, 1.0)
110
+ self.features.weight.requires_grad = False
111
+
112
+ # for m in self.modules():
113
+ # if isinstance(m, nn.Conv2d):
114
+ # nn.init.normal_(m.weight, 0, 0.1)
115
+ # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
116
+ # nn.init.constant_(m.weight, 1)
117
+ # nn.init.constant_(m.bias, 0)
118
+
119
+ # if zero_init_residual:
120
+ # for m in self.modules():
121
+ # if isinstance(m, IBasicBlock):
122
+ # nn.init.constant_(m.bn2.weight, 0)
123
+
124
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
125
+ downsample = None
126
+ previous_dilation = self.dilation
127
+ if dilate:
128
+ self.dilation *= stride
129
+ stride = 1
130
+ if stride != 1 or self.inplanes != planes * block.expansion:
131
+ downsample = nn.Sequential(
132
+ conv1x1(self.inplanes, planes * block.expansion, stride),
133
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
134
+ )
135
+ layers = []
136
+ layers.append(
137
+ block(self.inplanes, planes, stride, downsample, self.groups,
138
+ self.base_width, previous_dilation))
139
+ self.inplanes = planes * block.expansion
140
+ for _ in range(1, blocks):
141
+ layers.append(
142
+ block(self.inplanes,
143
+ planes,
144
+ groups=self.groups,
145
+ base_width=self.base_width,
146
+ dilation=self.dilation))
147
+
148
+ return nn.Sequential(*layers)
149
+
150
+ def forward(self, x):
151
+ with torch.cuda.amp.autocast(self.fp16):
152
+ x = self.conv1(x)
153
+ x = self.bn1(x)
154
+ x = self.prelu(x)
155
+ x = self.layer1(x)
156
+ x = self.layer2(x)
157
+ x = self.layer3(x)
158
+ x = self.layer4(x)
159
+ x = self.bn2(x)
160
+ x = torch.flatten(x, 1)
161
+ x = self.dropout(x)
162
+ x = self.fc(x.float() if self.fp16 else x)
163
+ # x = self.features(x)
164
+ return x
165
+
166
+
167
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
168
+ model = IResNet(block, layers, **kwargs)
169
+ if pretrained:
170
+ raise ValueError()
171
+ return model
172
+
173
+
174
+ def iresnet18(pretrained=False, progress=True, **kwargs):
175
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
176
+ progress, **kwargs)
177
+
178
+
179
+ def iresnet34(pretrained=False, progress=True, **kwargs):
180
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
181
+ progress, **kwargs)
182
+
183
+
184
+ def iresnet50(pretrained=False, progress=True, **kwargs):
185
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
186
+ progress, **kwargs)
187
+
188
+
189
+ def iresnet100(pretrained=False, progress=True, **kwargs):
190
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
191
+ progress, **kwargs)
192
+
193
+
194
+ def iresnet200(pretrained=False, progress=True, **kwargs):
195
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
196
+ progress, **kwargs)
Experimenting with Adversarial Loss/discriminator-16796-16328-37280.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd44167f0badfb6adfa6975b76c29ff360fb9e4eaa80b248768e15cb0145bec1
3
+ size 328890738
Experimenting with Adversarial Loss/discriminator-580-596-640.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f0cfcd96fcefa9b06d182d5a471ddf94db35c9143d0c3afb5b073502ec1cc07
3
+ size 328887546
Experimenting with Adversarial Loss/reswapper-1679500.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0653401aad18b8c82a565ea4f954e044b2c2d72b5dde965b4c06e52abddac2cf
3
+ size 553194302
Experimenting with Adversarial Loss/reswapper-1683150.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c13e960d555a7075fb38bfaa2b0fa59a6c84ca470fd85b3b0c54f526ecb32e8f
3
+ size 553194302
Experimenting with Adversarial Loss/train_dis.3_3_1_Good_1.1.1.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import os
3
+ import random
4
+ import torch
5
+ import torch.optim as optim
6
+ import torch.nn.functional as F
7
+
8
+ from Discriminatorv3_3 import iresnet50
9
+ import Image
10
+ import ModelFormat
11
+ from StyleTransferLoss import StyleTransferLoss
12
+ import onnxruntime as rt
13
+
14
+ import cv2
15
+ from insightface.data import get_image as ins_get_image
16
+ from insightface.app import FaceAnalysis
17
+ import face_align
18
+
19
+ from StyleTransferModel_128 import StyleTransferModel
20
+ from torch.utils.tensorboard import SummaryWriter
21
+
22
+ inswapper_128_path = 'inswapper_128.onnx'
23
+ img_size = 128
24
+
25
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
26
+
27
+ inswapperInferenceSession = rt.InferenceSession(inswapper_128_path, providers=providers)
28
+
29
+ faceAnalysis = FaceAnalysis(name='buffalo_l')
30
+ faceAnalysis.prepare(ctx_id=0, det_size=(512, 512))
31
+
32
+ class FocalLoss(torch.nn.Module):
33
+ def __init__(self, gamma=0, eps=1e-7):
34
+ super(FocalLoss, self).__init__()
35
+ self.gamma = gamma
36
+ self.eps = eps
37
+ self.ce = torch.nn.CrossEntropyLoss()
38
+
39
+ def forward(self, input, target):
40
+ logp = self.ce(input, target)
41
+ p = torch.exp(-logp)
42
+ loss = (1 - p) ** self.gamma * logp
43
+ return loss.mean()
44
+
45
+ def get_device():
46
+ return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47
+ style_loss_fn = StyleTransferLoss().to(get_device())
48
+
49
+ def patchgan_prediction(pred, threshold=0.5):
50
+ """Process PatchGAN output to image-level decision"""
51
+ # pred shape: (batch_size, 1, 8, 8)
52
+ probabilities = torch.sigmoid(pred)
53
+
54
+ # Two aggregation strategies
55
+ patch_confidence = probabilities.mean(dim=[1,2,3]) # Average all patches
56
+ any_patch_positive = (probabilities > threshold).any(dim=[1,2,3]).float() # Any patch thinks it's real
57
+
58
+ return patch_confidence, any_patch_positive
59
+
60
+ def compute_gradient_penalty(D, real, fake):
61
+ alpha = torch.rand(real.size(0), 1, 1, 1).to(real.device)
62
+ interpolates = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
63
+ d_interpolates = D(interpolates)
64
+
65
+ gradients = torch.autograd.grad(
66
+ outputs=d_interpolates,
67
+ inputs=interpolates,
68
+ grad_outputs=torch.ones_like(d_interpolates),
69
+ create_graph=True,
70
+ retain_graph=True
71
+ )[0]
72
+
73
+ gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
74
+ return gradient_penalty
75
+
76
+ def cosin_metric(x1,x2):
77
+ return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))
78
+
79
+ def createFakeImage(datasetDir, image, enableDataAugmentation, steps, resolution, device):
80
+ targetFaceIndex = random.randint(0, len(image)-1)
81
+ sourceFaceIndex = random.randint(0, len(image)-1)
82
+
83
+ target_img=cv2.imread(f"{datasetDir}/{image[targetFaceIndex]}")
84
+ if enableDataAugmentation and steps % 2 == 0:
85
+ target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2GRAY)
86
+ target_img = cv2.cvtColor(target_img, cv2.COLOR_GRAY2BGR)
87
+ faces = faceAnalysis.get(target_img)
88
+
89
+ if targetFaceIndex != sourceFaceIndex:
90
+ source_img = cv2.imread(f"{datasetDir}/{image[sourceFaceIndex]}")
91
+ faces2 = faceAnalysis.get(source_img)
92
+ else:
93
+ faces2 = faces
94
+
95
+ if len(faces) > 0 and len(faces2) > 0:
96
+ new_aligned_face, _ = face_align.norm_crop2(target_img, faces[0].kps, img_size)
97
+ blob = Image.getBlob(new_aligned_face)
98
+ latent = Image.getLatent(faces2[0])
99
+ else:
100
+ return createFakeImage(datasetDir, image, enableDataAugmentation, steps, resolution, device)
101
+
102
+ if targetFaceIndex != sourceFaceIndex:
103
+ input = {inswapperInferenceSession.get_inputs()[0].name: blob,
104
+ inswapperInferenceSession.get_inputs()[1].name: latent}
105
+
106
+ expected_output = inswapperInferenceSession.run([inswapperInferenceSession.get_outputs()[0].name], input)[0]
107
+ else:
108
+ expected_output = blob
109
+
110
+ expected_output_tensor = torch.from_numpy(expected_output).to(device)
111
+
112
+ if resolution != 128:
113
+ new_aligned_face, _ = face_align.norm_crop2(target_img, faces[0].kps, resolution)
114
+ blob = Image.getBlob(new_aligned_face, (resolution, resolution))
115
+
116
+ latent_tensor = torch.from_numpy(latent).to(device)
117
+ target_input_tensor = torch.from_numpy(blob).to(device)
118
+
119
+ return target_input_tensor, latent_tensor, expected_output_tensor
120
+
121
+ def train(datasetDir, learning_rate=0.0001, model_path=None, outputModelFolder='', saveModelEachSteps = 1, stopAtSteps=None, logDir=None, previewDir=None, saveAs_onnx = False, resolutions = [128], enableDataAugmentation = False):
122
+ device = get_device()
123
+ print(f"Using device: {device}")
124
+ train_g = True #True
125
+ train_d = True #False
126
+
127
+ model = StyleTransferModel().to(device)
128
+ discriminator = iresnet50().to(device) # Add discriminator
129
+ # discriminator.features.weight.requires_grad = True
130
+ optimizer_D = optim.Adam(discriminator.parameters(), lr=0.00005) # S
131
+ fake_correct_count = 0
132
+ real_correct_count = 0
133
+ d_steps = 0
134
+
135
+ if model_path is not None:
136
+ model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
137
+ lastSteps = 0
138
+ # lastSteps = 200
139
+ d_steps = 640
140
+ fake_correct_count=580
141
+ real_correct_count=596
142
+
143
+ discriminator.load_state_dict(torch.load(f"D:\\ReSwapper\\model\\discriminatorV4\\discriminator-{fake_correct_count}-{real_correct_count}-{d_steps}.pth", map_location=device), strict=False)
144
+ print(f"Loaded model from {model_path}")
145
+ if train_g:
146
+ lastSteps = int(model_path.split('-')[-1].split('.')[0])
147
+ print(f"Resuming training from step {lastSteps}")
148
+ d_steps *= 2
149
+ else:
150
+ lastSteps = 0
151
+
152
+ model.train()
153
+ model = model.to(device)
154
+ # criterion = FocalLoss(gamma=2).to(device)
155
+ # # criterion = torch.nn.CrossEntropyLoss().to(device)
156
+ # criterion = torch.nn.BCELoss().to(device)
157
+ criterion = torch.nn.BCELoss().to(device)
158
+
159
+ # Initialize optimizer
160
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
161
+ # torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)
162
+
163
+ # Initialize TensorBoard writer
164
+ if logDir is not None:
165
+ train_writer = SummaryWriter(os.path.join(logDir, "training"))
166
+ val_writer = SummaryWriter(os.path.join(logDir, "validation"))
167
+
168
+ steps = 0
169
+
170
+ image = os.listdir(datasetDir)
171
+
172
+ resolutionIndex = 0
173
+
174
+
175
+ batch_size = 5
176
+ # Training loop
177
+ while True:
178
+ start_time = datetime.now()
179
+
180
+ resolution = resolutions[resolutionIndex%len(resolutions)]
181
+ optimizer.zero_grad()
182
+
183
+ # if steps % 100 == 0 or True:
184
+ real_images_list = []
185
+
186
+ fake_images_list = []
187
+ while len(real_images_list)!=batch_size:
188
+ realFaceIndex = random.randint(0, len(image)-1)
189
+ real_img = cv2.imread(f"{datasetDir}/{image[realFaceIndex]}")
190
+ faces3 = faceAnalysis.get(real_img)
191
+ if len(faces3) == 0 : continue
192
+
193
+ aligned_real_face, _ = face_align.norm_crop2(real_img, faces3[0].kps, resolution)
194
+ real_images = torch.from_numpy(Image.getBlob(aligned_real_face, (resolution, resolution))).to(device)
195
+
196
+ real_images = F.interpolate(real_images, size=(224, 224), mode='bilinear', align_corners=False)
197
+ real_images_list.append(real_images)
198
+
199
+ while len(fake_images_list)!=batch_size:
200
+ target_input_tensor, latent_tensor, expected_output_tensor = createFakeImage(datasetDir, image, enableDataAugmentation, steps, resolution, device)
201
+
202
+ with torch.no_grad():
203
+ output = model(target_input_tensor, latent_tensor)
204
+
205
+ fake_images = output.detach() # Detach to avoid backprop through generator
206
+ fake_images = F.interpolate(fake_images, size=(224, 224), mode='bilinear', align_corners=False)
207
+
208
+ fake_images_list.append(fake_images)
209
+
210
+ if train_d and resolution == 256:
211
+ # ---------------------
212
+ # Train Discriminator
213
+ # ---------------------
214
+ optimizer_D.zero_grad()
215
+
216
+ # Use ground truth as real samples
217
+ fake_images_list = torch.stack(fake_images_list, 1).to(device)
218
+ real_images_list = torch.stack(real_images_list, 1).to(device)
219
+
220
+ real_pred = discriminator(real_images_list[0])
221
+ # real_label = torch.from_numpy([1]) * 1
222
+ # real_label = real_label.float().to(device)
223
+
224
+ d_loss_real = F.binary_cross_entropy_with_logits(real_pred, torch.ones_like(real_pred))
225
+ # d_loss_real= 1- F.cosine_similarity(real_pred, torch.ones_like(real_pred))
226
+
227
+ # old
228
+ # if real_pred.mean() > 0.5:
229
+ # real_correct_count += 1
230
+ #new
231
+ mean_per_real_sample = real_pred.mean(dim=1)
232
+
233
+ # Create a boolean mask where mean > 0.5
234
+ real_mask = mean_per_real_sample > 0.5
235
+
236
+ # Sum the True values to get the count
237
+ real_correct_count += real_mask.sum().item()
238
+ #new end
239
+
240
+ # Use generator output as fake samples
241
+ # fake_images = output.detach() # Detach to avoid backprop through generator
242
+ # fake_images = F.interpolate(fake_images, size=(224, 224), mode='bilinear', align_corners=False)
243
+ fake_pred = discriminator(fake_images_list[0])
244
+
245
+ # fake_label = [0] * 1
246
+ # fake_label = fake_label.float().to(device)
247
+
248
+ # if fake_pred.mean() < 0.5:
249
+ # fake_correct_count += 1
250
+ mean_per_fake_sample = fake_pred.mean(dim=1)
251
+
252
+ # Create a boolean mask where mean > 0.5
253
+ fake_mask = mean_per_fake_sample < 0.5
254
+
255
+ # Sum the True values to get the count
256
+ fake_correct_count += fake_mask.sum().item()
257
+
258
+ d_loss_fake = F.binary_cross_entropy_with_logits(fake_pred, torch.zeros_like(real_pred) * -1)
259
+ # d_loss_fake= 1- F.cosine_similarity(fake_pred, torch.zeros_like(real_pred))
260
+ # d_loss_fake_v2 = 1 - cosin_metric(fake_pred[0], torch.zeros_like(real_pred)[0])
261
+ d_loss = d_loss_real + d_loss_fake
262
+ d_loss.backward()
263
+ optimizer_D.step()
264
+ d_steps += batch_size * 2
265
+
266
+ # real, p = patchgan_prediction(real_pred)
267
+ # fake, p2 = patchgan_prediction(fake_pred)
268
+
269
+ #Train Gen
270
+ if train_g:
271
+
272
+ target_input_tensor, latent_tensor, expected_output_tensor = createFakeImage(datasetDir, image, enableDataAugmentation, steps, resolution, device)
273
+
274
+ output = model(target_input_tensor, latent_tensor)
275
+
276
+ if (resolution != 128):
277
+ output_128 = F.interpolate(output, size=(128, 128), mode='bilinear', align_corners=False)
278
+
279
+ content_loss, identity_loss = style_loss_fn(output_128, expected_output_tensor)
280
+ # Adversarial loss
281
+ output_224 = F.interpolate(output, size=(224, 224), mode='bilinear', align_corners=False)
282
+
283
+ fake_pred = discriminator(output_224)
284
+ adversarial_loss = F.binary_cross_entropy_with_logits(fake_pred, torch.ones_like(fake_pred))
285
+
286
+ loss = content_loss + adversarial_loss
287
+
288
+ if identity_loss is not None:
289
+ loss +=identity_loss
290
+
291
+
292
+ loss.backward()
293
+
294
+ optimizer.step()
295
+
296
+ steps += 1
297
+ totalSteps = steps + lastSteps
298
+
299
+ acc = (fake_correct_count+real_correct_count)/ d_steps
300
+
301
+ if logDir is not None:
302
+ if train_g:
303
+ train_writer.add_scalar("Loss/total", loss.item(), totalSteps)
304
+ train_writer.add_scalar("Loss/content_loss", content_loss.item(), totalSteps)
305
+ train_writer.add_scalar("Loss/adversarial_loss", adversarial_loss.item(), totalSteps)
306
+
307
+ if identity_loss is not None:
308
+ train_writer.add_scalar("Loss/identity_loss", identity_loss.item(), totalSteps)
309
+
310
+ if train_d:
311
+ train_writer.add_scalar("Loss/d_acc", acc, totalSteps)
312
+
313
+ train_writer.add_scalar("Loss/d_loss", d_loss.item(), totalSteps)
314
+ train_writer.add_scalar("Loss/d_loss_fake", d_loss_fake.item(), totalSteps)
315
+ train_writer.add_scalar("Loss/d_loss_real", d_loss_real.item(), totalSteps)
316
+
317
+ elapsed_time = datetime.now() - start_time
318
+
319
+ if train_d:
320
+ print(f"Total Steps: {totalSteps}, Step: {steps}, D_Loss: {d_loss.item():.4f}, d_loss_real: {d_loss_real.item():.4f}, d_loss_fake: {d_loss_fake.item():.4f}, acc: {(acc):.4f}, Elapsed time: {elapsed_time}")
321
+ if train_g:
322
+ print(f"Total Steps: {totalSteps}, Step: {steps}, G_Loss: {loss.item():.4f}, Elapsed time: {elapsed_time}")
323
+
324
+ if steps % saveModelEachSteps == 0:
325
+ if train_g:
326
+ outputModelPath = f"reswapper-{totalSteps}.pth"
327
+ if outputModelFolder != '':
328
+ outputModelPath = f"{outputModelFolder}/{outputModelPath}"
329
+ saveModel(model, outputModelPath)
330
+ if train_d:
331
+ discriminatorModelPath = f"discriminator-{fake_correct_count}-{real_correct_count}-{d_steps}.pth"
332
+ if outputModelFolder != '':
333
+ discriminatorModelPath = f"{outputModelFolder}/{discriminatorModelPath}"
334
+ saveModel(discriminator, discriminatorModelPath)
335
+
336
+ if train_g:
337
+ validation_total_loss, validation_content_loss, validation_identity_loss, swapped_face, swapped_face_256 = validate(outputModelPath)
338
+ if previewDir is not None:
339
+ cv2.imwrite(f"{previewDir}/{totalSteps}.jpg", swapped_face)
340
+ cv2.imwrite(f"{previewDir}/{totalSteps}_256.jpg", swapped_face_256)
341
+
342
+ if logDir is not None:
343
+ val_writer.add_scalar("Loss/total", validation_total_loss.item(), totalSteps)
344
+ val_writer.add_scalar("Loss/content_loss", validation_content_loss.item(), totalSteps)
345
+ if validation_identity_loss is not None:
346
+ val_writer.add_scalar("Loss/identity_loss", validation_identity_loss.item(), totalSteps)
347
+
348
+ if saveAs_onnx :
349
+ ModelFormat.save_as_onnx_model(outputModelPath)
350
+
351
+ if stopAtSteps is not None and steps == stopAtSteps:
352
+ exit()
353
+
354
+ resolutionIndex += 1
355
+
356
+ def saveModel(model, outputModelPath):
357
+ torch.save(model.state_dict(), outputModelPath)
358
+
359
+ def load_model(model_path):
360
+ device = get_device()
361
+ model = StyleTransferModel().to(device)
362
+ model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
363
+
364
+ model.eval()
365
+ return model
366
+
367
+ def swap_face(model, target_face, source_face_latent):
368
+ device = get_device()
369
+
370
+ target_tensor = torch.from_numpy(target_face).to(device)
371
+ source_tensor = torch.from_numpy(source_face_latent).to(device)
372
+
373
+ with torch.no_grad():
374
+ swapped_tensor = model(target_tensor, source_tensor)
375
+
376
+ swapped_face = Image.postprocess_face(swapped_tensor)
377
+
378
+ return swapped_face, swapped_tensor
379
+
380
+ # test image
381
+ test_img = ins_get_image('t1')
382
+
383
+ test_faces = faceAnalysis.get(test_img)
384
+ test_faces = sorted(test_faces, key = lambda x : x.bbox[0])
385
+ test_target_face, _ = face_align.norm_crop2(test_img, test_faces[0].kps, img_size)
386
+ test_target_face = Image.getBlob(test_target_face)
387
+ test_l = Image.getLatent(test_faces[2])
388
+
389
+ test_target_face_256, _ = face_align.norm_crop2(test_img, test_faces[0].kps, 256)
390
+ test_target_face_256 = Image.getBlob(test_target_face_256, (256, 256))
391
+
392
+ test_input = {inswapperInferenceSession.get_inputs()[0].name: test_target_face,
393
+ inswapperInferenceSession.get_inputs()[1].name: test_l}
394
+
395
+ test_inswapperOutput = inswapperInferenceSession.run([inswapperInferenceSession.get_outputs()[0].name], test_input)[0]
396
+
397
+ def validate(modelPath):
398
+ model = load_model(modelPath)
399
+ swapped_face, swapped_tensor= swap_face(model, test_target_face, test_l)
400
+ swapped_face_256, _= swap_face(model, test_target_face_256, test_l)
401
+
402
+ validation_content_loss, validation_identity_loss = style_loss_fn(swapped_tensor, torch.from_numpy(test_inswapperOutput).to(get_device()))
403
+
404
+ validation_total_loss = validation_content_loss
405
+ if validation_identity_loss is not None:
406
+ validation_total_loss += validation_identity_loss
407
+
408
+ return validation_total_loss, validation_content_loss, validation_identity_loss, swapped_face, swapped_face_256
409
+
410
+ def main():
411
+ outputModelFolder = "model/discriminatorV4"
412
+ modelPath = None
413
+ modelPath = f"model/discriminatorV4/reswapper-1679500.pth"
414
+
415
+ logDir = "training/log/moreRes"
416
+ previewDir = "training/preview/moreRes"
417
+ datasetDir = "FFHQ"
418
+
419
+ os.makedirs(outputModelFolder, exist_ok=True)
420
+ os.makedirs(previewDir, exist_ok=True)
421
+
422
+ train(
423
+ datasetDir=datasetDir,
424
+ model_path=modelPath,
425
+ learning_rate=0.000001,
426
+ resolutions = [256],
427
+ enableDataAugmentation=True,
428
+ outputModelFolder=outputModelFolder,
429
+ saveModelEachSteps = 100,
430
+ stopAtSteps = 70000,
431
+ logDir=f"{logDir}/{datetime.now().strftime('%Y%m%d %H%M%S')}",
432
+ previewDir=previewDir)
433
+
434
+ if __name__ == "__main__":
435
+ main()