import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from datasets import load_dataset from transformers import AutoTokenizer from sklearn.preprocessing import LabelEncoder # Import your model from tensor_network.py from tensor_network import FourDimensionalTransformer # Adjust the import path as needed device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # List of dataset identifiers dataset_ids = [ "prithivMLmods/Deepthink-Reasoning", "ewok-core/ewok-core-1.0", "MuskumPillerum/General-Knowledge", "fblgit/tree-of-knowledge", "CohereForAI/aya_dataset", "AtlasUnified/Atlas-Reasoning", "livebench/reasoning", "SkunkworksAI/reasoning-0.01", "KingNish/reasoning-base-20k", "RLHFlow/HH-RLHF-Helpful-standard", "yitingxie/rlhf-reward-datasets" ] # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') def tokenize_function(examples): possible_text_keys = ['text', 'content', 'question', 'passage', 'prompt', 'input'] possible_label_keys = ['label', 'answer', 'response', 'output', 'target'] text_key = next((k for k in possible_text_keys if k in examples), None) if text_key is None: text_key = list(examples.keys())[0] label_key = next((k for k in possible_label_keys if k in examples), None) if label_key is None: labels = [0] * len(examples[text_key]) # Default label else: labels = examples[label_key] texts = [str(t) for t in examples[text_key]] tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48) tokenized_inputs['labels'] = labels return tokenized_inputs # Initialize LabelEncoder label_encoder = LabelEncoder() all_labels = [] # Process each dataset individually tokenized_datasets = [] for dataset_id in dataset_ids: try: dataset = load_dataset(dataset_id) tokenized_dataset = dataset.map(tokenize_function, batched=True) # Collect labels for label encoding for split in tokenized_dataset.keys(): if 'labels' in tokenized_dataset[split].features: all_labels.extend(tokenized_dataset[split]['labels']) tokenized_datasets.append(tokenized_dataset) except Exception as e: print(f"Could not process dataset {dataset_id}: {e}") # Fit label encoder label_encoder.fit(all_labels) num_classes = len(label_encoder.classes_) print(f"Number of unique labels: {num_classes}") if num_classes > 10: print("Warning: Number of unique labels exceeds the number of classes. Adjusting the dataset or model is required.") exit() # Transform labels in each dataset for dataset in tokenized_datasets: for split in dataset.keys(): if 'labels' in dataset[split].features: dataset[split] = dataset[split].map( lambda examples: {'labels': label_encoder.transform(examples['labels'])}, batched=True ) # Prepare DataLoaders def prepare_dataloader(dataset_splits, split_name, batch_size=2): dataloaders = [] for dataset in dataset_splits: if split_name in dataset: dataset_split = dataset[split_name] dataset_split.set_format(type='torch', columns=['input_ids', 'labels']) dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True) dataloaders.append(dataloader) return dataloaders train_dataloaders = prepare_dataloader(tokenized_datasets, 'train') val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation') # Initialize the model model = FourDimensionalTransformer( num_layers=16, embed_dim=7, num_heads=1, num_extra_tokens=16, num_classes=10 # Using 10 classes as per your model ).to(device) # Loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) def train(model, train_dataloaders, val_dataloaders, num_epochs=10): for epoch in range(num_epochs): model.train() total_loss = 0 total_batches = 0 for dataloader in train_dataloaders: for batch in dataloader: input_ids = batch['input_ids'] labels = batch['labels'] # Reshape input_ids and move to device input_ids = input_ids[:, :48] # Ensure length is 48 input_ids = input_ids.view(-1, 3, 4, 4).float().to(device) # Convert labels to torch.long and move to device labels = labels.type(torch.long).to(device) optimizer.zero_grad() outputs = model(input_ids) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() total_batches += 1 avg_loss = total_loss / total_batches if total_batches > 0 else 0 print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}') # Validation loop model.eval() total_correct = 0 total_samples = 0 with torch.no_grad(): for dataloader in val_dataloaders: for batch in dataloader: input_ids = batch['input_ids'] labels = batch['labels'] input_ids = input_ids[:, :48] # Ensure length is 48 input_ids = input_ids.view(-1, 3, 4, 4).float().to(device) labels = labels.type(torch.long).to(device) outputs = model(input_ids) _, predicted = torch.max(outputs, 1) total_correct += (predicted == labels).sum().item() total_samples += labels.size(0) accuracy = total_correct / total_samples if total_samples > 0 else 0 print(f'Validation Accuracy: {accuracy:.4f}') torch.save(model.state_dict(), 'trained_model.pth') # Start training train(model, train_dataloaders, val_dataloaders)