Prositron commited on
Commit
7f93af3
·
verified ·
1 Parent(s): 3b475af

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +35 -17
train_model.py CHANGED
@@ -27,21 +27,40 @@ datasets = [load_dataset(dataset_id) for dataset_id in dataset_ids]
27
  # Initialize tokenizer
28
  tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Replace with your model's tokenizer
29
 
30
- # Tokenize datasets
31
  def tokenize_function(examples):
32
- return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128)
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  tokenized_datasets = [dataset.map(tokenize_function, batched=True) for dataset in datasets]
35
 
36
-
37
- # Prepare DataLoader
38
  def prepare_dataloader(dataset, batch_size=32):
 
 
 
 
 
 
39
  dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
40
  return DataLoader(dataset, batch_size=batch_size, shuffle=True)
41
 
42
- train_dataloaders = [prepare_dataloader(dataset['train']) for dataset in tokenized_datasets]
43
- val_dataloaders = [prepare_dataloader(dataset['validation']) for dataset in tokenized_datasets]
44
-
45
 
46
  # Model setup
47
  model = FourDimensionalTransformer(
@@ -54,13 +73,13 @@ model = FourDimensionalTransformer(
54
 
55
  # Loss function and optimizer
56
  criterion = nn.CrossEntropyLoss()
57
- optimizer = optim.Adam(model.parameters(), lr=1e-4) # Using Adam optimizer with a learning rate of 1e-4
58
 
59
- # Training loop
60
  def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
61
  for epoch in range(num_epochs):
62
  model.train()
63
  total_loss = 0
 
64
  for dataloader in train_dataloaders:
65
  for batch in dataloader:
66
  input_ids = batch['input_ids']
@@ -72,32 +91,31 @@ def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
72
  loss = criterion(outputs, labels)
73
  loss.backward()
74
  optimizer.step()
75
-
76
  total_loss += loss.item()
77
 
 
78
  avg_loss = total_loss / len(dataloader)
79
  print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
80
 
81
- # Validation
82
  model.eval()
83
  total_correct = 0
 
84
  with torch.no_grad():
85
  for dataloader in val_dataloaders:
86
  for batch in dataloader:
87
  input_ids = batch['input_ids']
88
  attention_mask = batch['attention_mask']
89
  labels = batch['label']
90
-
91
  outputs = model(input_ids, attention_mask)
92
  _, predicted = torch.max(outputs, 1)
93
  total_correct += (predicted == labels).sum().item()
94
-
95
- accuracy = total_correct / len(dataloader.dataset)
96
  print(f'Validation Accuracy: {accuracy:.4f}')
97
 
98
  # Save the trained model
99
  torch.save(model.state_dict(), 'trained_model.pth')
100
 
101
-
102
- # Train the model
103
  train(model, train_dataloaders, val_dataloaders)
 
27
  # Initialize tokenizer
28
  tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Replace with your model's tokenizer
29
 
 
30
  def tokenize_function(examples):
31
+ """
32
+ Attempts to find a common text column (e.g., 'text', 'content', 'question', or 'passage').
33
+ If not found, it falls back to the first available key.
34
+ Ensures that all inputs are converted to strings.
35
+ """
36
+ possible_keys = ['text', 'content', 'question', 'passage']
37
+ key = None
38
+ for k in possible_keys:
39
+ if k in examples:
40
+ key = k
41
+ break
42
+ if key is None:
43
+ key = list(examples.keys())[0] # fallback if none of the common keys exist
44
+ # Convert all items to string in case they're not
45
+ texts = [str(t) for t in examples[key]]
46
+ return tokenizer(texts, padding='max_length', truncation=True, max_length=128)
47
+
48
+ # Apply tokenization to all datasets
49
  tokenized_datasets = [dataset.map(tokenize_function, batched=True) for dataset in datasets]
50
 
 
 
51
  def prepare_dataloader(dataset, batch_size=32):
52
+ """
53
+ Sets the format for the dataset to PyTorch and returns a DataLoader.
54
+ This function assumes the dataset contains 'input_ids', 'attention_mask', and 'label'.
55
+ If the label column is missing, you'll need to adjust this accordingly.
56
+ """
57
+ # You may need to adjust the columns if your datasets use a different label column name.
58
  dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
59
  return DataLoader(dataset, batch_size=batch_size, shuffle=True)
60
 
61
+ # Only include splits that exist to avoid key errors
62
+ train_dataloaders = [prepare_dataloader(ds['train']) for ds in tokenized_datasets if 'train' in ds]
63
+ val_dataloaders = [prepare_dataloader(ds['validation']) for ds in tokenized_datasets if 'validation' in ds]
64
 
65
  # Model setup
66
  model = FourDimensionalTransformer(
 
73
 
74
  # Loss function and optimizer
75
  criterion = nn.CrossEntropyLoss()
76
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
77
 
 
78
  def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
79
  for epoch in range(num_epochs):
80
  model.train()
81
  total_loss = 0
82
+ # Iterate over each training dataloader (from each dataset)
83
  for dataloader in train_dataloaders:
84
  for batch in dataloader:
85
  input_ids = batch['input_ids']
 
91
  loss = criterion(outputs, labels)
92
  loss.backward()
93
  optimizer.step()
 
94
  total_loss += loss.item()
95
 
96
+ # Use the last dataloader's length to compute average loss
97
  avg_loss = total_loss / len(dataloader)
98
  print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
99
 
100
+ # Validation loop
101
  model.eval()
102
  total_correct = 0
103
+ total_samples = 0
104
  with torch.no_grad():
105
  for dataloader in val_dataloaders:
106
  for batch in dataloader:
107
  input_ids = batch['input_ids']
108
  attention_mask = batch['attention_mask']
109
  labels = batch['label']
 
110
  outputs = model(input_ids, attention_mask)
111
  _, predicted = torch.max(outputs, 1)
112
  total_correct += (predicted == labels).sum().item()
113
+ total_samples += labels.size(0)
114
+ accuracy = total_correct / total_samples if total_samples > 0 else 0
115
  print(f'Validation Accuracy: {accuracy:.4f}')
116
 
117
  # Save the trained model
118
  torch.save(model.state_dict(), 'trained_model.pth')
119
 
120
+ # Start training
 
121
  train(model, train_dataloaders, val_dataloaders)