import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os from tqdm import tqdm import matplotlib.pyplot as plt from model import aeModel class ImageDataset(Dataset): def __init__(self, folder_path): self.folder_path = folder_path self.image_files = [f for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))] self.transform = transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor(), ]) def __len__(self): return len(self.image_files) def __getitem__(self, idx): img_path = os.path.join(self.folder_path, self.image_files[idx]) image = Image.open(img_path).convert('RGB') image = self.transform(image) return image def train(model, dataloader, num_epochs, device): criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) for epoch in range(num_epochs): model.train() total_loss = 0 for batch in tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}'): batch = batch.to(device) output = model(batch) loss = criterion(output, batch) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(dataloader) print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}') def visualize_results(model, dataloader, device): model.eval() with torch.no_grad(): images = next(iter(dataloader)) images = images.to(device) reconstructions = model(images) fig, axes = plt.subplots(2, 5, figsize=(12, 6)) for i in range(5): axes[0, i].imshow(images[i].cpu().permute(1, 2, 0)) axes[0, i].axis('off') axes[1, i].imshow(reconstructions[i].cpu().permute(1, 2, 0)) axes[1, i].axis('off') plt.tight_layout() plt.show() def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if ur not using nvidia for inference, are you a freak who uses directml :eww: print(f"Using device: {device}") dataset = ImageDataset('dataset/images/') dataloader = DataLoader(dataset, batch_size=32, shuffle=True) model = aeModel().to(device) #model.load_state_dict(torch.load('autoencoder_250.pth')) num_epochs = 250 train(model, dataloader, num_epochs, device) visualize_results(model, dataloader, device) torch.save(model.state_dict(), 'autoencoder.pth') if __name__ == "__main__": main()