|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer |
|
from sklearn.preprocessing import LabelEncoder |
|
|
|
|
|
from tensor_network import FourDimensionalTransformer |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
dataset_ids = [ |
|
"race/all", |
|
"squad" |
|
] |
|
|
|
|
|
possible_text_keys = ['question', 'sentence', 'query'] |
|
possible_context_keys = ['context', 'article', 'passage'] |
|
possible_label_keys = ['answer', 'answers', 'options'] |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
def tokenize_function_race(examples): |
|
texts = [q + " " + p for q, p in zip(examples['question'], examples['article'])] |
|
labels = examples['answer'] |
|
tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48) |
|
tokenized_inputs['labels'] = labels |
|
return tokenized_inputs |
|
|
|
def tokenize_function_squad(examples): |
|
texts = [q + " " + c for q, c in zip(examples['question'], examples['context'])] |
|
labels = [ans['text'][0] if ans['text'] else '' for ans in examples['answers']] |
|
tokenized_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=48) |
|
tokenized_inputs['labels'] = labels |
|
return tokenized_inputs |
|
|
|
|
|
label_encoder = LabelEncoder() |
|
all_labels = [] |
|
|
|
|
|
race_dataset = load_dataset('race', 'all') |
|
tokenized_datasets = [] |
|
for split in race_dataset.keys(): |
|
tokenized_race = race_dataset[split].map( |
|
tokenize_function_race, |
|
batched=True, |
|
remove_columns=race_dataset[split].column_names, |
|
load_from_cache_file=False, |
|
) |
|
tokenized_datasets.append({split: tokenized_race}) |
|
|
|
all_labels.extend(tokenized_race['labels']) |
|
|
|
|
|
squad_dataset = load_dataset('squad') |
|
for split in squad_dataset.keys(): |
|
tokenized_squad = squad_dataset[split].map( |
|
tokenize_function_squad, |
|
batched=True, |
|
remove_columns=squad_dataset[split].column_names, |
|
load_from_cache_file=False, |
|
) |
|
tokenized_datasets.append({split: tokenized_squad}) |
|
|
|
all_labels.extend(tokenized_squad['labels']) |
|
|
|
|
|
label_encoder.fit(all_labels) |
|
num_classes = len(label_encoder.classes_) |
|
print(f"Number of unique labels: {num_classes}") |
|
|
|
|
|
if num_classes > 10: |
|
print("Number of classes exceeds 10. Reducing to top 10 classes.") |
|
from collections import Counter |
|
label_counter = Counter(all_labels) |
|
top_10_labels = [label for label, _ in label_counter.most_common(10)] |
|
print(f"Top 10 labels: {top_10_labels}") |
|
label_mapping = {label: i for i, label in enumerate(top_10_labels)} |
|
label_mapping['other'] = len(top_10_labels) |
|
num_classes = len(top_10_labels) + 1 |
|
else: |
|
label_mapping = {label: i for i, label in enumerate(label_encoder.classes_)} |
|
|
|
|
|
model = FourDimensionalTransformer( |
|
num_layers=16, |
|
embed_dim=7, |
|
num_heads=1, |
|
num_extra_tokens=16, |
|
num_classes=num_classes |
|
).to(device) |
|
|
|
def map_labels(labels): |
|
return [label_mapping.get(label, label_mapping['other']) for label in labels] |
|
|
|
|
|
for tokenized_dataset in tokenized_datasets: |
|
for split in tokenized_dataset.keys(): |
|
tokenized_dataset[split] = tokenized_dataset[split].map( |
|
lambda examples: {'labels': map_labels(examples['labels'])}, |
|
batched=True |
|
) |
|
tokenized_dataset[split] = tokenized_dataset[split].filter( |
|
lambda example: example['labels'] < num_classes |
|
) |
|
tokenized_dataset[split].set_format(type='torch', columns=['input_ids', 'labels']) |
|
|
|
|
|
def prepare_dataloader(tokenized_datasets, split_name, batch_size=4): |
|
dataloaders = [] |
|
for tokenized_dataset in tokenized_datasets: |
|
if split_name in tokenized_dataset: |
|
dataset_split = tokenized_dataset[split_name] |
|
dataloader = DataLoader(dataset_split, batch_size=batch_size, shuffle=True) |
|
dataloaders.append(dataloader) |
|
return dataloaders |
|
|
|
train_dataloaders = prepare_dataloader(tokenized_datasets, 'train') |
|
val_dataloaders = prepare_dataloader(tokenized_datasets, 'validation') |
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
|
|
def train(model, train_dataloaders, val_dataloaders, num_epochs=10): |
|
for epoch in range(num_epochs): |
|
model.train() |
|
total_loss = 0 |
|
total_batches = 0 |
|
for dataloader in train_dataloaders: |
|
for batch in dataloader: |
|
input_ids = batch['input_ids'] |
|
labels = batch['labels'] |
|
|
|
|
|
input_ids = input_ids[:, :48] |
|
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device) |
|
|
|
|
|
labels = labels.to(device).long() |
|
|
|
optimizer.zero_grad() |
|
outputs = model(input_ids) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
total_batches += 1 |
|
|
|
avg_loss = total_loss / total_batches if total_batches > 0 else 0 |
|
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}') |
|
|
|
|
|
model.eval() |
|
total_correct = 0 |
|
total_samples = 0 |
|
with torch.no_grad(): |
|
for dataloader in val_dataloaders: |
|
for batch in dataloader: |
|
input_ids = batch['input_ids'] |
|
labels = batch['labels'] |
|
|
|
input_ids = input_ids[:, :48] |
|
input_ids = input_ids.view(-1, 3, 4, 4).float().to(device) |
|
labels = labels.to(device).long() |
|
|
|
outputs = model(input_ids) |
|
_, predicted = torch.max(outputs, 1) |
|
total_correct += (predicted == labels).sum().item() |
|
total_samples += labels.size(0) |
|
accuracy = total_correct / total_samples if total_samples > 0 else 0 |
|
print(f'Validation Accuracy: {accuracy:.4f}') |
|
|
|
torch.save(model.state_dict(), 'trained_model.pth') |
|
|
|
|
|
if train_dataloaders and val_dataloaders: |
|
train(model, train_dataloaders, val_dataloaders) |
|
else: |
|
print("No data loaders available for training. Please check the datasets and preprocessing steps.") |