|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer |
|
from sklearn.preprocessing import LabelEncoder |
|
|
|
|
|
from tensor_network import FourDimensionalTransformer |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
dataset_ids = [ |
|
"prithivMLmods/Deepthink-Reasoning", |
|
"ewok-core/ewok-core-1.0", |
|
"MuskumPillerum/General-Knowledge", |
|
"fblgit/tree-of-knowledge", |
|
"CohereForAI/aya_dataset", |
|
"AtlasUnified/Atlas-Reasoning", |
|
"livebench/reasoning", |
|
"SkunkworksAI/reasoning-0.01", |
|
"KingNish/reasoning-base-20k", |
|
"RLHFlow/HH-RLHF-Helpful-standard", |
|
"yitingxie/rlhf-reward-datasets" |
|
] |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
def tokenize_function(examples): |
|
possible_text_keys = ['text', 'content', 'question', 'passage', 'prompt', 'input'] |
|
possible_label_keys = ['label', 'answer', 'response', 'output', 'target'] |
|
|
|
text_key = next((k for k in possible_text_keys if k in examples), None) |
|
if text_key is None: |
|
text_key = list(examples.keys())[0] |
|
|
|
label_key = next((k for k in possible_label_keys if k in examples), None) |
|
if label_key is None: |
|
labels = [0] * len(examples[text_key]) |
|
else: |
|
labels = examples[label_key] |
|
|
|
texts = [str(t) for t in examples[text_key]] |
|
tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48) |
|
tokenized_inputs['labels'] = labels |
|
return tokenized_inputs |
|
|
|
|
|
label_encoder = LabelEncoder() |
|
all_labels = [] |
|
|
|
|
|
tokenized_datasets = [] |
|
for dataset_id in dataset_ids: |
|
try: |
|
dataset = load_dataset(dataset_id) |
|
tokenized_dataset = dataset.map(tokenize_function, batched=True) |
|
|
|
|
|
for split in tokenized_dataset.keys(): |
|
if 'labels' in tokenized_dataset[split].features: |
|
all_labels.extend(tokenized_dataset[split]['labels']) |
|
|
|
tokenized_datasets.append(tokenized_dataset) |
|
except Exception as e: |
|
print(f"Could not process dataset {dataset_id}: {e}") |
|
|
|
|
|
label_encoder.fit(all_labels) |
|
num_classes = len(label_encoder.classes_) |
|
print(f"Number of unique labels: {num_classes}") |
|
|
|
if num_classes > 10: |
|
print("Warning: Number of unique labels exceeds the number of classes. Adjusting the dataset or model is required.") |
|
exit() |
|
|
|
|
|
for dataset in tokenized_datasets: |
|
for split in dataset.keys(): |
|
if 'labels' in dataset[split].features: |
|
dataset[split] = dataset[split].map( |
|
lambda examples: {'labels': label_encoder.transform(examples['labels'])}, |
|
batched=True |
|
) |
|
|
|
|
|
def prepare_dataloader(dataset_splits, split_name, batch_size=2): |
|
dataloaders = [] |
|
for dataset in dataset_splits: |
|
if split_name in dataset: |
|
dataset_split = dataset[split_name] |
|
dataset_split.set_format(type='torch', columns=['input_ids', 'labels']) |
|
dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True) |
|
dataloaders.append(dataloader) |
|
return dataloaders |
|
|
|
train_dataloaders = prepare_dataloader(tokenized_datasets, 'train') |
|
val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation') |
|
|
|
|
|
model = FourDimensionalTransformer( |
|
num_layers=16, |
|
embed_dim=7, |
|
num_heads=1, |
|
num_extra_tokens=16, |
|
num_classes=10 |
|
).to(device) |
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
|
|
def train(model, train_dataloaders, val_dataloaders, num_epochs=10): |
|
for epoch in range(num_epochs): |
|
model.train() |
|
total_loss = 0 |
|
total_batches = 0 |
|
for dataloader in train_dataloaders: |
|
for batch in dataloader: |
|
input_ids = batch['input_ids'] |
|
labels = batch['labels'] |
|
|
|
|
|
input_ids = input_ids[:, :48] |
|
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device) |
|
|
|
|
|
labels = labels.type(torch.long).to(device) |
|
|
|
optimizer.zero_grad() |
|
outputs = model(input_ids) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
total_batches += 1 |
|
|
|
avg_loss = total_loss / total_batches if total_batches > 0 else 0 |
|
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}') |
|
|
|
|
|
model.eval() |
|
total_correct = 0 |
|
total_samples = 0 |
|
with torch.no_grad(): |
|
for dataloader in val_dataloaders: |
|
for batch in dataloader: |
|
input_ids = batch['input_ids'] |
|
labels = batch['labels'] |
|
|
|
input_ids = input_ids[:, :48] |
|
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device) |
|
labels = labels.type(torch.long).to(device) |
|
|
|
outputs = model(input_ids) |
|
_, predicted = torch.max(outputs, 1) |
|
total_correct += (predicted == labels).sum().item() |
|
total_samples += labels.size(0) |
|
accuracy = total_correct / total_samples if total_samples > 0 else 0 |
|
print(f'Validation Accuracy: {accuracy:.4f}') |
|
|
|
torch.save(model.state_dict(), 'trained_model.pth') |
|
|
|
|
|
train(model, train_dataloaders, val_dataloaders) |
|
|