Update train_model.py
Browse files- train_model.py +35 -17
train_model.py
CHANGED
@@ -27,21 +27,40 @@ datasets = [load_dataset(dataset_id) for dataset_id in dataset_ids]
|
|
27 |
# Initialize tokenizer
|
28 |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Replace with your model's tokenizer
|
29 |
|
30 |
-
# Tokenize datasets
|
31 |
def tokenize_function(examples):
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
tokenized_datasets = [dataset.map(tokenize_function, batched=True) for dataset in datasets]
|
35 |
|
36 |
-
|
37 |
-
# Prepare DataLoader
|
38 |
def prepare_dataloader(dataset, batch_size=32):
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
40 |
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
|
46 |
# Model setup
|
47 |
model = FourDimensionalTransformer(
|
@@ -54,13 +73,13 @@ model = FourDimensionalTransformer(
|
|
54 |
|
55 |
# Loss function and optimizer
|
56 |
criterion = nn.CrossEntropyLoss()
|
57 |
-
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
58 |
|
59 |
-
# Training loop
|
60 |
def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
|
61 |
for epoch in range(num_epochs):
|
62 |
model.train()
|
63 |
total_loss = 0
|
|
|
64 |
for dataloader in train_dataloaders:
|
65 |
for batch in dataloader:
|
66 |
input_ids = batch['input_ids']
|
@@ -72,32 +91,31 @@ def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
|
|
72 |
loss = criterion(outputs, labels)
|
73 |
loss.backward()
|
74 |
optimizer.step()
|
75 |
-
|
76 |
total_loss += loss.item()
|
77 |
|
|
|
78 |
avg_loss = total_loss / len(dataloader)
|
79 |
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
|
80 |
|
81 |
-
# Validation
|
82 |
model.eval()
|
83 |
total_correct = 0
|
|
|
84 |
with torch.no_grad():
|
85 |
for dataloader in val_dataloaders:
|
86 |
for batch in dataloader:
|
87 |
input_ids = batch['input_ids']
|
88 |
attention_mask = batch['attention_mask']
|
89 |
labels = batch['label']
|
90 |
-
|
91 |
outputs = model(input_ids, attention_mask)
|
92 |
_, predicted = torch.max(outputs, 1)
|
93 |
total_correct += (predicted == labels).sum().item()
|
94 |
-
|
95 |
-
accuracy = total_correct /
|
96 |
print(f'Validation Accuracy: {accuracy:.4f}')
|
97 |
|
98 |
# Save the trained model
|
99 |
torch.save(model.state_dict(), 'trained_model.pth')
|
100 |
|
101 |
-
|
102 |
-
# Train the model
|
103 |
train(model, train_dataloaders, val_dataloaders)
|
|
|
27 |
# Initialize tokenizer
|
28 |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Replace with your model's tokenizer
|
29 |
|
|
|
30 |
def tokenize_function(examples):
|
31 |
+
"""
|
32 |
+
Attempts to find a common text column (e.g., 'text', 'content', 'question', or 'passage').
|
33 |
+
If not found, it falls back to the first available key.
|
34 |
+
Ensures that all inputs are converted to strings.
|
35 |
+
"""
|
36 |
+
possible_keys = ['text', 'content', 'question', 'passage']
|
37 |
+
key = None
|
38 |
+
for k in possible_keys:
|
39 |
+
if k in examples:
|
40 |
+
key = k
|
41 |
+
break
|
42 |
+
if key is None:
|
43 |
+
key = list(examples.keys())[0] # fallback if none of the common keys exist
|
44 |
+
# Convert all items to string in case they're not
|
45 |
+
texts = [str(t) for t in examples[key]]
|
46 |
+
return tokenizer(texts, padding='max_length', truncation=True, max_length=128)
|
47 |
+
|
48 |
+
# Apply tokenization to all datasets
|
49 |
tokenized_datasets = [dataset.map(tokenize_function, batched=True) for dataset in datasets]
|
50 |
|
|
|
|
|
51 |
def prepare_dataloader(dataset, batch_size=32):
|
52 |
+
"""
|
53 |
+
Sets the format for the dataset to PyTorch and returns a DataLoader.
|
54 |
+
This function assumes the dataset contains 'input_ids', 'attention_mask', and 'label'.
|
55 |
+
If the label column is missing, you'll need to adjust this accordingly.
|
56 |
+
"""
|
57 |
+
# You may need to adjust the columns if your datasets use a different label column name.
|
58 |
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
59 |
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
60 |
|
61 |
+
# Only include splits that exist to avoid key errors
|
62 |
+
train_dataloaders = [prepare_dataloader(ds['train']) for ds in tokenized_datasets if 'train' in ds]
|
63 |
+
val_dataloaders = [prepare_dataloader(ds['validation']) for ds in tokenized_datasets if 'validation' in ds]
|
64 |
|
65 |
# Model setup
|
66 |
model = FourDimensionalTransformer(
|
|
|
73 |
|
74 |
# Loss function and optimizer
|
75 |
criterion = nn.CrossEntropyLoss()
|
76 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
77 |
|
|
|
78 |
def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
|
79 |
for epoch in range(num_epochs):
|
80 |
model.train()
|
81 |
total_loss = 0
|
82 |
+
# Iterate over each training dataloader (from each dataset)
|
83 |
for dataloader in train_dataloaders:
|
84 |
for batch in dataloader:
|
85 |
input_ids = batch['input_ids']
|
|
|
91 |
loss = criterion(outputs, labels)
|
92 |
loss.backward()
|
93 |
optimizer.step()
|
|
|
94 |
total_loss += loss.item()
|
95 |
|
96 |
+
# Use the last dataloader's length to compute average loss
|
97 |
avg_loss = total_loss / len(dataloader)
|
98 |
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
|
99 |
|
100 |
+
# Validation loop
|
101 |
model.eval()
|
102 |
total_correct = 0
|
103 |
+
total_samples = 0
|
104 |
with torch.no_grad():
|
105 |
for dataloader in val_dataloaders:
|
106 |
for batch in dataloader:
|
107 |
input_ids = batch['input_ids']
|
108 |
attention_mask = batch['attention_mask']
|
109 |
labels = batch['label']
|
|
|
110 |
outputs = model(input_ids, attention_mask)
|
111 |
_, predicted = torch.max(outputs, 1)
|
112 |
total_correct += (predicted == labels).sum().item()
|
113 |
+
total_samples += labels.size(0)
|
114 |
+
accuracy = total_correct / total_samples if total_samples > 0 else 0
|
115 |
print(f'Validation Accuracy: {accuracy:.4f}')
|
116 |
|
117 |
# Save the trained model
|
118 |
torch.save(model.state_dict(), 'trained_model.pth')
|
119 |
|
120 |
+
# Start training
|
|
|
121 |
train(model, train_dataloaders, val_dataloaders)
|