MLPScaling / train_mlp.py
TeacherPuffy's picture
Create train_mlp.py
9986847 verified
raw
history blame
5.2 kB
import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from datasets import load_dataset
# Define the MLP model
class MLP(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size):
super(MLP, self).__init__()
layers = []
sizes = [input_size] + hidden_sizes + [output_size]
for i in range(len(sizes) - 1):
layers.append(nn.Linear(sizes[i], sizes[i+1]))
if i < len(sizes) - 2:
layers.append(nn.ReLU())
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
# Preprocess the images
def preprocess_image(example, image_size):
image = Image.open(example['image_path']).convert('RGB')
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = transform(image)
return {'image': image, 'label': example['label']}
# Train the model
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
running_loss = 0.0
for batch in train_loader:
inputs = batch['image'].view(batch['image'].size(0), -1)
labels = batch['label']
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')
# Validation
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch in val_loader:
inputs = batch['image'].view(batch['image'].size(0), -1)
labels = batch['label']
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Validation Loss: {val_loss/len(val_loader)}, Accuracy: {100 * correct / total}%')
return val_loss / len(val_loader)
# Main function
def main():
parser = argparse.ArgumentParser(description='Train an MLP on a Hugging Face dataset with JPEG images and class labels.')
parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
args = parser.parse_args()
# Load the dataset
dataset = load_dataset('your_dataset_name')
train_dataset = dataset['train']
val_dataset = dataset['validation']
# Determine the number of classes
num_classes = len(set(train_dataset['label']))
# Determine the fixed resolution of the images
example_image = Image.open(train_dataset[0]['image_path'])
image_size = example_image.size[0] # Assuming the images are square
# Preprocess the dataset
train_dataset = train_dataset.map(lambda x: preprocess_image(x, image_size))
val_dataset = val_dataset.map(lambda x: preprocess_image(x, image_size))
# Define the model
input_size = image_size * image_size * 3
hidden_sizes = [args.width] * args.layer_count
output_size = num_classes
model = MLP(input_size, hidden_sizes, output_size)
# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
# Train the model and get the final loss
final_loss = train_model(model, train_loader, val_loader)
# Calculate the number of parameters
param_count = sum(p.numel() for p in model.parameters())
# Create the folder for the model
model_folder = f'mlp_model_l{args.layer_count}w{args.width}'
os.makedirs(model_folder, exist_ok=True)
# Save the model
model_path = os.path.join(model_folder, 'model.pth')
torch.save(model.state_dict(), model_path)
# Write the results to a text file in the model folder
result_path = os.path.join(model_folder, 'results.txt')
with open(result_path, 'w') as f:
f.write(f'Layer Count: {args.layer_count}, Width: {args.width}, Parameter Count: {param_count}, Final Loss: {final_loss}\n')
# Save a duplicate of the results in the 'results' folder
results_folder = 'results'
os.makedirs(results_folder, exist_ok=True)
duplicate_result_path = os.path.join(results_folder, f'results_l{args.layer_count}w{args.width}.txt')
with open(duplicate_result_path, 'w') as f:
f.write(f'Layer Count: {args.layer_count}, Width: {args.width}, Parameter Count: {param_count}, Final Loss: {final_loss}\n')
if __name__ == '__main__':
main()