cappuch commited on
Commit
d1ef8ee
1 Parent(s): 361ee4f

Upload 4 files

Browse files

model weights, inference code

Files changed (4) hide show
  1. autoencoder.pth +3 -0
  2. autoencoder.py +81 -0
  3. autoencoder_inf.py +60 -0
  4. model.py +120 -0
autoencoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0e83451feb425b0c1e6d795c9e94ec2f10ef0444bb979c02b25de0ae76bfd71
3
+ size 11511830
autoencoder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import os
8
+ from tqdm import tqdm
9
+ import matplotlib.pyplot as plt
10
+ from model import aeModel
11
+
12
+ class ImageDataset(Dataset):
13
+ def __init__(self, folder_path):
14
+ self.folder_path = folder_path
15
+ self.image_files = [f for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
16
+ self.transform = transforms.Compose([
17
+ transforms.Resize((64, 64)),
18
+ transforms.ToTensor(),
19
+ ])
20
+
21
+ def __len__(self):
22
+ return len(self.image_files)
23
+
24
+ def __getitem__(self, idx):
25
+ img_path = os.path.join(self.folder_path, self.image_files[idx])
26
+ image = Image.open(img_path).convert('RGB')
27
+ image = self.transform(image)
28
+ return image
29
+
30
+ def train(model, dataloader, num_epochs, device):
31
+ criterion = nn.MSELoss()
32
+ optimizer = optim.Adam(model.parameters(), lr=1e-3)
33
+
34
+ for epoch in range(num_epochs):
35
+ model.train()
36
+ total_loss = 0
37
+ for batch in tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}'):
38
+ batch = batch.to(device)
39
+ output = model(batch)
40
+ loss = criterion(output, batch)
41
+ optimizer.zero_grad()
42
+ loss.backward()
43
+ optimizer.step()
44
+
45
+ total_loss += loss.item()
46
+
47
+ avg_loss = total_loss / len(dataloader)
48
+ print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
49
+
50
+
51
+ def visualize_results(model, dataloader, device):
52
+ model.eval()
53
+ with torch.no_grad():
54
+ images = next(iter(dataloader))
55
+ images = images.to(device)
56
+
57
+ reconstructions = model(images)
58
+ fig, axes = plt.subplots(2, 5, figsize=(12, 6))
59
+ for i in range(5):
60
+ axes[0, i].imshow(images[i].cpu().permute(1, 2, 0))
61
+ axes[0, i].axis('off')
62
+ axes[1, i].imshow(reconstructions[i].cpu().permute(1, 2, 0))
63
+ axes[1, i].axis('off')
64
+
65
+ plt.tight_layout()
66
+ plt.show()
67
+
68
+ def main():
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if ur not using nvidia for inference, are you a freak who uses directml :eww:
70
+ print(f"Using device: {device}")
71
+ dataset = ImageDataset('dataset/images/')
72
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
73
+ model = aeModel().to(device)
74
+ #model.load_state_dict(torch.load('autoencoder_250.pth'))
75
+ num_epochs = 250
76
+ train(model, dataloader, num_epochs, device)
77
+ visualize_results(model, dataloader, device)
78
+ torch.save(model.state_dict(), 'autoencoder.pth')
79
+
80
+ if __name__ == "__main__":
81
+ main()
autoencoder_inf.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ from model import aeModel
6
+
7
+ def load_model(model_path, device):
8
+ model = aeModel().to(device)
9
+ model.load_state_dict(torch.load(model_path, map_location=device))
10
+ model.eval()
11
+ return model
12
+
13
+ def process_single_image(image_path, model, device):
14
+ transform = transforms.Compose([
15
+ transforms.Resize((64, 64)),
16
+ transforms.ToTensor(),
17
+ ])
18
+
19
+ image = Image.open(image_path).convert('RGB')
20
+ image_tensor = transform(image).unsqueeze(0).to(device)
21
+
22
+ with torch.no_grad():
23
+ encoded = model.encode(image_tensor)
24
+ reconstruction = model.decode(encoded)
25
+
26
+ print(f'Original shape: {image_tensor.shape}')
27
+ print(f'Encoded shape: {encoded.shape}')
28
+ print(f'Decoded shape: {reconstruction.shape}')
29
+
30
+ return image_tensor.squeeze(0).cpu(), reconstruction.squeeze(0).cpu()
31
+
32
+ def visualize_original_and_reconstruction(original, reconstruction):
33
+ original = torch.clamp(original, 0, 1)
34
+ reconstruction = torch.clamp(reconstruction, 0, 1)
35
+
36
+ fig, axes = plt.subplots(1, 2, figsize=(8, 4))
37
+
38
+ axes[0].imshow(original.permute(1, 2, 0))
39
+ axes[0].set_title("Original")
40
+ axes[0].axis("off")
41
+
42
+ axes[1].imshow(reconstruction.permute(1, 2, 0))
43
+ axes[1].set_title("Decoded")
44
+ axes[1].axis("off")
45
+
46
+ plt.tight_layout()
47
+ plt.show()
48
+
49
+
50
+ if __name__ == "__main__":
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ print(f"Using device: {device}")
53
+
54
+ model_path = 'autoencoder.pth'
55
+ model = load_model(model_path, device)
56
+
57
+ image_path = r"dataset\images\proof_2.png"
58
+
59
+ original, reconstruction = process_single_image(image_path, model, device)
60
+ visualize_original_and_reconstruction(original, reconstruction)
model.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class SelfAttention(nn.Module):
5
+ def __init__(self, in_channels):
6
+ super(SelfAttention, self).__init__()
7
+ self.query = nn.Conv2d(in_channels, in_channels//8, 1)
8
+ self.key = nn.Conv2d(in_channels, in_channels//8, 1)
9
+ self.value = nn.Conv2d(in_channels, in_channels, 1)
10
+ self.gamma = nn.Parameter(torch.zeros(1))
11
+
12
+ def forward(self, x):
13
+ batch_size, C, H, W = x.size()
14
+
15
+ q = self.query(x).view(batch_size, -1, H*W).permute(0, 2, 1)
16
+ k = self.key(x).view(batch_size, -1, H*W)
17
+ v = self.value(x).view(batch_size, -1, H*W)
18
+
19
+ attention = torch.bmm(q, k)
20
+ attention = torch.softmax(attention, dim=-1)
21
+
22
+ out = torch.bmm(v, attention.permute(0, 2, 1))
23
+ out = out.view(batch_size, C, H, W)
24
+
25
+ return self.gamma * out + x
26
+
27
+ class ResidualBlock(nn.Module):
28
+ def __init__(self, channels):
29
+ super(ResidualBlock, self).__init__()
30
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
31
+ self.bn1 = nn.BatchNorm2d(channels)
32
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
33
+ self.bn2 = nn.BatchNorm2d(channels)
34
+ self.relu = nn.ReLU()
35
+
36
+ def forward(self, x):
37
+ residual = x
38
+ out = self.relu(self.bn1(self.conv1(x)))
39
+ out = self.bn2(self.conv2(out))
40
+ out += residual
41
+ out = self.relu(out)
42
+ return out
43
+
44
+ class aeModel(nn.Module):
45
+ def __init__(self):
46
+ super(aeModel, self).__init__()
47
+
48
+ self.encoder = nn.ModuleList([
49
+ nn.Sequential(
50
+ nn.Conv2d(3, 32, 3, stride=2, padding=1),
51
+ nn.BatchNorm2d(32),
52
+ nn.ReLU(),
53
+ ResidualBlock(32)
54
+ ),
55
+ nn.Sequential(
56
+ nn.Conv2d(32, 64, 3, stride=2, padding=1),
57
+ nn.BatchNorm2d(64),
58
+ nn.ReLU(),
59
+ ResidualBlock(64)
60
+ ),
61
+ nn.Sequential(
62
+ nn.Conv2d(64, 128, 3, stride=2, padding=1),
63
+ nn.BatchNorm2d(128),
64
+ nn.ReLU(),
65
+ ResidualBlock(128),
66
+ SelfAttention(128)
67
+ ),
68
+ nn.Sequential(
69
+ nn.Conv2d(128, 256, 3, stride=2, padding=1),
70
+ nn.BatchNorm2d(256),
71
+ nn.ReLU(),
72
+ ResidualBlock(256),
73
+ SelfAttention(256)
74
+ )
75
+ ])
76
+
77
+ self.decoder = nn.ModuleList([
78
+ nn.Sequential(
79
+ nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
80
+ nn.BatchNorm2d(128),
81
+ nn.ReLU(),
82
+ ResidualBlock(128),
83
+ SelfAttention(128)
84
+ ),
85
+ nn.Sequential(
86
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
87
+ nn.BatchNorm2d(64),
88
+ nn.ReLU(),
89
+ ResidualBlock(64)
90
+ ),
91
+ nn.Sequential(
92
+ nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
93
+ nn.BatchNorm2d(32),
94
+ nn.ReLU(),
95
+ ResidualBlock(32)
96
+ ),
97
+ nn.Sequential(
98
+ nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1, output_padding=1),
99
+ nn.Sigmoid()
100
+ )
101
+ ])
102
+
103
+ def forward(self, x):
104
+ for encoder_block in self.encoder:
105
+ x = encoder_block(x)
106
+
107
+ for decoder_block in self.decoder:
108
+ x = decoder_block(x)
109
+
110
+ return x
111
+
112
+ def encode(self, x):
113
+ for encoder_block in self.encoder:
114
+ x = encoder_block(x)
115
+ return x
116
+
117
+ def decode(self, x):
118
+ for decoder_block in self.decoder:
119
+ x = decoder_block(x)
120
+ return x