Prositron commited on
Commit
dff6ebd
·
verified ·
1 Parent(s): 7f93af3

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +100 -52
train_model.py CHANGED
@@ -4,7 +4,12 @@ import torch.optim as optim
4
  from torch.utils.data import DataLoader
5
  from datasets import load_dataset
6
  from transformers import AutoTokenizer
7
- from tensor_network import FourDimensionalTransformer # Adjust based on your model's location
 
 
 
 
 
8
 
9
  # List of dataset identifiers
10
  dataset_ids = [
@@ -21,55 +26,88 @@ dataset_ids = [
21
  "yitingxie/rlhf-reward-datasets"
22
  ]
23
 
24
- # Load datasets
25
- datasets = [load_dataset(dataset_id) for dataset_id in dataset_ids]
26
-
27
  # Initialize tokenizer
28
- tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Replace with your model's tokenizer
29
 
30
  def tokenize_function(examples):
31
- """
32
- Attempts to find a common text column (e.g., 'text', 'content', 'question', or 'passage').
33
- If not found, it falls back to the first available key.
34
- Ensures that all inputs are converted to strings.
35
- """
36
- possible_keys = ['text', 'content', 'question', 'passage']
37
- key = None
38
- for k in possible_keys:
39
- if k in examples:
40
- key = k
41
- break
42
- if key is None:
43
- key = list(examples.keys())[0] # fallback if none of the common keys exist
44
- # Convert all items to string in case they're not
45
- texts = [str(t) for t in examples[key]]
46
- return tokenizer(texts, padding='max_length', truncation=True, max_length=128)
47
-
48
- # Apply tokenization to all datasets
49
- tokenized_datasets = [dataset.map(tokenize_function, batched=True) for dataset in datasets]
50
-
51
- def prepare_dataloader(dataset, batch_size=32):
52
- """
53
- Sets the format for the dataset to PyTorch and returns a DataLoader.
54
- This function assumes the dataset contains 'input_ids', 'attention_mask', and 'label'.
55
- If the label column is missing, you'll need to adjust this accordingly.
56
- """
57
- # You may need to adjust the columns if your datasets use a different label column name.
58
- dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
59
- return DataLoader(dataset, batch_size=batch_size, shuffle=True)
60
-
61
- # Only include splits that exist to avoid key errors
62
- train_dataloaders = [prepare_dataloader(ds['train']) for ds in tokenized_datasets if 'train' in ds]
63
- val_dataloaders = [prepare_dataloader(ds['validation']) for ds in tokenized_datasets if 'validation' in ds]
64
-
65
- # Model setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  model = FourDimensionalTransformer(
67
  num_layers=16,
68
  embed_dim=7,
69
  num_heads=1,
70
  num_extra_tokens=16,
71
- num_classes=10 # Adjust based on your specific task
72
- )
73
 
74
  # Loss function and optimizer
75
  criterion = nn.CrossEntropyLoss()
@@ -79,22 +117,29 @@ def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
79
  for epoch in range(num_epochs):
80
  model.train()
81
  total_loss = 0
82
- # Iterate over each training dataloader (from each dataset)
83
  for dataloader in train_dataloaders:
84
  for batch in dataloader:
85
  input_ids = batch['input_ids']
86
- attention_mask = batch['attention_mask']
87
- labels = batch['label']
 
 
 
 
 
 
88
 
89
  optimizer.zero_grad()
90
- outputs = model(input_ids, attention_mask)
91
  loss = criterion(outputs, labels)
92
  loss.backward()
93
  optimizer.step()
 
94
  total_loss += loss.item()
 
95
 
96
- # Use the last dataloader's length to compute average loss
97
- avg_loss = total_loss / len(dataloader)
98
  print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
99
 
100
  # Validation loop
@@ -105,16 +150,19 @@ def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
105
  for dataloader in val_dataloaders:
106
  for batch in dataloader:
107
  input_ids = batch['input_ids']
108
- attention_mask = batch['attention_mask']
109
- labels = batch['label']
110
- outputs = model(input_ids, attention_mask)
 
 
 
 
111
  _, predicted = torch.max(outputs, 1)
112
  total_correct += (predicted == labels).sum().item()
113
  total_samples += labels.size(0)
114
  accuracy = total_correct / total_samples if total_samples > 0 else 0
115
  print(f'Validation Accuracy: {accuracy:.4f}')
116
 
117
- # Save the trained model
118
  torch.save(model.state_dict(), 'trained_model.pth')
119
 
120
  # Start training
 
4
  from torch.utils.data import DataLoader
5
  from datasets import load_dataset
6
  from transformers import AutoTokenizer
7
+ from sklearn.preprocessing import LabelEncoder
8
+
9
+ # Import your model from tensor_network.py
10
+ from tensor_network import FourDimensionalTransformer # Adjust the import path as needed
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
 
14
  # List of dataset identifiers
15
  dataset_ids = [
 
26
  "yitingxie/rlhf-reward-datasets"
27
  ]
28
 
 
 
 
29
  # Initialize tokenizer
