1aurent commited on
Commit
1d84969
·
1 Parent(s): bd80219

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +11 -7
README.md CHANGED
@@ -48,20 +48,22 @@ class ConvStem(nn.Module):
48
  Adapted from https://github.com/Xiyue-Wang/TransPath/blob/main/ctran.py#L6-L44
49
  """
50
 
51
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, flatten=False, **kwargs):
52
  super().__init__()
53
 
54
- assert patch_size == 4
55
- assert embed_dim % 8 == 0
 
56
 
57
  img_size = to_2tuple(img_size)
58
  patch_size = to_2tuple(patch_size)
 
59
  self.img_size = img_size
60
  self.patch_size = patch_size
61
  self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
62
  self.num_patches = self.grid_size[0] * self.grid_size[1]
63
- self.flatten = flatten
64
 
 
65
  stem = []
66
  input_dim, output_dim = 3, embed_dim // 8
67
  for l in range(2):
@@ -73,15 +75,17 @@ class ConvStem(nn.Module):
73
  stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
74
  self.proj = nn.Sequential(*stem)
75
 
 
76
  self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
77
 
78
  def forward(self, x):
79
  B, C, H, W = x.shape
 
 
80
  assert H == self.img_size[0] and W == self.img_size[1], \
81
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
 
82
  x = self.proj(x)
83
- if self.flatten:
84
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
85
  x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
86
  x = self.norm(x)
87
  return x
 
48
  Adapted from https://github.com/Xiyue-Wang/TransPath/blob/main/ctran.py#L6-L44
49
  """
50
 
51
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, **kwargs):
52
  super().__init__()
53
 
54
+ # Check input constraints
55
+ assert patch_size == 4, "Patch size must be 4"
56
+ assert embed_dim % 8 == 0, "Embedding dimension must be a multiple of 8"
57
 
58
  img_size = to_2tuple(img_size)
59
  patch_size = to_2tuple(patch_size)
60
+
61
  self.img_size = img_size
62
  self.patch_size = patch_size
63
  self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
64
  self.num_patches = self.grid_size[0] * self.grid_size[1]
 
65
 
66
+ # Create stem network
67
  stem = []
68
  input_dim, output_dim = 3, embed_dim // 8
69
  for l in range(2):
 
75
  stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
76
  self.proj = nn.Sequential(*stem)
77
 
78
+ # Apply normalization layer (if provided)
79
  self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
80
 
81
  def forward(self, x):
82
  B, C, H, W = x.shape
83
+
84
+ # Check input image size
85
  assert H == self.img_size[0] and W == self.img_size[1], \
86
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
87
+
88
  x = self.proj(x)
 
 
89
  x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
90
  x = self.norm(x)
91
  return x