File size: 6,962 Bytes
0ae85e9 dff6ebd 0ae85e9 725379f 0ae85e9 725379f 0ae85e9 725379f 0ae85e9 dff6ebd 0ae85e9 725379f dff6ebd 725379f dff6ebd 725379f dff6ebd 725379f dff6ebd 725379f dff6ebd 725379f dff6ebd 725379f dff6ebd 725379f dff6ebd 0ae85e9 7f93af3 0ae85e9 725379f 0ae85e9 dff6ebd 0ae85e9 dff6ebd 725379f dff6ebd 725379f 0ae85e9 dff6ebd 0ae85e9 dff6ebd 0ae85e9 dff6ebd 0ae85e9 dff6ebd 0ae85e9 7f93af3 0ae85e9 7f93af3 0ae85e9 dff6ebd 725379f dff6ebd 725379f dff6ebd 0ae85e9 7f93af3 0ae85e9 7f93af3 725379f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder
# Import your model from tensor_network.py
from tensor_network import FourDimensionalTransformer # Adjust the import path as needed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# List of dataset identifiers for reasoning and knowledge
dataset_ids = [
"race/all", # For reasoning
"squad" # For general knowledge
]
# Update possible keys
possible_text_keys = ['question', 'sentence', 'query']
possible_context_keys = ['context', 'article', 'passage']
possible_label_keys = ['answer', 'answers', 'options']
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
def tokenize_function_race(examples):
texts = [q + " " + p for q, p in zip(examples['question'], examples['article'])]
labels = examples['answer']
tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
tokenized_inputs['labels'] = labels
return tokenized_inputs
def tokenize_function_squad(examples):
texts = [q + " " + c for q, c in zip(examples['question'], examples['context'])]
labels = [ans['text'][0] if ans['text'] else '' for ans in examples['answers']]
tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
tokenized_inputs['labels'] = labels
return tokenized_inputs
# Initialize LabelEncoder
label_encoder = LabelEncoder()
all_labels = []
# Process RACE dataset
race_dataset = load_dataset('race', 'all')
tokenized_datasets = []
for split in race_dataset.keys():
tokenized_race = race_dataset[split].map(
tokenize_function_race,
batched=True,
remove_columns=race_dataset[split].column_names,
load_from_cache_file=False,
)
tokenized_datasets.append({split: tokenized_race})
# Collect labels
all_labels.extend(tokenized_race['labels'])
# Process SQuAD dataset
squad_dataset = load_dataset('squad')
for split in squad_dataset.keys():
tokenized_squad = squad_dataset[split].map(
tokenize_function_squad,
batched=True,
remove_columns=squad_dataset[split].column_names,
load_from_cache_file=False,
)
tokenized_datasets.append({split: tokenized_squad})
# Collect labels
all_labels.extend(tokenized_squad['labels'])
# Fit label encoder
label_encoder.fit(all_labels)
num_classes = len(label_encoder.classes_)
print(f"Number of unique labels: {num_classes}")
# Limit the number of classes to top 10 frequent labels
if num_classes > 10:
print("Number of classes exceeds 10. Reducing to top 10 classes.")
from collections import Counter
label_counter = Counter(all_labels)
top_10_labels = [label for label, _ in label_counter.most_common(10)]
print(f"Top 10 labels: {top_10_labels}")
label_mapping = {label: i for i, label in enumerate(top_10_labels)}
label_mapping['other'] = len(top_10_labels)
num_classes = len(top_10_labels) + 1
else:
label_mapping = {label: i for i, label in enumerate(label_encoder.classes_)}
# Update model with correct num_classes
model = FourDimensionalTransformer(
num_layers=16,
embed_dim=7,
num_heads=1,
num_extra_tokens=16,
num_classes=num_classes
).to(device)
def map_labels(labels):
return [label_mapping.get(label, label_mapping['other']) for label in labels]
# Process datasets
for tokenized_dataset in tokenized_datasets:
for split in tokenized_dataset.keys():
tokenized_dataset[split] = tokenized_dataset[split].map(
lambda examples: {'labels': map_labels(examples['labels'])},
batched=True
)
tokenized_dataset[split] = tokenized_dataset[split].filter(
lambda example: example['labels'] < num_classes
)
tokenized_dataset[split].set_format(type='torch', columns=['input_ids', 'labels'])
# Prepare DataLoaders
def prepare_dataloader(tokenized_datasets, split_name, batch_size=4):
dataloaders = []
for tokenized_dataset in tokenized_datasets:
if split_name in tokenized_dataset:
dataset_split = tokenized_dataset[split_name]
dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True)
dataloaders.append(dataloader)
return dataloaders
train_dataloaders = prepare_dataloader(tokenized_datasets, 'train')
val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation')
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
def train(model, train_dataloaders, val_dataloaders, num_epochs=10): #change number of Epochs to your liking
for epoch in range(num_epochs):
model.train()
total_loss = 0
total_batches = 0
for dataloader in train_dataloaders:
for batch in dataloader:
input_ids = batch['input_ids']
labels = batch['labels']
# Reshape input_ids and move to device
input_ids = input_ids[:, :48]
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
# Convert labels to torch.long and move to device
labels = labels.to(device).long()
optimizer.zero_grad()
outputs = model(input_ids)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_batches += 1
avg_loss = total_loss / total_batches if total_batches > 0 else 0
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
# Validation loop
model.eval()
total_correct = 0
total_samples = 0
with torch.no_grad():
for dataloader in val_dataloaders:
for batch in dataloader:
input_ids = batch['input_ids']
labels = batch['labels']
input_ids = input_ids[:, :48]
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
labels = labels.to(device).long()
outputs = model(input_ids)
_, predicted = torch.max(outputs, 1)
total_correct += (predicted == labels).sum().item()
total_samples += labels.size(0)
accuracy = total_correct / total_samples if total_samples > 0 else 0
print(f'Validation Accuracy: {accuracy:.4f}')
torch.save(model.state_dict(), 'trained_model.pth')
# Start training
if train_dataloaders and val_dataloaders:
train(model, train_dataloaders, val_dataloaders)
else:
print("No data loaders available for training. Please check the datasets and preprocessing steps.") |