30
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
31
 
32
  def tokenize_function(examples):
33
+ possible_text_keys = ['text', 'content', 'question', 'passage', 'prompt', 'input']
34
+ possible_label_keys = ['label', 'answer', 'response', 'output', 'target']
35
+
36
+ text_key = next((k for k in possible_text_keys if k in examples), None)
37
+ if text_key is None:
38
+ text_key = list(examples.keys())[0]
39
+
40
+ label_key = next((k for k in possible_label_keys if k in examples), None)
41
+ if label_key is None:
42
+ labels = [0] * len(examples[text_key]) # Default label
43
+ else:
44
+ labels = examples[label_key]
45
+
46
+ texts = [str(t) for t in examples[text_key]]
47
+ tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
48
+ tokenized_inputs['labels'] = labels
49
+ return tokenized_inputs
50
+
51
+ # Initialize LabelEncoder
52
+ label_encoder = LabelEncoder()
53
+ all_labels = []
54
+
55
+ # Process each dataset individually
56
+ tokenized_datasets = []
57
+ for dataset_id in dataset_ids:
58
+ try:
59
+ dataset = load_dataset(dataset_id)
60
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
61
+
62
+ # Collect labels for label encoding
63
+ for split in tokenized_dataset.keys():
64
+ if 'labels' in tokenized_dataset[split].features:
65
+ all_labels.extend(tokenized_dataset[split]['labels'])
66
+
67
+ tokenized_datasets.append(tokenized_dataset)
68
+ except Exception as e:
69
+ print(f"Could not process dataset {dataset_id}: {e}")
70
+
71
+ # Fit label encoder
72
+ label_encoder.fit(all_labels)
73
+ num_classes = len(label_encoder.classes_)
74
+ print(f"Number of unique labels: {num_classes}")
75
+
76
+ if num_classes > 10:
77
+ print("Warning: Number of unique labels exceeds the number of classes. Adjusting the dataset or model is required.")
78
+ exit()
79
+
80
+ # Transform labels in each dataset
81
+ for dataset in tokenized_datasets:
82
+ for split in dataset.keys():
83
+ if 'labels' in dataset[split].features:
84
+ dataset[split] = dataset[split].map(
85
+ lambda examples: {'labels': label_encoder.transform(examples['labels'])},
86
+ batched=True
87
+ )
88
+
89
+ # Prepare DataLoaders
90
+ def prepare_dataloader(dataset_splits, split_name, batch_size=2):
91
+ dataloaders = []
92
+ for dataset in dataset_splits:
93
+ if split_name in dataset:
94
+ dataset_split = dataset[split_name]
95
+ dataset_split.set_format(type='torch', columns=['input_ids', 'labels'])
96
+ dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True)
97
+ dataloaders.append(dataloader)
98
+ return dataloaders
99
+
100
+ train_dataloaders = prepare_dataloader(tokenized_datasets, 'train')
101
+ val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation')
102
+
103
+ # Initialize the model
104
  model = FourDimensionalTransformer(
105
  num_layers=16,
106
  embed_dim=7,
107
  num_heads=1,
108
  num_extra_tokens=16,
109
+ num_classes=10 # Using 10 classes as per your model
110
+ ).to(device)
111
 
112
  # Loss function and optimizer
113
  criterion = nn.CrossEntropyLoss()
 
117
  for epoch in range(num_epochs):
118
  model.train()
119
  total_loss = 0
120
+ total_batches = 0
121
  for dataloader in train_dataloaders:
122
  for batch in dataloader:
123
  input_ids = batch['input_ids']
124
+ labels = batch['labels']
125
+
126
+ # Reshape input_ids and move to device
127
+ input_ids = input_ids[:, :48] # Ensure length is 48
128
+ input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
129
+
130
+ # Convert labels to torch.long and move to device
131
+ labels = labels.type(torch.long).to(device)
132
 
133
  optimizer.zero_grad()
134
+ outputs = model(input_ids)
135
  loss = criterion(outputs, labels)
136
  loss.backward()
137
  optimizer.step()
138
+
139
  total_loss += loss.item()
140
+ total_batches += 1
141
 
142
+ avg_loss = total_loss / total_batches if total_batches > 0 else 0
 
143
  print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
144
 
145
  # Validation loop
 
150
  for dataloader in val_dataloaders:
151
  for batch in dataloader:
152
  input_ids = batch['input_ids']
153
+ labels = batch['labels']
154
+
155
+ input_ids = input_ids[:, :48] # Ensure length is 48
156
+ input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
157
+ labels = labels.type(torch.long).to(device)
158
+
159
+ outputs = model(input_ids)
160
  _, predicted = torch.max(outputs, 1)
161
  total_correct += (predicted == labels).sum().item()
162
  total_samples += labels.size(0)
163
  accuracy = total_correct / total_samples if total_samples > 0 else 0
164
  print(f'Validation Accuracy: {accuracy:.4f}')
165
 
 
166
  torch.save(model.state_dict(), 'trained_model.pth')
167
 
168
  # Start training