File size: 5,332 Bytes
9986847
 
 
 
 
a989bbc
9986847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a989bbc
 
 
 
 
9986847
 
a989bbc
9986847
 
 
a989bbc
 
 
9986847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a989bbc
 
 
9986847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a989bbc
 
 
 
 
 
 
 
9986847
a989bbc
9986847
 
 
 
 
 
 
 
a989bbc
 
 
9986847
 
 
 
 
a989bbc
9986847
 
 
 
 
 
 
 
a989bbc
 
 
9986847
 
a989bbc
 
9986847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_from_disk

# Define the MLP model
class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MLP, self).__init__()
        layers = []
        sizes = [input_size] + hidden_sizes + [output_size]
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i+1]))
            if i < len(sizes) - 2:
                layers.append(nn.ReLU())
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Custom collate function
def custom_collate(batch):
    images = torch.stack([item['image'] for item in batch])
    labels = torch.tensor([item['label'] for item in batch])
    return {'image': images, 'label': labels}

# Train the model
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001, save_loss_path=None):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for batch in train_loader:
            inputs = batch['image'].view(batch['image'].size(0), -1)
            labels = batch['label']

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f'Epoch {epoch+1}, Loss: {avg_train_loss}')

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                inputs = batch['image'].view(batch['image'].size(0), -1)
                labels = batch['label']

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        print(f'Validation Loss: {avg_val_loss}, Accuracy: {100 * correct / total}%')

    if save_loss_path:
        with open(save_loss_path, 'w') as f:
            for epoch, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses)):
                f.write(f'Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}\n')

    return avg_val_loss

# Main function
def main():
    parser = argparse.ArgumentParser(description='Train an MLP on a Hugging Face dataset with JPEG images and class labels.')
    parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
    parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
    args = parser.parse_args()

    # Load the preprocessed datasets
    train_dataset = load_from_disk('preprocessed_train_dataset')
    val_dataset = load_from_disk('preprocessed_val_dataset')

    # Determine the number of classes
    num_classes = len(set(train_dataset['label']))

    # Determine the fixed resolution of the images
    image_size = train_dataset[0]['image'].size(1)  # Assuming the images are square

    # Define the model
    input_size = image_size * image_size * 3
    hidden_sizes = [args.width] * args.layer_count
    output_size = num_classes

    model = MLP(input_size, hidden_sizes, output_size)

    # Create data loaders with custom collate function
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate)

    # Train the model and get the final loss
    save_loss_path = 'losses.txt'
    final_loss = train_model(model, train_loader, val_loader, save_loss_path=save_loss_path)

    # Calculate the number of parameters
    param_count = sum(p.numel() for p in model.parameters())

    # Create the folder for the model
    model_folder = f'mlp_model_l{args.layer_count}w{args.width}'
    os.makedirs(model_folder, exist_ok=True)

    # Save the model
    model_path = os.path.join(model_folder, 'model.pth')
    torch.save(model.state_dict(), model_path)

    # Write the results to a text file in the model folder
    result_path = os.path.join(model_folder, 'results.txt')
    with open(result_path, 'w') as f:
        f.write(f'Layer Count: {args.layer_count}, Width: {args.width}, Parameter Count: {param_count}, Final Loss: {final_loss}\n')

    # Save a duplicate of the results in the 'results' folder
    results_folder = 'results'
    os.makedirs(results_folder, exist_ok=True)
    duplicate_result_path = os.path.join(results_folder, f'results_l{args.layer_count}w{args.width}.txt')
    with open(duplicate_result_path, 'w') as f:
        f.write(f'Layer Count: {args.layer_count}, Width: {args.width}, Parameter Count: {param_count}, Final Loss: {final_loss}\n')

if __name__ == '__main__':
    main()