vamcrizer commited on
Commit
5e67a62
·
verified ·
1 Parent(s): 04645f8

Update saicinpainting/training/trainers/default.py

Browse files
saicinpainting/training/trainers/default.py CHANGED
@@ -1,5 +1,5 @@
1
  import logging
2
- import cv2
3
  import torch
4
  import torch.nn.functional as F
5
  from omegaconf import OmegaConf
@@ -13,30 +13,6 @@ from saicinpainting.utils import add_prefix_to_keys, get_ramp
13
 
14
  LOGGER = logging.getLogger(__name__)
15
 
16
- def resize_to_square(image, target_size):
17
- h, w = image.shape[:2]
18
- if h == w:
19
- return cv2.resize(image, (target_size, target_size))
20
-
21
- dif = h if h > w else w
22
- interpolation = cv2.INTER_AREA if dif > target_size else cv2.INTER_CUBIC
23
-
24
- x_pos = (dif - w) // 2
25
- y_pos = (dif - h) // 2
26
-
27
- if len(image.shape) == 2:
28
- mask = np.zeros((dif, dif), dtype=image.dtype)
29
- mask[y_pos:y_pos+h, x_pos:x_pos+w] = image
30
- else:
31
- mask = np.zeros((dif, dif, image.shape[2]), dtype=image.dtype)
32
- mask[y_pos:y_pos+h, x_pos:x_pos+w, :] = image
33
-
34
- return cv2.resize(mask, (target_size, target_size), interpolation=interpolation)
35
-
36
- # Sử dụng
37
- target_size = 256
38
- resized_frame = resize_to_square(frame, target_size)
39
-
40
 
41
  def make_constant_area_crop_batch(batch, **kwargs):
42
  crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
@@ -48,9 +24,25 @@ def make_constant_area_crop_batch(batch, **kwargs):
48
 
49
 
50
  class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
51
- def __init__(self, *args, **kwargs):
 
 
 
 
52
  super().__init__(*args, **kwargs)
53
- self.target_size = 256 # Hoặc kích thước mong muốn khác
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def forward(self, batch):
56
  if self.training and self.rescale_size_getter is not None:
@@ -58,29 +50,6 @@ class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
58
  batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
59
  batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
60
 
61
- # Thêm đoạn code resize ở đây
62
- resized_images = []
63
- resized_masks = []
64
- for img, mask in zip(batch['image'], batch['mask']):
65
- # Chuyển từ tensor sang numpy array
66
- img_np = img.permute(1, 2, 0).cpu().numpy()
67
- mask_np = mask.squeeze().cpu().numpy()
68
-
69
- # Resize
70
- img_resized = resize_to_square(img_np, self.target_size)
71
- mask_resized = resize_to_square(mask_np, self.target_size)
72
-
73
- # Chuyển lại thành tensor
74
- img_resized = torch.from_numpy(img_resized).permute(2, 0, 1).float().to(img.device)
75
- mask_resized = torch.from_numpy(mask_resized).unsqueeze(0).float().to(mask.device)
76
-
77
- resized_images.append(img_resized)
78
- resized_masks.append(mask_resized)
79
-
80
- batch['image'] = torch.stack(resized_images)
81
- batch['mask'] = torch.stack(resized_masks)
82
-
83
- # Tiếp tục với phần còn lại của phương thức forward
84
  if self.training and self.const_area_crop_kwargs is not None:
85
  batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
86
 
@@ -203,4 +172,4 @@ class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
203
  metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
204
  metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
205
 
206
- return total_loss, metrics
 
1
  import logging
2
+
3
  import torch
4
  import torch.nn.functional as F
5
  from omegaconf import OmegaConf
 
13
 
14
  LOGGER = logging.getLogger(__name__)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def make_constant_area_crop_batch(batch, **kwargs):
18
  crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
 
24
 
25
 
26
  class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
27
+ def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
28
+ add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
29
+ distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
30
+ fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
31
+ **kwargs):
32
  super().__init__(*args, **kwargs)
33
+ self.concat_mask = concat_mask
34
+ self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None
35
+ self.image_to_discriminator = image_to_discriminator
36
+ self.add_noise_kwargs = add_noise_kwargs
37
+ self.noise_fill_hole = noise_fill_hole
38
+ self.const_area_crop_kwargs = const_area_crop_kwargs
39
+ self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \
40
+ if distance_weighter_kwargs is not None else None
41
+ self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
42
+
43
+ self.fake_fakes_proba = fake_fakes_proba
44
+ if self.fake_fakes_proba > 1e-3:
45
+ self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {}))
46
 
47
  def forward(self, batch):
48
  if self.training and self.rescale_size_getter is not None:
 
50
  batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
51
  batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if self.training and self.const_area_crop_kwargs is not None:
54
  batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
55
 
 
172
  metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
173
  metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
174
 
175
+ return total_loss, metrics