Update train_model.py
Browse files- train_model.py +90 -71
train_model.py
CHANGED
@@ -11,39 +11,30 @@ from tensor_network import FourDimensionalTransformer # Adjust the import path
|
|
11 |
|
12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
|
14 |
-
# List of dataset identifiers
|
15 |
dataset_ids = [
|
16 |
-
"
|
17 |
-
"
|
18 |
-
"MuskumPillerum/General-Knowledge",
|
19 |
-
"fblgit/tree-of-knowledge",
|
20 |
-
"CohereForAI/aya_dataset",
|
21 |
-
"AtlasUnified/Atlas-Reasoning",
|
22 |
-
"livebench/reasoning",
|
23 |
-
"SkunkworksAI/reasoning-0.01",
|
24 |
-
"KingNish/reasoning-base-20k",
|
25 |
-
"RLHFlow/HH-RLHF-Helpful-standard",
|
26 |
-
"yitingxie/rlhf-reward-datasets"
|
27 |
]
|
28 |
|
|
|
|
|
|
|
|
|
|
|
29 |
# Initialize tokenizer
|
30 |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
31 |
|
32 |
-
def
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
text_key = list(examples.keys())[0]
|
39 |
-
|
40 |
-
label_key = next((k for k in possible_label_keys if k in examples), None)
|
41 |
-
if label_key is None:
|
42 |
-
labels = [0] * len(examples[text_key]) # Default label
|
43 |
-
else:
|
44 |
-
labels = examples[label_key]
|
45 |
|
46 |
-
|
|
|
|
|
47 |
tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
|
48 |
tokenized_inputs['labels'] = labels
|
49 |
return tokenized_inputs
|
@@ -52,47 +43,81 @@ def tokenize_function(examples):
|
|
52 |
label_encoder = LabelEncoder()
|
53 |
all_labels = []
|
54 |
|
55 |
-
# Process
|
|
|
56 |
tokenized_datasets = []
|
57 |
-
for
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
# Fit label encoder
|
72 |
label_encoder.fit(all_labels)
|
73 |
num_classes = len(label_encoder.classes_)
|
74 |
print(f"Number of unique labels: {num_classes}")
|
75 |
|
|
|
76 |
if num_classes > 10:
|
77 |
-
print("
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
for
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# Prepare DataLoaders
|
90 |
-
def prepare_dataloader(
|
91 |
dataloaders = []
|
92 |
-
for
|
93 |
-
if split_name in
|
94 |
-
dataset_split =
|
95 |
-
dataset_split.set_format(type='torch', columns=['input_ids', 'labels'])
|
96 |
dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True)
|
97 |
dataloaders.append(dataloader)
|
98 |
return dataloaders
|
@@ -100,20 +125,11 @@ def prepare_dataloader(dataset_splits, split_name, batch_size=2):
|
|
100 |
train_dataloaders = prepare_dataloader(tokenized_datasets, 'train')
|
101 |
val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation')
|
102 |
|
103 |
-
# Initialize the model
|
104 |
-
model = FourDimensionalTransformer(
|
105 |
-
num_layers=16,
|
106 |
-
embed_dim=7,
|
107 |
-
num_heads=1,
|
108 |
-
num_extra_tokens=16,
|
109 |
-
num_classes=10 # Using 10 classes as per your model
|
110 |
-
).to(device)
|
111 |
-
|
112 |
# Loss function and optimizer
|
113 |
criterion = nn.CrossEntropyLoss()
|
114 |
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
115 |
|
116 |
-
def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
|
117 |
for epoch in range(num_epochs):
|
118 |
model.train()
|
119 |
total_loss = 0
|
@@ -124,11 +140,11 @@ def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
|
|
124 |
labels = batch['labels']
|
125 |
|
126 |
# Reshape input_ids and move to device
|
127 |
-
input_ids = input_ids[:, :48]
|
128 |
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
|
129 |
|
130 |
# Convert labels to torch.long and move to device
|
131 |
-
labels = labels.
|
132 |
|
133 |
optimizer.zero_grad()
|
134 |
outputs = model(input_ids)
|
@@ -152,9 +168,9 @@ def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
|
|
152 |
input_ids = batch['input_ids']
|
153 |
labels = batch['labels']
|
154 |
|
155 |
-
input_ids = input_ids[:, :48]
|
156 |
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
|
157 |
-
labels = labels.
|
158 |
|
159 |
outputs = model(input_ids)
|
160 |
_, predicted = torch.max(outputs, 1)
|
@@ -166,4 +182,7 @@ def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
|
|
166 |
torch.save(model.state_dict(), 'trained_model.pth')
|
167 |
|
168 |
# Start training
|
169 |
-
|
|
|
|
|
|
|
|
11 |
|
12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
|
14 |
+
# List of dataset identifiers for reasoning and knowledge
|
15 |
dataset_ids = [
|
16 |
+
"race/all", # For reasoning
|
17 |
+
"squad" # For general knowledge
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
]
|
19 |
|
20 |
+
# Update possible keys
|
21 |
+
possible_text_keys = ['question', 'sentence', 'query']
|
22 |
+
possible_context_keys = ['context', 'article', 'passage']
|
23 |
+
possible_label_keys = ['answer', 'answers', 'options']
|
24 |
+
|
25 |
# Initialize tokenizer
|
26 |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
27 |
|
28 |
+
def tokenize_function_race(examples):
|
29 |
+
texts = [q + " " + p for q, p in zip(examples['question'], examples['article'])]
|
30 |
+
labels = examples['answer']
|
31 |
+
tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
|
32 |
+
tokenized_inputs['labels'] = labels
|
33 |
+
return tokenized_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
def tokenize_function_squad(examples):
|
36 |
+
texts = [q + " " + c for q, c in zip(examples['question'], examples['context'])]
|
37 |
+
labels = [ans['text'][0] if ans['text'] else '' for ans in examples['answers']]
|
38 |
tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
|
39 |
tokenized_inputs['labels'] = labels
|
40 |
return tokenized_inputs
|
|
|
43 |
label_encoder = LabelEncoder()
|
44 |
all_labels = []
|
45 |
|
46 |
+
# Process RACE dataset
|
47 |
+
race_dataset = load_dataset('race', 'all')
|
48 |
tokenized_datasets = []
|
49 |
+
for split in race_dataset.keys():
|
50 |
+
tokenized_race = race_dataset[split].map(
|
51 |
+
tokenize_function_race,
|
52 |
+
batched=True,
|
53 |
+
remove_columns=race_dataset[split].column_names,
|
54 |
+
load_from_cache_file=False,
|
55 |
+
)
|
56 |
+
tokenized_datasets.append({split: tokenized_race})
|
57 |
+
# Collect labels
|
58 |
+
all_labels.extend(tokenized_race['labels'])
|
59 |
+
|
60 |
+
# Process SQuAD dataset
|
61 |
+
squad_dataset = load_dataset('squad')
|
62 |
+
for split in squad_dataset.keys():
|
63 |
+
tokenized_squad = squad_dataset[split].map(
|
64 |
+
tokenize_function_squad,
|
65 |
+
batched=True,
|
66 |
+
remove_columns=squad_dataset[split].column_names,
|
67 |
+
load_from_cache_file=False,
|
68 |
+
)
|
69 |
+
tokenized_datasets.append({split: tokenized_squad})
|
70 |
+
# Collect labels
|
71 |
+
all_labels.extend(tokenized_squad['labels'])
|
72 |
|
73 |
# Fit label encoder
|
74 |
label_encoder.fit(all_labels)
|
75 |
num_classes = len(label_encoder.classes_)
|
76 |
print(f"Number of unique labels: {num_classes}")
|
77 |
|
78 |
+
# Limit the number of classes to top 10 frequent labels
|
79 |
if num_classes > 10:
|
80 |
+
print("Number of classes exceeds 10. Reducing to top 10 classes.")
|
81 |
+
from collections import Counter
|
82 |
+
label_counter = Counter(all_labels)
|
83 |
+
top_10_labels = [label for label, _ in label_counter.most_common(10)]
|
84 |
+
print(f"Top 10 labels: {top_10_labels}")
|
85 |
+
label_mapping = {label: i for i, label in enumerate(top_10_labels)}
|
86 |
+
label_mapping['other'] = len(top_10_labels)
|
87 |
+
num_classes = len(top_10_labels) + 1
|
88 |
+
else:
|
89 |
+
label_mapping = {label: i for i, label in enumerate(label_encoder.classes_)}
|
90 |
+
|
91 |
+
# Update model with correct num_classes
|
92 |
+
model = FourDimensionalTransformer(
|
93 |
+
num_layers=16,
|
94 |
+
embed_dim=7,
|
95 |
+
num_heads=1,
|
96 |
+
num_extra_tokens=16,
|
97 |
+
num_classes=num_classes
|
98 |
+
).to(device)
|
99 |
+
|
100 |
+
def map_labels(labels):
|
101 |
+
return [label_mapping.get(label, label_mapping['other']) for label in labels]
|
102 |
+
|
103 |
+
# Process datasets
|
104 |
+
for tokenized_dataset in tokenized_datasets:
|
105 |
+
for split in tokenized_dataset.keys():
|
106 |
+
tokenized_dataset[split] = tokenized_dataset[split].map(
|
107 |
+
lambda examples: {'labels': map_labels(examples['labels'])},
|
108 |
+
batched=True
|
109 |
+
)
|
110 |
+
tokenized_dataset[split] = tokenized_dataset[split].filter(
|
111 |
+
lambda example: example['labels'] < num_classes
|
112 |
+
)
|
113 |
+
tokenized_dataset[split].set_format(type='torch', columns=['input_ids', 'labels'])
|
114 |
|
115 |
# Prepare DataLoaders
|
116 |
+
def prepare_dataloader(tokenized_datasets, split_name, batch_size=4):
|
117 |
dataloaders = []
|
118 |
+
for tokenized_dataset in tokenized_datasets:
|
119 |
+
if split_name in tokenized_dataset:
|
120 |
+
dataset_split = tokenized_dataset[split_name]
|
|
|
121 |
dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True)
|
122 |
dataloaders.append(dataloader)
|
123 |
return dataloaders
|
|
|
125 |
train_dataloaders = prepare_dataloader(tokenized_datasets, 'train')
|
126 |
val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation')
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
# Loss function and optimizer
|
129 |
criterion = nn.CrossEntropyLoss()
|
130 |
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
131 |
|
132 |
+
def train(model, train_dataloaders, val_dataloaders, num_epochs=10): #change number of Epochs to your liking
|
133 |
for epoch in range(num_epochs):
|
134 |
model.train()
|
135 |
total_loss = 0
|
|
|
140 |
labels = batch['labels']
|
141 |
|
142 |
# Reshape input_ids and move to device
|
143 |
+
input_ids = input_ids[:, :48]
|
144 |
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
|
145 |
|
146 |
# Convert labels to torch.long and move to device
|
147 |
+
labels = labels.to(device).long()
|
148 |
|
149 |
optimizer.zero_grad()
|
150 |
outputs = model(input_ids)
|
|
|
168 |
input_ids = batch['input_ids']
|
169 |
labels = batch['labels']
|
170 |
|
171 |
+
input_ids = input_ids[:, :48]
|
172 |
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
|
173 |
+
labels = labels.to(device).long()
|
174 |
|
175 |
outputs = model(input_ids)
|
176 |
_, predicted = torch.max(outputs, 1)
|
|
|
182 |
torch.save(model.state_dict(), 'trained_model.pth')
|
183 |
|
184 |
# Start training
|
185 |
+
if train_dataloaders and val_dataloaders:
|
186 |
+
train(model, train_dataloaders, val_dataloaders)
|
187 |
+
else:
|
188 |
+
print("No data loaders available for training. Please check the datasets and preprocessing steps.")
|