|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer |
|
from tensor_network import FourDimensionalTransformer |
|
|
|
|
|
dataset_ids = [ |
|
"prithivMLmods/Deepthink-Reasoning", |
|
"ewok-core/ewok-core-1.0", |
|
"MuskumPillerum/General-Knowledge", |
|
"fblgit/tree-of-knowledge", |
|
"CohereForAI/aya_dataset", |
|
"AtlasUnified/Atlas-Reasoning", |
|
"livebench/reasoning", |
|
"SkunkworksAI/reasoning-0.01", |
|
"KingNish/reasoning-base-20k", |
|
"RLHFlow/HH-RLHF-Helpful-standard", |
|
"yitingxie/rlhf-reward-datasets" |
|
] |
|
|
|
|
|
datasets = [load_dataset(dataset_id) for dataset_id in dataset_ids] |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128) |
|
|
|
tokenized_datasets = [dataset.map(tokenize_function, batched=True) for dataset in datasets] |
|
|
|
|
|
|
|
def prepare_dataloader(dataset, batch_size=32): |
|
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) |
|
return DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
|
train_dataloaders = [prepare_dataloader(dataset['train']) for dataset in tokenized_datasets] |
|
val_dataloaders = [prepare_dataloader(dataset['validation']) for dataset in tokenized_datasets] |
|
|
|
|
|
|
|
model = FourDimensionalTransformer( |
|
num_layers=16, |
|
embed_dim=7, |
|
num_heads=1, |
|
num_extra_tokens=16, |
|
num_classes=10 |
|
) |
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
|
|
|
|
def train(model, train_dataloaders, val_dataloaders, num_epochs=10): |
|
for epoch in range(num_epochs): |
|
model.train() |
|
total_loss = 0 |
|
for dataloader in train_dataloaders: |
|
for batch in dataloader: |
|
input_ids = batch['input_ids'] |
|
attention_mask = batch['attention_mask'] |
|
labels = batch['label'] |
|
|
|
optimizer.zero_grad() |
|
outputs = model(input_ids, attention_mask) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
|
|
avg_loss = total_loss / len(dataloader) |
|
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}') |
|
|
|
|
|
model.eval() |
|
total_correct = 0 |
|
with torch.no_grad(): |
|
for dataloader in val_dataloaders: |
|
for batch in dataloader: |
|
input_ids = batch['input_ids'] |
|
attention_mask = batch['attention_mask'] |
|
labels = batch['label'] |
|
|
|
outputs = model(input_ids, attention_mask) |
|
_, predicted = torch.max(outputs, 1) |
|
total_correct += (predicted == labels).sum().item() |
|
|
|
accuracy = total_correct / len(dataloader.dataset) |
|
print(f'Validation Accuracy: {accuracy:.4f}') |
|
|
|
|
|
torch.save(model.state_dict(), 'trained_model.pth') |
|
|
|
|
|
|
|
train(model, train_dataloaders, val_dataloaders) |
|
|