Guest / train_model.py
Prositron's picture
Update train_model.py
725379f verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder
# Import your model from tensor_network.py
from tensor_network import FourDimensionalTransformer # Adjust the import path as needed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# List of dataset identifiers for reasoning and knowledge
dataset_ids = [
"race/all", # For reasoning
"squad" # For general knowledge
]
# Update possible keys
possible_text_keys = ['question', 'sentence', 'query']
possible_context_keys = ['context', 'article', 'passage']
possible_label_keys = ['answer', 'answers', 'options']
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
def tokenize_function_race(examples):
texts = [q + " " + p for q, p in zip(examples['question'], examples['article'])]
labels = examples['answer']
tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
tokenized_inputs['labels'] = labels
return tokenized_inputs
def tokenize_function_squad(examples):
texts = [q + " " + c for q, c in zip(examples['question'], examples['context'])]
labels = [ans['text'][0] if ans['text'] else '' for ans in examples['answers']]
tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48)
tokenized_inputs['labels'] = labels
return tokenized_inputs
# Initialize LabelEncoder
label_encoder = LabelEncoder()
all_labels = []
# Process RACE dataset
race_dataset = load_dataset('race', 'all')
tokenized_datasets = []
for split in race_dataset.keys():
tokenized_race = race_dataset[split].map(
tokenize_function_race,
batched=True,
remove_columns=race_dataset[split].column_names,
load_from_cache_file=False,
)
tokenized_datasets.append({split: tokenized_race})
# Collect labels
all_labels.extend(tokenized_race['labels'])
# Process SQuAD dataset
squad_dataset = load_dataset('squad')
for split in squad_dataset.keys():
tokenized_squad = squad_dataset[split].map(
tokenize_function_squad,
batched=True,
remove_columns=squad_dataset[split].column_names,
load_from_cache_file=False,
)
tokenized_datasets.append({split: tokenized_squad})
# Collect labels
all_labels.extend(tokenized_squad['labels'])
# Fit label encoder
label_encoder.fit(all_labels)
num_classes = len(label_encoder.classes_)
print(f"Number of unique labels: {num_classes}")
# Limit the number of classes to top 10 frequent labels
if num_classes > 10:
print("Number of classes exceeds 10. Reducing to top 10 classes.")
from collections import Counter
label_counter = Counter(all_labels)
top_10_labels = [label for label, _ in label_counter.most_common(10)]
print(f"Top 10 labels: {top_10_labels}")
label_mapping = {label: i for i, label in enumerate(top_10_labels)}
label_mapping['other'] = len(top_10_labels)
num_classes = len(top_10_labels) + 1
else:
label_mapping = {label: i for i, label in enumerate(label_encoder.classes_)}
# Update model with correct num_classes
model = FourDimensionalTransformer(
num_layers=16,
embed_dim=7,
num_heads=1,
num_extra_tokens=16,
num_classes=num_classes
).to(device)
def map_labels(labels):
return [label_mapping.get(label, label_mapping['other']) for label in labels]
# Process datasets
for tokenized_dataset in tokenized_datasets:
for split in tokenized_dataset.keys():
tokenized_dataset[split] = tokenized_dataset[split].map(
lambda examples: {'labels': map_labels(examples['labels'])},
batched=True
)
tokenized_dataset[split] = tokenized_dataset[split].filter(
lambda example: example['labels'] < num_classes
)
tokenized_dataset[split].set_format(type='torch', columns=['input_ids', 'labels'])
# Prepare DataLoaders
def prepare_dataloader(tokenized_datasets, split_name, batch_size=4):
dataloaders = []
for tokenized_dataset in tokenized_datasets:
if split_name in tokenized_dataset:
dataset_split = tokenized_dataset[split_name]
dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True)
dataloaders.append(dataloader)
return dataloaders
train_dataloaders = prepare_dataloader(tokenized_datasets, 'train')
val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation')
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
def train(model, train_dataloaders, val_dataloaders, num_epochs=10): #change number of Epochs to your liking
for epoch in range(num_epochs):
model.train()
total_loss = 0
total_batches = 0
for dataloader in train_dataloaders:
for batch in dataloader:
input_ids = batch['input_ids']
labels = batch['labels']
# Reshape input_ids and move to device
input_ids = input_ids[:, :48]
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
# Convert labels to torch.long and move to device
labels = labels.to(device).long()
optimizer.zero_grad()
outputs = model(input_ids)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_batches += 1
avg_loss = total_loss / total_batches if total_batches > 0 else 0
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
# Validation loop
model.eval()
total_correct = 0
total_samples = 0
with torch.no_grad():
for dataloader in val_dataloaders:
for batch in dataloader:
input_ids = batch['input_ids']
labels = batch['labels']
input_ids = input_ids[:, :48]
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device)
labels = labels.to(device).long()
outputs = model(input_ids)
_, predicted = torch.max(outputs, 1)
total_correct += (predicted == labels).sum().item()
total_samples += labels.size(0)
accuracy = total_correct / total_samples if total_samples > 0 else 0
print(f'Validation Accuracy: {accuracy:.4f}')
torch.save(model.state_dict(), 'trained_model.pth')
# Start training
if train_dataloaders and val_dataloaders:
train(model, train_dataloaders, val_dataloaders)
else:
print("No data loaders available for training. Please check the datasets and preprocessing steps.")