import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from datasets import load_dataset from transformers import AutoTokenizer from sklearn.preprocessing import LabelEncoder # Import your model from tensor_network.py from tensor_network import FourDimensionalTransformer # Adjust the import path as needed device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # List of dataset identifiers for reasoning and knowledge dataset_ids = [ "race/all", # For reasoning "squad" # For general knowledge ] # Update possible keys possible_text_keys = ['question', 'sentence', 'query'] possible_context_keys = ['context', 'article', 'passage'] possible_label_keys = ['answer', 'answers', 'options'] # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') def tokenize_function_race(examples): texts = [q + " " + p for q, p in zip(examples['question'], examples['article'])] labels = examples['answer'] tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48) tokenized_inputs['labels'] = labels return tokenized_inputs def tokenize_function_squad(examples): texts = [q + " " + c for q, c in zip(examples['question'], examples['context'])] labels = [ans['text'][0] if ans['text'] else '' for ans in examples['answers']] tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48) tokenized_inputs['labels'] = labels return tokenized_inputs # Initialize LabelEncoder label_encoder = LabelEncoder() all_labels = [] # Process RACE dataset race_dataset = load_dataset('race', 'all') tokenized_datasets = [] for split in race_dataset.keys(): tokenized_race = race_dataset[split].map( tokenize_function_race, batched=True, remove_columns=race_dataset[split].column_names, load_from_cache_file=False, ) tokenized_datasets.append({split: tokenized_race}) # Collect labels all_labels.extend(tokenized_race['labels']) # Process SQuAD dataset squad_dataset = load_dataset('squad') for split in squad_dataset.keys(): tokenized_squad = squad_dataset[split].map( tokenize_function_squad, batched=True, remove_columns=squad_dataset[split].column_names, load_from_cache_file=False, ) tokenized_datasets.append({split: tokenized_squad}) # Collect labels all_labels.extend(tokenized_squad['labels']) # Fit label encoder label_encoder.fit(all_labels) num_classes = len(label_encoder.classes_) print(f"Number of unique labels: {num_classes}") # Limit the number of classes to top 10 frequent labels if num_classes > 10: print("Number of classes exceeds 10. Reducing to top 10 classes.") from collections import Counter label_counter = Counter(all_labels) top_10_labels = [label for label, _ in label_counter.most_common(10)] print(f"Top 10 labels: {top_10_labels}") label_mapping = {label: i for i, label in enumerate(top_10_labels)} label_mapping['other'] = len(top_10_labels) num_classes = len(top_10_labels) + 1 else: label_mapping = {label: i for i, label in enumerate(label_encoder.classes_)} # Update model with correct num_classes model = FourDimensionalTransformer( num_layers=16, embed_dim=7, num_heads=1, num_extra_tokens=16, num_classes=num_classes ).to(device) def map_labels(labels): return [label_mapping.get(label, label_mapping['other']) for label in labels] # Process datasets for tokenized_dataset in tokenized_datasets: for split in tokenized_dataset.keys(): tokenized_dataset[split] = tokenized_dataset[split].map( lambda examples: {'labels': map_labels(examples['labels'])}, batched=True ) tokenized_dataset[split] = tokenized_dataset[split].filter( lambda example: example['labels'] < num_classes ) tokenized_dataset[split].set_format(type='torch', columns=['input_ids', 'labels']) # Prepare DataLoaders def prepare_dataloader(tokenized_datasets, split_name, batch_size=4): dataloaders = [] for tokenized_dataset in tokenized_datasets: if split_name in tokenized_dataset: dataset_split = tokenized_dataset[split_name] dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True) dataloaders.append(dataloader) return dataloaders train_dataloaders = prepare_dataloader(tokenized_datasets, 'train') val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation') # Loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) def train(model, train_dataloaders, val_dataloaders, num_epochs=10): #change number of Epochs to your liking for epoch in range(num_epochs): model.train() total_loss = 0 total_batches = 0 for dataloader in train_dataloaders: for batch in dataloader: input_ids = batch['input_ids'] labels = batch['labels'] # Reshape input_ids and move to device input_ids = input_ids[:, :48] input_ids = input_ids.view(-1, 3, 4, 4).float().to(device) # Convert labels to torch.long and move to device labels = labels.to(device).long() optimizer.zero_grad() outputs = model(input_ids) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() total_batches += 1 avg_loss = total_loss / total_batches if total_batches > 0 else 0 print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}') # Validation loop model.eval() total_correct = 0 total_samples = 0 with torch.no_grad(): for dataloader in val_dataloaders: for batch in dataloader: input_ids = batch['input_ids'] labels = batch['labels'] input_ids = input_ids[:, :48] input_ids = input_ids.view(-1, 3, 4, 4).float().to(device) labels = labels.to(device).long() outputs = model(input_ids) _, predicted = torch.max(outputs, 1) total_correct += (predicted == labels).sum().item() total_samples += labels.size(0) accuracy = total_correct / total_samples if total_samples > 0 else 0 print(f'Validation Accuracy: {accuracy:.4f}') torch.save(model.state_dict(), 'trained_model.pth') # Start training if train_dataloaders and val_dataloaders: train(model, train_dataloaders, val_dataloaders) else: print("No data loaders available for training. Please check the datasets and preprocessing steps.")