Guest / train_model.py
Prositron's picture
Update train_model.py
dff6ebd verified
raw
history blame
6.05 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder
# Import your model from tensor_network.py
from tensor_network import FourDimensionalTransformer # Adjust the import path as needed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# List of dataset identifiers
dataset_ids = [
"prithivMLmods/Deepthink-Reasoning",
"ewok-core/ewok-core-1.0",
"MuskumPillerum/General-Knowledge",
"fblgit/tree-of-knowledge",
"CohereForAI/aya_dataset",
"AtlasUnified/Atlas-Reasoning",
"livebench/reasoning",
"SkunkworksAI/reasoning-0.01",
"KingNish/reasoning-base-20k",
"RLHFlow/HH-RLHF-Helpful-standard",
"yitingxie/rlhf-reward-datasets"
]
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
def tokenize_function(examples):
possible_text_keys = ['text', 'content', 'question', 'passage', 'prompt', 'input']
possible_label_keys = ['label', 'answer', 'response', 'output', 'target']
text_key = next((k for k in possible_text_keys if k in examples), None)
if text_key is None:
text_key = list(examples.keys())[0]
label_key = next((k for k in possible_label_keys if k in examples), None)
if label_key is None:
labels = [0] * len(examples[text_key]) # Default label
else:
labels = examples[label_key]
texts = [str(t) for t in examples[text_key]]
tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
tokenized_inputs['labels'] = labels
return tokenized_inputs
# Initialize LabelEncoder
label_encoder = LabelEncoder()
all_labels = []
# Process each dataset individually
tokenized_datasets = []
for dataset_id in dataset_ids:
try:
dataset = load_dataset(dataset_id)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Collect labels for label encoding
for split in tokenized_dataset.keys():
if 'labels' in tokenized_dataset[split].features:
all_labels.extend(tokenized_dataset[split]['labels'])
tokenized_datasets.append(tokenized_dataset)
except Exception as e:
print(f"Could not process dataset {dataset_id}: {e}")
# Fit label encoder
label_encoder.fit(all_labels)
num_classes = len(label_encoder.classes_)
print(f"Number of unique labels: {num_classes}")
if num_classes > 10:
print("Warning: Number of unique labels exceeds the number of classes. Adjusting the dataset or model is required.")
exit()
# Transform labels in each dataset
for dataset in tokenized_datasets:
for split in dataset.keys():
if 'labels' in dataset[split].features:
dataset[split] = dataset[split].map(
lambda examples: {'labels': label_encoder.transform(examples['labels'])},
batched=True
)
# Prepare DataLoaders
def prepare_dataloader(dataset_splits, split_name, batch_size=2):
dataloaders = []
for dataset in dataset_splits:
if split_name in dataset:
dataset_split = dataset[split_name]
dataset_split.set_format(type='torch', columns=['input_ids', 'labels'])
dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True)
dataloaders.append(dataloader)
return dataloaders
train_dataloaders = prepare_dataloader(tokenized_datasets, 'train')
val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation')
# Initialize the model
model = FourDimensionalTransformer(
num_layers=16,
embed_dim=7,
num_heads=1,
num_extra_tokens=16,
num_classes=10 # Using 10 classes as per your model
).to(device)
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
def train(model, train_dataloaders, val_dataloaders, num_epochs=10):
for epoch in range(num_epochs):
model.train()
total_loss = 0
total_batches = 0
for dataloader in train_dataloaders:
for batch in dataloader:
input_ids = batch['input_ids']
labels = batch['labels']
# Reshape input_ids and move to device
input_ids = input_ids[:, :48] # Ensure length is 48
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
# Convert labels to torch.long and move to device
labels = labels.type(torch.long).to(device)
optimizer.zero_grad()
outputs = model(input_ids)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_batches += 1
avg_loss = total_loss / total_batches if total_batches > 0 else 0
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
# Validation loop
model.eval()
total_correct = 0
total_samples = 0
with torch.no_grad():
for dataloader in val_dataloaders:
for batch in dataloader:
input_ids = batch['input_ids']
labels = batch['labels']
input_ids = input_ids[:, :48] # Ensure length is 48
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
labels = labels.type(torch.long).to(device)
outputs = model(input_ids)
_, predicted = torch.max(outputs, 1)
total_correct += (predicted == labels).sum().item()
total_samples += labels.size(0)
accuracy = total_correct / total_samples if total_samples > 0 else 0
print(f'Validation Accuracy: {accuracy:.4f}')
torch.save(model.state_dict(), 'trained_model.pth')
# Start training
train(model, train_dataloaders, val_dataloaders)