import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm import random from scipy.ndimage import gaussian_filter, map_coordinates # Add this line import PIL class ResidualConvBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ResidualConvBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) self.in1 = nn.InstanceNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.in2 = nn.InstanceNorm2d(out_channels) self.relu = nn.LeakyReLU(inplace=True) self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None def forward(self, x): residual = x out = self.relu(self.in1(self.conv1(x))) out = self.in2(self.conv2(out)) if self.downsample: residual = self.downsample(x) out += residual return self.relu(out) class AttentionGate(nn.Module): def __init__(self, F_g, F_l, F_int): super(AttentionGate, self).__init__() self.W_g = nn.Sequential( nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.InstanceNorm2d(F_int) ) self.W_x = nn.Sequential( nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.InstanceNorm2d(F_int) ) self.psi = nn.Sequential( nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.InstanceNorm2d(1), nn.Sigmoid() ) self.relu = nn.LeakyReLU(inplace=True) def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.relu(g1 + x1) psi = self.psi(psi) return x * psi class EnhancedUNet(nn.Module): def __init__(self, n_channels, n_classes): super(EnhancedUNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.inc = ResidualConvBlock(n_channels, 64) self.down1 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(64, 128)) self.down2 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(128, 256)) self.down3 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(256, 512)) self.down4 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(512, 1024)) self.dilation = nn.Sequential( nn.Conv2d(1024, 1024, kernel_size=3, padding=2, dilation=2), nn.InstanceNorm2d(1024), nn.LeakyReLU(inplace=True), nn.Conv2d(1024, 1024, kernel_size=3, padding=4, dilation=4), nn.InstanceNorm2d(1024), nn.LeakyReLU(inplace=True) ) self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) self.att4 = AttentionGate(F_g=512, F_l=512, F_int=256) self.up_conv4 = ResidualConvBlock(1024, 512) self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.att3 = AttentionGate(F_g=256, F_l=256, F_int=128) self.up_conv3 = ResidualConvBlock(512, 256) self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.att2 = AttentionGate(F_g=128, F_l=128, F_int=64) self.up_conv2 = ResidualConvBlock(256, 128) self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.att1 = AttentionGate(F_g=64, F_l=64, F_int=32) self.up_conv1 = ResidualConvBlock(128, 64) self.outc = nn.Conv2d(64, n_classes, kernel_size=1) self.dropout = nn.Dropout(0.5) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x2 = self.dropout(x2) x3 = self.down2(x2) x3 = self.dropout(x3) x4 = self.down3(x3) x4 = self.dropout(x4) x5 = self.down4(x4) x5 = self.dilation(x5) x5 = self.dropout(x5) x = self.up4(x5) x4 = self.att4(g=x, x=x4) x = torch.cat([x4, x], dim=1) x = self.up_conv4(x) x = self.dropout(x) x = self.up3(x) x3 = self.att3(g=x, x=x3) x = torch.cat([x3, x], dim=1) x = self.up_conv3(x) x = self.dropout(x) x = self.up2(x) x2 = self.att2(g=x, x=x2) x = torch.cat([x2, x], dim=1) x = self.up_conv2(x) x = self.dropout(x) x = self.up1(x) x1 = self.att1(g=x, x=x1) x = torch.cat([x1, x], dim=1) x = self.up_conv1(x) logits = self.outc(x) return logits class MoS2Dataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.images_dir = os.path.join(root_dir, 'images') self.labels_dir = os.path.join(root_dir, 'labels') self.image_files = [] for f in sorted(os.listdir(self.images_dir)): if f.endswith('.png'): try: Image.open(os.path.join(self.images_dir, f)).verify() self.image_files.append(f) except: print(f"Skipping unreadable image: {f}") def __len__(self): return len(self.image_files) def __getitem__(self, idx): img_name = self.image_files[idx] img_path = os.path.join(self.images_dir, img_name) if not os.path.exists(img_path): print(f"Image file does not exist: {img_path}") return None, None label_name = f"image_{img_name.split('_')[1].replace('.png', '.npy')}" label_path = os.path.join(self.labels_dir, label_name) try: image = np.array(Image.open(img_path).convert('L'), dtype=np.float32) / 255.0 label = np.load(label_path).astype(np.int64) except (PIL.UnidentifiedImageError, FileNotFoundError, IOError) as e: print(f"Error loading image {img_path}: {str(e)}") return None, None # Or handle this case appropriately if self.transform: image, label = self.transform(image, label) image = torch.from_numpy(image).float().unsqueeze(0) label = torch.from_numpy(label).long() return image, label class AugmentationTransform: def __init__(self): self.aug_functions = [ self.random_brightness_contrast, self.random_gamma, self.random_noise, self.random_elastic_deform ] def __call__(self, image, label): for aug_func in self.aug_functions: if random.random() < 0.5: # 50% chance to apply each augmentation image, label = aug_func(image, label) return image.astype(np.float32), label # Ensure float32 def random_brightness_contrast(self, image, label): brightness = random.uniform(0.7, 1.3) contrast = random.uniform(0.7, 1.3) image = np.clip(brightness * image + contrast * (image - 0.5) + 0.5, 0, 1) return image, label def random_gamma(self, image, label): gamma = random.uniform(0.7, 1.3) image = np.power(image, gamma) return image, label def random_noise(self, image, label): noise = np.random.normal(0, 0.05, image.shape) image = np.clip(image + noise, 0, 1) return image, label def random_elastic_deform(self, image, label): alpha = random.uniform(10, 20) sigma = random.uniform(3, 5) shape = image.shape dx = np.random.rand(*shape) * 2 - 1 dy = np.random.rand(*shape) * 2 - 1 dx = gaussian_filter(dx, sigma, mode="constant", cval=0) * alpha dy = gaussian_filter(dy, sigma, mode="constant", cval=0) * alpha x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1)) image = map_coordinates(image, indices, order=1).reshape(shape) label = map_coordinates(label, indices, order=0).reshape(shape) return image, label def focal_loss(output, target, alpha=0.25, gamma=2): ce_loss = nn.CrossEntropyLoss(reduction='none')(output, target) pt = torch.exp(-ce_loss) focal_loss = alpha * (1-pt)**gamma * ce_loss return focal_loss.mean() def dice_loss(output, target, smooth=1e-5): output = torch.softmax(output, dim=1) num_classes = output.shape[1] dice_sum = 0 for c in range(num_classes): pred_class = output[:, c, :, :] target_class = (target == c).float() intersection = (pred_class * target_class).sum() union = pred_class.sum() + target_class.sum() dice = (2. * intersection + smooth) / (union + smooth) dice_sum += dice return 1 - dice_sum / num_classes def combined_loss(output, target): fl = focal_loss(output, target) dl = dice_loss(output, target) return 0.5 * fl + 0.5 * dl def iou_score(output, target): smooth = 1e-5 output = torch.argmax(output, dim=1) intersection = (output & target).float().sum((1, 2)) union = (output | target).float().sum((1, 2)) iou = (intersection + smooth) / (union + smooth) return iou.mean() def pixel_accuracy(output, target): output = torch.argmax(output, dim=1) correct = torch.eq(output, target).int() accuracy = float(correct.sum()) / float(correct.numel()) return accuracy def train_one_epoch(model, dataloader, optimizer, criterion, device): model.train() total_loss = 0 total_iou = 0 total_accuracy = 0 pbar = tqdm(dataloader, desc='Training') for images, labels in pbar: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() total_iou += iou_score(outputs, labels) total_accuracy += pixel_accuracy(outputs, labels) pbar.set_postfix({'Loss': total_loss / (pbar.n + 1), 'IoU': total_iou / (pbar.n + 1), 'Accuracy': total_accuracy / (pbar.n + 1)}) return total_loss / len(dataloader), total_iou / len(dataloader), total_accuracy / len(dataloader) def validate(model, dataloader, criterion, device): model.eval() total_loss = 0 total_iou = 0 total_accuracy = 0 with torch.no_grad(): pbar = tqdm(dataloader, desc='Validation') for images, labels in pbar: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) total_loss += loss.item() total_iou += iou_score(outputs, labels) total_accuracy += pixel_accuracy(outputs, labels) pbar.set_postfix({'Loss': total_loss / (pbar.n + 1), 'IoU': total_iou / (pbar.n + 1), 'Accuracy': total_accuracy / (pbar.n + 1)}) return total_loss / len(dataloader), total_iou / len(dataloader), total_accuracy / len(dataloader) def main(): # Hyperparameters num_classes = 4 batch_size = 64 num_epochs = 100 learning_rate = 1e-4 weight_decay = 1e-5 # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Create datasets and data loaders transform = AugmentationTransform() # dataset = MoS2Dataset('MoS2_dataset_advanced_v2', transform=transform) dataset = MoS2Dataset('dataset_with_noise_npy') train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4) # Create model model = EnhancedUNet(n_channels=1, n_classes=num_classes).to(device) # Loss and optimizer criterion = combined_loss optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=10, verbose=True) # Create directory for saving models and visualizations save_dir = 'enhanced_training_results' os.makedirs(save_dir, exist_ok=True) # Training loop best_val_iou = 0.0 for epoch in range(1, num_epochs + 1): print(f"Epoch {epoch}/{num_epochs}") train_loss, train_iou, train_accuracy = train_one_epoch(model, train_loader, optimizer, criterion, device) val_loss, val_iou, val_accuracy = validate(model, val_loader, criterion, device) print(f"Train - Loss: {train_loss:.4f}, IoU: {train_iou:.4f}, Accuracy: {train_accuracy:.4f}") print(f"Val - Loss: {val_loss:.4f}, IoU: {val_iou:.4f}, Accuracy: {val_accuracy:.4f}") scheduler.step(val_iou) if val_iou > best_val_iou: best_val_iou = val_iou torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth')) print(f"New best model saved with IoU: {best_val_iou:.4f}") # Save checkpoint torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_val_iou': best_val_iou, }, os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pth')) # Visualize predictions every 5 epochs visualize_prediction(model, val_loader, device, epoch, save_dir) print("Training completed!") def visualize_prediction(model, val_loader, device, epoch, save_dir): model.eval() images, labels = next(iter(val_loader)) images, labels = images.to(device), labels.to(device) with torch.no_grad(): outputs = model(images) images = images.cpu().numpy() labels = labels.cpu().numpy() predictions = torch.argmax(outputs, dim=1).cpu().numpy() fig, axs = plt.subplots(2, 3, figsize=(15, 10)) axs[0, 0].imshow(images[0, 0], cmap='gray') axs[0, 0].set_title('Input Image') axs[0, 1].imshow(labels[0], cmap='viridis') axs[0, 1].set_title('True Label') axs[0, 2].imshow(predictions[0], cmap='viridis') axs[0, 2].set_title('Prediction') axs[1, 0].imshow(images[1, 0], cmap='gray') axs[1, 1].imshow(labels[1], cmap='viridis') axs[1, 2].imshow(predictions[1], cmap='viridis') plt.tight_layout() plt.savefig(os.path.join(save_dir, f'prediction_epoch_{epoch}.png')) plt.close() if __name__ == "__main__": main()