|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
import os
|
|
from tqdm import tqdm
|
|
import matplotlib.pyplot as plt
|
|
from model import aeModel
|
|
|
|
class ImageDataset(Dataset):
|
|
def __init__(self, folder_path):
|
|
self.folder_path = folder_path
|
|
self.image_files = [f for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
|
|
self.transform = transforms.Compose([
|
|
transforms.Resize((64, 64)),
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
def __len__(self):
|
|
return len(self.image_files)
|
|
|
|
def __getitem__(self, idx):
|
|
img_path = os.path.join(self.folder_path, self.image_files[idx])
|
|
image = Image.open(img_path).convert('RGB')
|
|
image = self.transform(image)
|
|
return image
|
|
|
|
def train(model, dataloader, num_epochs, device):
|
|
criterion = nn.MSELoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
|
|
|
for epoch in range(num_epochs):
|
|
model.train()
|
|
total_loss = 0
|
|
for batch in tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}'):
|
|
batch = batch.to(device)
|
|
output = model(batch)
|
|
loss = criterion(output, batch)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
|
|
avg_loss = total_loss / len(dataloader)
|
|
print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
|
|
|
|
|
|
def visualize_results(model, dataloader, device):
|
|
model.eval()
|
|
with torch.no_grad():
|
|
images = next(iter(dataloader))
|
|
images = images.to(device)
|
|
|
|
reconstructions = model(images)
|
|
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
|
|
for i in range(5):
|
|
axes[0, i].imshow(images[i].cpu().permute(1, 2, 0))
|
|
axes[0, i].axis('off')
|
|
axes[1, i].imshow(reconstructions[i].cpu().permute(1, 2, 0))
|
|
axes[1, i].axis('off')
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
def main():
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
dataset = ImageDataset('dataset/images/')
|
|
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
|
model = aeModel().to(device)
|
|
|
|
num_epochs = 250
|
|
train(model, dataloader, num_epochs, device)
|
|
visualize_results(model, dataloader, device)
|
|
torch.save(model.state_dict(), 'autoencoder.pth')
|
|
|
|
if __name__ == "__main__":
|
|
main() |