Create train_mlp.py
Browse files- train_mlp.py +141 -0
train_mlp.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.optim as optim
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from PIL import Image
|
8 |
+
from datasets import load_dataset
|
9 |
+
|
10 |
+
# Define the MLP model
|
11 |
+
class MLP(nn.Module):
|
12 |
+
def __init__(self, input_size, hidden_sizes, output_size):
|
13 |
+
super(MLP, self).__init__()
|
14 |
+
layers = []
|
15 |
+
sizes = [input_size] + hidden_sizes + [output_size]
|
16 |
+
for i in range(len(sizes) - 1):
|
17 |
+
layers.append(nn.Linear(sizes[i], sizes[i+1]))
|
18 |
+
if i < len(sizes) - 2:
|
19 |
+
layers.append(nn.ReLU())
|
20 |
+
self.model = nn.Sequential(*layers)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
return self.model(x)
|
24 |
+
|
25 |
+
# Preprocess the images
|
26 |
+
def preprocess_image(example, image_size):
|
27 |
+
image = Image.open(example['image_path']).convert('RGB')
|
28 |
+
transform = transforms.Compose([
|
29 |
+
transforms.Resize((image_size, image_size)),
|
30 |
+
transforms.ToTensor(),
|
31 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
32 |
+
])
|
33 |
+
image = transform(image)
|
34 |
+
return {'image': image, 'label': example['label']}
|
35 |
+
|
36 |
+
# Train the model
|
37 |
+
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001):
|
38 |
+
criterion = nn.CrossEntropyLoss()
|
39 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
40 |
+
|
41 |
+
for epoch in range(epochs):
|
42 |
+
model.train()
|
43 |
+
running_loss = 0.0
|
44 |
+
for batch in train_loader:
|
45 |
+
inputs = batch['image'].view(batch['image'].size(0), -1)
|
46 |
+
labels = batch['label']
|
47 |
+
|
48 |
+
optimizer.zero_grad()
|
49 |
+
outputs = model(inputs)
|
50 |
+
loss = criterion(outputs, labels)
|
51 |
+
loss.backward()
|
52 |
+
optimizer.step()
|
53 |
+
|
54 |
+
running_loss += loss.item()
|
55 |
+
|
56 |
+
print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')
|
57 |
+
|
58 |
+
# Validation
|
59 |
+
model.eval()
|
60 |
+
val_loss = 0.0
|
61 |
+
correct = 0
|
62 |
+
total = 0
|
63 |
+
with torch.no_grad():
|
64 |
+
for batch in val_loader:
|
65 |
+
inputs = batch['image'].view(batch['image'].size(0), -1)
|
66 |
+
labels = batch['label']
|
67 |
+
|
68 |
+
outputs = model(inputs)
|
69 |
+
loss = criterion(outputs, labels)
|
70 |
+
val_loss += loss.item()
|
71 |
+
|
72 |
+
_, predicted = torch.max(outputs.data, 1)
|
73 |
+
total += labels.size(0)
|
74 |
+
correct += (predicted == labels).sum().item()
|
75 |
+
|
76 |
+
print(f'Validation Loss: {val_loss/len(val_loader)}, Accuracy: {100 * correct / total}%')
|
77 |
+
|
78 |
+
return val_loss / len(val_loader)
|
79 |
+
|
80 |
+
# Main function
|
81 |
+
def main():
|
82 |
+
parser = argparse.ArgumentParser(description='Train an MLP on a Hugging Face dataset with JPEG images and class labels.')
|
83 |
+
parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
|
84 |
+
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
|
85 |
+
args = parser.parse_args()
|
86 |
+
|
87 |
+
# Load the dataset
|
88 |
+
dataset = load_dataset('your_dataset_name')
|
89 |
+
train_dataset = dataset['train']
|
90 |
+
val_dataset = dataset['validation']
|
91 |
+
|
92 |
+
# Determine the number of classes
|
93 |
+
num_classes = len(set(train_dataset['label']))
|
94 |
+
|
95 |
+
# Determine the fixed resolution of the images
|
96 |
+
example_image = Image.open(train_dataset[0]['image_path'])
|
97 |
+
image_size = example_image.size[0] # Assuming the images are square
|
98 |
+
|
99 |
+
# Preprocess the dataset
|
100 |
+
train_dataset = train_dataset.map(lambda x: preprocess_image(x, image_size))
|
101 |
+
val_dataset = val_dataset.map(lambda x: preprocess_image(x, image_size))
|
102 |
+
|
103 |
+
# Define the model
|
104 |
+
input_size = image_size * image_size * 3
|
105 |
+
hidden_sizes = [args.width] * args.layer_count
|
106 |
+
output_size = num_classes
|
107 |
+
|
108 |
+
model = MLP(input_size, hidden_sizes, output_size)
|
109 |
+
|
110 |
+
# Create data loaders
|
111 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
|
112 |
+
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
|
113 |
+
|
114 |
+
# Train the model and get the final loss
|
115 |
+
final_loss = train_model(model, train_loader, val_loader)
|
116 |
+
|
117 |
+
# Calculate the number of parameters
|
118 |
+
param_count = sum(p.numel() for p in model.parameters())
|
119 |
+
|
120 |
+
# Create the folder for the model
|
121 |
+
model_folder = f'mlp_model_l{args.layer_count}w{args.width}'
|
122 |
+
os.makedirs(model_folder, exist_ok=True)
|
123 |
+
|
124 |
+
# Save the model
|
125 |
+
model_path = os.path.join(model_folder, 'model.pth')
|
126 |
+
torch.save(model.state_dict(), model_path)
|
127 |
+
|
128 |
+
# Write the results to a text file in the model folder
|
129 |
+
result_path = os.path.join(model_folder, 'results.txt')
|
130 |
+
with open(result_path, 'w') as f:
|
131 |
+
f.write(f'Layer Count: {args.layer_count}, Width: {args.width}, Parameter Count: {param_count}, Final Loss: {final_loss}\n')
|
132 |
+
|
133 |
+
# Save a duplicate of the results in the 'results' folder
|
134 |
+
results_folder = 'results'
|
135 |
+
os.makedirs(results_folder, exist_ok=True)
|
136 |
+
duplicate_result_path = os.path.join(results_folder, f'results_l{args.layer_count}w{args.width}.txt')
|
137 |
+
with open(duplicate_result_path, 'w') as f:
|
138 |
+
f.write(f'Layer Count: {args.layer_count}, Width: {args.width}, Parameter Count: {param_count}, Final Loss: {final_loss}\n')
|
139 |
+
|
140 |
+
if __name__ == '__main__':
|
141 |
+
main()
|