Update train.py
Browse files
train.py
CHANGED
@@ -37,7 +37,7 @@ def train(args):
|
|
37 |
test_data = Dataset(root=args.train_data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset)
|
38 |
train_data = Dataset(root=args.train_data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset, mode='train')
|
39 |
train_data = torch.utils.data.ConcatDataset([train_data, test_data])
|
40 |
-
print('Dataset length:', train_data
|
41 |
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
|
42 |
|
43 |
##########################################################################################
|
|
|
37 |
test_data = Dataset(root=args.train_data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset)
|
38 |
train_data = Dataset(root=args.train_data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset, mode='train')
|
39 |
train_data = torch.utils.data.ConcatDataset([train_data, test_data])
|
40 |
+
print('Dataset length:', len(train_data))
|
41 |
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
|
42 |
|
43 |
##########################################################################################
|