Prositron commited on
Commit
725379f
·
verified ·
1 Parent(s): 8b932df

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +90 -71
train_model.py CHANGED
@@ -11,39 +11,30 @@ from tensor_network import FourDimensionalTransformer # Adjust the import path
11
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
 
14
- # List of dataset identifiers
15
  dataset_ids = [
16
- "prithivMLmods/Deepthink-Reasoning",
17
- "ewok-core/ewok-core-1.0",
18
- "MuskumPillerum/General-Knowledge",
19
- "fblgit/tree-of-knowledge",
20
- "CohereForAI/aya_dataset",
21
- "AtlasUnified/Atlas-Reasoning",
22
- "livebench/reasoning",
23
- "SkunkworksAI/reasoning-0.01",
24
- "KingNish/reasoning-base-20k",
25
- "RLHFlow/HH-RLHF-Helpful-standard",
26
- "yitingxie/rlhf-reward-datasets"
27
  ]
28
 
 
 
 
 
 
29
  # Initialize tokenizer
30
  tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
31
 
32
- def tokenize_function(examples):
33
- possible_text_keys = ['text', 'content', 'question', 'passage', 'prompt', 'input']
34
- possible_label_keys = ['label', 'answer', 'response', 'output', 'target']
35
-
36
- text_key = next((k for k in possible_text_keys if k in examples), None)
37
- if text_key is None:
38
- text_key = list(examples.keys())[0]
39
-
40
- label_key = next((k for k in possible_label_keys if k in examples), None)
41
- if label_key is None:
42
- labels = [0] * len(examples[text_key]) # Default label
43
- else:
44
- labels = examples[label_key]
45
 
46
- texts = [str(t) for t in examples[text_key]]
 
 
47
  tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
48
  tokenized_inputs['labels'] = labels
49
  return tokenized_inputs
@@ -52,47 +43,81 @@ def tokenize_function(examples):
52
  label_encoder = LabelEncoder()
53
  all_labels = []
54
 
55
- # Process each dataset individually
 
56
  tokenized_datasets = []
57
- for dataset_id in dataset_ids:
58
- try:
59
- dataset = load_dataset(dataset_id)
60
- tokenized_dataset = dataset.map(tokenize_function, batched=True)
61
-
62
- # Collect labels for label encoding
63
- for split in tokenized_dataset.keys():
64
- if 'labels' in tokenized_dataset[split].features:
65
- all_labels.extend(tokenized_dataset[split]['labels'])
66
-
67
- tokenized_datasets.append(tokenized_dataset)
68
- except Exception as e:
69
- print(f"Could not process dataset {dataset_id}: {e}")
 
 
 
 
 
 
 
 
 
 
70
 
71
  # Fit label encoder
72
  label_encoder.fit(all_labels)
73
  num_classes = len(label_encoder.classes_)
74
  print(f"Number of unique labels: {num_classes}")
75
 
 
76
  if num_classes > 10:
77
- print("Warning: Number of unique labels exceeds the number of classes. Adjusting the dataset or model is required.")
78
- exit()
79
-
80
- # Transform labels in each dataset
81
- for dataset in tokenized_datasets:
82
- for split in dataset.keys():
83
- if 'labels' in dataset[split].features:
84
- dataset[split] = dataset[split].map(
85
- lambda examples: {'labels': label_encoder.transform(examples['labels'])},
86
- batched=True
87
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  # Prepare DataLoaders
90
- def prepare_dataloader(dataset_splits, split_name, batch_size=2):
91
  dataloaders = []
92
- for dataset in dataset_splits:
93
- if split_name in dataset:
94
- dataset_split = dataset[split_name]
95
- dataset_split.set_format(type='torch', columns=['input_ids', 'labels'])
96
  dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True)
97
  dataloaders.append(dataloader)
98
  return dataloaders
@@ -100,20 +125,11 @@ def prepare_dataloader(dataset_splits, split_name, batch_size=2):
100
  train_dataloaders = prepare_dataloader(tokenized_datasets, 'train')
