Update train_mlp.py
Browse files- train_mlp.py +1 -1
train_mlp.py
CHANGED
@@ -105,7 +105,7 @@ def main():
|
|
105 |
|
106 |
# Split the dataset into train and validation sets
|
107 |
train_dataset = dataset['train']
|
108 |
-
val_dataset = dataset['
|
109 |
|
110 |
# Determine the number of classes
|
111 |
num_classes = len(set(train_dataset['label']))
|
|
|
105 |
|
106 |
# Split the dataset into train and validation sets
|
107 |
train_dataset = dataset['train']
|
108 |
+
val_dataset = dataset['valid'] # Assuming 'validation' is the correct key
|
109 |
|
110 |
# Determine the number of classes
|
111 |
num_classes = len(set(train_dataset['label']))
|