Update README.md
Browse files
README.md
CHANGED
@@ -48,20 +48,22 @@ class ConvStem(nn.Module):
|
|
48 |
Adapted from https://github.com/Xiyue-Wang/TransPath/blob/main/ctran.py#L6-L44
|
49 |
"""
|
50 |
|
51 |
-
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None,
|
52 |
super().__init__()
|
53 |
|
54 |
-
|
55 |
-
assert
|
|
|
56 |
|
57 |
img_size = to_2tuple(img_size)
|
58 |
patch_size = to_2tuple(patch_size)
|
|
|
59 |
self.img_size = img_size
|
60 |
self.patch_size = patch_size
|
61 |
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
62 |
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
63 |
-
self.flatten = flatten
|
64 |
|
|
|
65 |
stem = []
|
66 |
input_dim, output_dim = 3, embed_dim // 8
|
67 |
for l in range(2):
|
@@ -73,15 +75,17 @@ class ConvStem(nn.Module):
|
|
73 |
stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
|
74 |
self.proj = nn.Sequential(*stem)
|
75 |
|
|
|
76 |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
77 |
|
78 |
def forward(self, x):
|
79 |
B, C, H, W = x.shape
|
|
|
|
|
80 |
assert H == self.img_size[0] and W == self.img_size[1], \
|
81 |
-
|
|
|
82 |
x = self.proj(x)
|
83 |
-
if self.flatten:
|
84 |
-
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
85 |
x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
|
86 |
x = self.norm(x)
|
87 |
return x
|
|
|
48 |
Adapted from https://github.com/Xiyue-Wang/TransPath/blob/main/ctran.py#L6-L44
|
49 |
"""
|
50 |
|
51 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, **kwargs):
|
52 |
super().__init__()
|
53 |
|
54 |
+
# Check input constraints
|
55 |
+
assert patch_size == 4, "Patch size must be 4"
|
56 |
+
assert embed_dim % 8 == 0, "Embedding dimension must be a multiple of 8"
|
57 |
|
58 |
img_size = to_2tuple(img_size)
|
59 |
patch_size = to_2tuple(patch_size)
|
60 |
+
|
61 |
self.img_size = img_size
|
62 |
self.patch_size = patch_size
|
63 |
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
64 |
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
|
65 |
|
66 |
+
# Create stem network
|
67 |
stem = []
|
68 |
input_dim, output_dim = 3, embed_dim // 8
|
69 |
for l in range(2):
|
|
|
75 |
stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
|
76 |
self.proj = nn.Sequential(*stem)
|
77 |
|
78 |
+
# Apply normalization layer (if provided)
|
79 |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
80 |
|
81 |
def forward(self, x):
|
82 |
B, C, H, W = x.shape
|
83 |
+
|
84 |
+
# Check input image size
|
85 |
assert H == self.img_size[0] and W == self.img_size[1], \
|
86 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
87 |
+
|
88 |
x = self.proj(x)
|
|
|
|
|
89 |
x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
|
90 |
x = self.norm(x)
|
91 |
return x
|