101
  val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation')
102
 
103
- # Initialize the model
104
- model = FourDimensionalTransformer(
105
- num_layers=16,
106
- embed_dim=7,
107
- num_heads=1,
108
- num_extra_tokens=16,
109
- num_classes=10 # Using 10 classes as per your model
110
- ).to(device)
111
-
112
  # Loss function and optimizer
113
  criterion = nn.CrossEntropyLoss()
114
  optimizer = optim.Adam(model.parameters(), lr=1e-4)
115
 
116
- def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
117
  for epoch in range(num_epochs):
118
  model.train()
119
  total_loss = 0
@@ -124,11 +140,11 @@ def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
124
  labels = batch['labels']
125
 
126
  # Reshape input_ids and move to device
127
- input_ids = input_ids[:, :48] # Ensure length is 48
128
  input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
129
 
130
  # Convert labels to torch.long and move to device
131
- labels = labels.type(torch.long).to(device)
132
 
133
  optimizer.zero_grad()
134
  outputs = model(input_ids)
@@ -152,9 +168,9 @@ def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
152
  input_ids = batch['input_ids']
153
  labels = batch['labels']
154
 
155
- input_ids = input_ids[:, :48] # Ensure length is 48
156
  input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
157
- labels = labels.type(torch.long).to(device)
158
 
159
  outputs = model(input_ids)
160
  _, predicted = torch.max(outputs, 1)
@@ -166,4 +182,7 @@ def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
166
  torch.save(model.state_dict(), 'trained_model.pth')
167
 
168
  # Start training
169
- train(model, train_dataloaders, val_dataloaders)
 
 
 
 
11
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
 
14
+ # List of dataset identifiers for reasoning and knowledge
15
  dataset_ids = [
16
+ "race/all", # For reasoning
17
+ "squad" # For general knowledge
 
 
 
 
 
 
 
 
 
18
  ]
19
 
20
+ # Update possible keys
21
+ possible_text_keys = ['question', 'sentence', 'query']
22
+ possible_context_keys = ['context', 'article', 'passage']
23
+ possible_label_keys = ['answer', 'answers', 'options']
24
+
25
  # Initialize tokenizer
26
  tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
27
 
28
+ def tokenize_function_race(examples):
29
+ texts = [q + " " + p for q, p in zip(examples['question'], examples['article'])]
30
+ labels = examples['answer']
31
+ tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
32
+ tokenized_inputs['labels'] = labels
33
+ return tokenized_inputs
 
 
 
 
 
 
 
34
 
35
+ def tokenize_function_squad(examples):
36
+ texts = [q + " " + c for q, c in zip(examples['question'], examples['context'])]
37
+ labels = [ans['text'][0] if ans['text'] else '' for ans in examples['answers']]
38
  tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
39
  tokenized_inputs['labels'] = labels
40
  return tokenized_inputs
 
43
  label_encoder = LabelEncoder()
44
  all_labels = []
45
 
46
+ # Process RACE dataset
47
+ race_dataset = load_dataset('race', 'all')
48
  tokenized_datasets = []
49
+ for split in race_dataset.keys():
50
+ tokenized_race = race_dataset[split].map(
51
+ tokenize_function_race,
52
+ batched=True,
53
+ remove_columns=race_dataset[split].column_names,
54
+ load_from_cache_file=False,
55
+ )
56
+ tokenized_datasets.append({split: tokenized_race})
57
+ # Collect labels
58
+ all_labels.extend(tokenized_race['labels'])
59
+
60
+ # Process SQuAD dataset
61
+ squad_dataset = load_dataset('squad')
62
+ for split in squad_dataset.keys():
63
+ tokenized_squad = squad_dataset[split].map(
64
+ tokenize_function_squad,
65
+ batched=True,
66
+ remove_columns=squad_dataset[split].column_names,
67
+ load_from_cache_file=False,
68
+ )
69
+ tokenized_datasets.append({split: tokenized_squad})
70
+ # Collect labels
71
+ all_labels.extend(tokenized_squad['labels'])
72
 
