TeacherPuffy commited on
Commit
a989bbc
·
verified ·
1 Parent(s): f7421f7

Update train_mlp.py

Browse files
Files changed (1) hide show
  1. train_mlp.py +31 -31
train_mlp.py CHANGED
@@ -3,9 +3,7 @@ import os
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
6
- import torchvision.transforms as transforms
7
- from PIL import Image
8
- from datasets import load_dataset
9
 
10
  # Define the MLP model
11
  class MLP(nn.Module):
@@ -22,22 +20,20 @@ class MLP(nn.Module):
22
  def forward(self, x):
23
  return self.model(x)
24
 
25
- # Preprocess the images
26
- def preprocess_image(example, image_size):
27
- image = Image.open(example['image_path']).convert('RGB')
28
- transform = transforms.Compose([
29
- transforms.Resize((image_size, image_size)),
30
- transforms.ToTensor(),
31
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
32
- ])
33
- image = transform(image)
34
- return {'image': image, 'label': example['label']}
35
 
36
  # Train the model
37
- def train_model(model, train_loader, val_loader, epochs=10, lr=0.001):
38
  criterion = nn.CrossEntropyLoss()
39
  optimizer = optim.Adam(model.parameters(), lr=lr)
40
 
 
 
 
41
  for epoch in range(epochs):
42
  model.train()
43
  running_loss = 0.0
@@ -53,7 +49,9 @@ def train_model(model, train_loader, val_loader, epochs=10, lr=0.001):
53
 
54
  running_loss += loss.item()
55
 
56
- print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')
 
 
57
 
58
  # Validation
59
  model.eval()
@@ -73,9 +71,16 @@ def train_model(model, train_loader, val_loader, epochs=10, lr=0.001):
73
  total += labels.size(0)
74
  correct += (predicted == labels).sum().item()
75
 
76
- print(f'Validation Loss: {val_loss/len(val_loader)}, Accuracy: {100 * correct / total}%')
 
 
 
 
 
 
 
77
 
78
- return val_loss / len(val_loader)
79
 
80
  # Main function
81
  def main():
@@ -84,21 +89,15 @@ def main():
84
  parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
85
  args = parser.parse_args()
86
 
87
- # Load the dataset
88
- dataset = load_dataset('your_dataset_name')
89
- train_dataset = dataset['train']
90
- val_dataset = dataset['validation']
91
 
92
  # Determine the number of classes
93
  num_classes = len(set(train_dataset['label']))
94
 
95
  # Determine the fixed resolution of the images
96
- example_image = Image.open(train_dataset[0]['image_path'])
97
- image_size = example_image.size[0] # Assuming the images are square
98
-
99
- # Preprocess the dataset
100
- train_dataset = train_dataset.map(lambda x: preprocess_image(x, image_size))
101
- val_dataset = val_dataset.map(lambda x: preprocess_image(x, image_size))
102
 
103
  # Define the model
104
  input_size = image_size * image_size * 3
@@ -107,12 +106,13 @@ def main():
107
 
108
  model = MLP(input_size, hidden_sizes, output_size)
109
 
110
- # Create data loaders
111
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
112
- val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
113
 
114
  # Train the model and get the final loss
115
- final_loss = train_model(model, train_loader, val_loader)
 
116
 
117
  # Calculate the number of parameters
118
  param_count = sum(p.numel() for p in model.parameters())
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
6
+ from datasets import load_from_disk
 
 
7
 
8
  # Define the MLP model
9
  class MLP(nn.Module):
 
20
  def forward(self, x):
21
  return self.model(x)
22
 
23
+ # Custom collate function
24
+ def custom_collate(batch):
25
+ images = torch.stack([item['image'] for item in batch])
26
+ labels = torch.tensor([item['label'] for item in batch])
27
+ return {'image': images, 'label': labels}
 
 
 
 
 
28
 
29
  # Train the model
30
+ def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, save_loss_path=None):
31
  criterion = nn.CrossEntropyLoss()
32
  optimizer = optim.Adam(model.parameters(), lr=lr)
33
 
34
+ train_losses = []
35
+ val_losses = []
36
+
37
  for epoch in range(epochs):
38
  model.train()
39
  running_loss = 0.0
 
49
 
50
  running_loss += loss.item()
51
 
52
+ avg_train_loss = running_loss / len(train_loader)
53
+ train_losses.append(avg_train_loss)
54
+ print(f'Epoch {epoch+1}, Loss: {avg_train_loss}')
55
 
56
  # Validation
57
  model.eval()
 
71
  total += labels.size(0)
72
  correct += (predicted == labels).sum().item()
73
 
74
+ avg_val_loss = val_loss / len(val_loader)
75
+ val_losses.append(avg_val_loss)
76
+ print(f'Validation Loss: {avg_val_loss}, Accuracy: {100 * correct / total}%')
77
+
78
+ if save_loss_path:
79
+ with open(save_loss_path, 'w') as f:
80
+ for epoch, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses)):
81
+ f.write(f'Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}\n')
82
 
83
+ return avg_val_loss
84
 
85
  # Main function
86
  def main():
 
89
  parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
90
  args = parser.parse_args()
91
 
92
+ # Load the preprocessed datasets
93
+ train_dataset = load_from_disk('preprocessed_train_dataset')
94
+ val_dataset = load_from_disk('preprocessed_val_dataset')
 
95
 
96
  # Determine the number of classes
97
  num_classes = len(set(train_dataset['label']))
98
 
99
  # Determine the fixed resolution of the images
100
+ image_size = train_dataset[0]['image'].size(1) # Assuming the images are square
 
 
 
 
 
101
 
102
  # Define the model
103
  input_size = image_size * image_size * 3
 
106
 
107
  model = MLP(input_size, hidden_sizes, output_size)
108
 
109
+ # Create data loaders with custom collate function
110
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)
111
+ val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate)
112
 
113
  # Train the model and get the final loss
114
+ save_loss_path = 'losses.txt'
115
+ final_loss = train_model(model, train_loader, val_loader, save_loss_path=save_loss_path)
116
 
117
  # Calculate the number of parameters
118
  param_count = sum(p.numel() for p in model.parameters())