File size: 2,839 Bytes
d1ef8ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from model import aeModel

class ImageDataset(Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        self.image_files = [f for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
        self.transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.folder_path, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)
        return image

def train(model, dataloader, num_epochs, device):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            batch = batch.to(device)
            output = model(batch)
            loss = criterion(output, batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')


def visualize_results(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        images = next(iter(dataloader))
        images = images.to(device)
        
        reconstructions = model(images)
        fig, axes = plt.subplots(2, 5, figsize=(12, 6))
        for i in range(5):
            axes[0, i].imshow(images[i].cpu().permute(1, 2, 0))
            axes[0, i].axis('off')
            axes[1, i].imshow(reconstructions[i].cpu().permute(1, 2, 0))
            axes[1, i].axis('off')
        
        plt.tight_layout()
        plt.show()

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if ur not using nvidia for inference, are you a freak who uses directml :eww:
    print(f"Using device: {device}")
    dataset = ImageDataset('dataset/images/')
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    model = aeModel().to(device)
    #model.load_state_dict(torch.load('autoencoder_250.pth'))
    num_epochs = 250
    train(model, dataloader, num_epochs, device)
    visualize_results(model, dataloader, device)
    torch.save(model.state_dict(), 'autoencoder.pth')

if __name__ == "__main__":
    main()