vamcrizer commited on
Commit
04645f8
·
verified ·
1 Parent(s): 8be3094

Update saicinpainting/training/trainers/default.py

Browse files
saicinpainting/training/trainers/default.py CHANGED
@@ -1,5 +1,5 @@
1
  import logging
2
-
3
  import torch
4
  import torch.nn.functional as F
5
  from omegaconf import OmegaConf
@@ -13,6 +13,30 @@ from saicinpainting.utils import add_prefix_to_keys, get_ramp
13
 
14
  LOGGER = logging.getLogger(__name__)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def make_constant_area_crop_batch(batch, **kwargs):
18
  crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
@@ -24,25 +48,9 @@ def make_constant_area_crop_batch(batch, **kwargs):
24
 
25
 
26
  class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
27
- def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
28
- add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
29
- distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
30
- fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
31
- **kwargs):
32
  super().__init__(*args, **kwargs)
33
- self.concat_mask = concat_mask
34
- self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None
35
- self.image_to_discriminator = image_to_discriminator
36
- self.add_noise_kwargs = add_noise_kwargs
37
- self.noise_fill_hole = noise_fill_hole
38
- self.const_area_crop_kwargs = const_area_crop_kwargs
39
- self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \
40
- if distance_weighter_kwargs is not None else None
41
- self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
42
-
43
- self.fake_fakes_proba = fake_fakes_proba
44
- if self.fake_fakes_proba > 1e-3:
45
- self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {}))
46
 
47
  def forward(self, batch):
48
  if self.training and self.rescale_size_getter is not None:
@@ -50,6 +58,29 @@ class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
50
  batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
51
  batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if self.training and self.const_area_crop_kwargs is not None:
54
  batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
55
 
 
1
  import logging
2
+ import cv2
3
  import torch
4
  import torch.nn.functional as F
5
  from omegaconf import OmegaConf
 
13
 
14
  LOGGER = logging.getLogger(__name__)
15
 
16
+ def resize_to_square(image, target_size):
17
+ h, w = image.shape[:2]
18
+ if h == w:
19
+ return cv2.resize(image, (target_size, target_size))
20
+
21
+ dif = h if h > w else w
22
+ interpolation = cv2.INTER_AREA if dif > target_size else cv2.INTER_CUBIC
23
+
24
+ x_pos = (dif - w) // 2
25
+ y_pos = (dif - h) // 2
26
+
27
+ if len(image.shape) == 2:
28
+ mask = np.zeros((dif, dif), dtype=image.dtype)
29
+ mask[y_pos:y_pos+h, x_pos:x_pos+w] = image
30
+ else:
31
+ mask = np.zeros((dif, dif, image.shape[2]), dtype=image.dtype)
32
+ mask[y_pos:y_pos+h, x_pos:x_pos+w, :] = image
33
+
34
+ return cv2.resize(mask, (target_size, target_size), interpolation=interpolation)
35
+
36
+ # Sử dụng
37
+ target_size = 256
38
+ resized_frame = resize_to_square(frame, target_size)
39
+
40
 
41
  def make_constant_area_crop_batch(batch, **kwargs):
42
  crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
 
48
 
49
 
50
  class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
51
+ def __init__(self, *args, **kwargs):
 
 
 
 
52
  super().__init__(*args, **kwargs)
53
+ self.target_size = 256 # Hoặc kích thước mong muốn khác
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def forward(self, batch):
56
  if self.training and self.rescale_size_getter is not None:
 
58
  batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
59
  batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
60
 
61
+ # Thêm đoạn code resize ở đây
62
+ resized_images = []
63
+ resized_masks = []
64
+ for img, mask in zip(batch['image'], batch['mask']):
65
+ # Chuyển từ tensor sang numpy array
66
+ img_np = img.permute(1, 2, 0).cpu().numpy()
67
+ mask_np = mask.squeeze().cpu().numpy()
68
+
69
+ # Resize
70
+ img_resized = resize_to_square(img_np, self.target_size)
71
+ mask_resized = resize_to_square(mask_np, self.target_size)
72
+
73
+ # Chuyển lại thành tensor
74
+ img_resized = torch.from_numpy(img_resized).permute(2, 0, 1).float().to(img.device)
75
+ mask_resized = torch.from_numpy(mask_resized).unsqueeze(0).float().to(mask.device)
76
+
77
+ resized_images.append(img_resized)
78
+ resized_masks.append(mask_resized)
79
+
80
+ batch['image'] = torch.stack(resized_images)
81
+ batch['mask'] = torch.stack(resized_masks)
82
+
83
+ # Tiếp tục với phần còn lại của phương thức forward
84
  if self.training and self.const_area_crop_kwargs is not None:
85
  batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
86