73
  # Fit label encoder
74
  label_encoder.fit(all_labels)
75
  num_classes = len(label_encoder.classes_)
76
  print(f"Number of unique labels: {num_classes}")
77
 
78
+ # Limit the number of classes to top 10 frequent labels
79
  if num_classes > 10:
80
+ print("Number of classes exceeds 10. Reducing to top 10 classes.")
81
+ from collections import Counter
82
+ label_counter = Counter(all_labels)
83
+ top_10_labels = [label for label, _ in label_counter.most_common(10)]
84
+ print(f"Top 10 labels: {top_10_labels}")
85
+ label_mapping = {label: i for i, label in enumerate(top_10_labels)}
86
+ label_mapping['other'] = len(top_10_labels)
87
+ num_classes = len(top_10_labels) + 1
88
+ else:
89
+ label_mapping = {label: i for i, label in enumerate(label_encoder.classes_)}
90
+
91
+ # Update model with correct num_classes
92
+ model = FourDimensionalTransformer(
93
+ num_layers=16,
94
+ embed_dim=7,
95
+ num_heads=1,
96
+ num_extra_tokens=16,
97
+ num_classes=num_classes
98
+ ).to(device)
99
+
100
+ def map_labels(labels):
101
+ return [label_mapping.get(label, label_mapping['other']) for label in labels]
102
+
103
+ # Process datasets
104
+ for tokenized_dataset in tokenized_datasets:
105
+ for split in tokenized_dataset.keys():
106
+ tokenized_dataset[split] = tokenized_dataset[split].map(
107
+ lambda examples: {'labels': map_labels(examples['labels'])},
108
+ batched=True
109
+ )
110
+ tokenized_dataset[split] = tokenized_dataset[split].filter(
111
+ lambda example: example['labels'] < num_classes
112
+ )
113
+ tokenized_dataset[split].set_format(type='torch', columns=['input_ids', 'labels'])
114
 
115
  # Prepare DataLoaders
116
+ def prepare_dataloader(tokenized_datasets, split_name, batch_size=4):
117
  dataloaders = []
118
+ for tokenized_dataset in tokenized_datasets:
119
+ if split_name in tokenized_dataset:
120
+ dataset_split = tokenized_dataset[split_name]
 
121
  dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True)
122
  dataloaders.append(dataloader)
123
  return dataloaders
 
125
  train_dataloaders = prepare_dataloader(tokenized_datasets, 'train')
126
  val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation')
127
 
 
 
 
 
 
 
 
 
 
128
  # Loss function and optimizer
129
  criterion = nn.CrossEntropyLoss()
130
  optimizer = optim.Adam(model.parameters(), lr=1e-4)
131
 
132
+ def train(model, train_dataloaders, val_dataloaders, num_epochs=10): #change number of Epochs to your liking
133
  for epoch in range(num_epochs):
134
  model.train()
135
  total_loss = 0
 
140
  labels = batch['labels']
141
 
142
  # Reshape input_ids and move to device
143
+ input_ids = input_ids[:, :48]
144
  input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
145
 
146
  # Convert labels to torch.long and move to device
147
+ labels = labels.to(device).long()
148
 
149
  optimizer.zero_grad()
150
  outputs = model(input_ids)
 
168
  input_ids = batch['input_ids']
169
  labels = batch['labels']
170
 
171
+ input_ids = input_ids[:, :48]
172
  input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
173
+ labels = labels.to(device).long()
174
 
175
  outputs = model(input_ids)
176
  _, predicted = torch.max(outputs, 1)
 
182
  torch.save(model.state_dict(), 'trained_model.pth')
183
 
184
  # Start training
185
+ if train_dataloaders and val_dataloaders:
186
+ train(model, train_dataloaders, val_dataloaders)
187
+ else:
188
+ print("No data loaders available for training. Please check the datasets and preprocessing steps.")