English
LSTM-1225 / trainer.py
Fishfishfishfishfish's picture
Update trainer.py
e142b58 verified
raw
history blame
3.56 kB
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
from torch.utils.data import Dataset, DataLoader
from safetensors.torch import save_file
import logging
import json
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Hyperparameters
sequence_length = 64
batch_size = 16
num_epochs = 1
learning_rate = 0.00001
embedding_dim = 256
hidden_dim = 800
num_layers = 4
# Read the text file
logging.info('Reading the text file...')
with open('fulltext.json', 'r') as file:
text = file.read()
logging.info('Text file read successfully.')
# Preprocess the text
logging.info('Preprocessing the text...')
words = json.loads(text)
vocab = set(words)
vocab.add('<pad>')
vocab.add('<UNK>')
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
vocab_size = len(vocab)
logging.info(f'Vocabulary size: {vocab_size}')
#logging.info(f'Word to index mapping: {word2idx}')
# Create sequences
logging.info('Creating sequences...')
sequences = []
for i in range(len(words) - sequence_length):
seq = words[i:i + sequence_length]
label = words[i + sequence_length]
sequences.append((seq, label))
logging.info(f'Number of sequences: {len(sequences)}')
# Dataset and DataLoader
class TextDataset(Dataset):
def __init__(self, sequences, word2idx):
self.sequences = sequences
self.word2idx = word2idx
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
seq, label = self.sequences[idx]
seq_idx = [self.word2idx.get(word, self.word2idx['<UNK>']) for word in seq]
label_idx = self.word2idx.get(label, self.word2idx['<UNK>'])
return torch.tensor(seq_idx, dtype=torch.long), torch.tensor(label_idx, dtype=torch.long)
logging.info('Creating dataset and dataloader...')
dataset = TextDataset(sequences, word2idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# LSTM Model
class LSTMModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
super(LSTMModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
embeds = self.embedding(x)
lstm_out, _ = self.lstm(embeds)
logits = self.fc(lstm_out[:, -1, :])
return logits
logging.info('Initializing the LSTM model...')
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Training loop
logging.info('Starting training...')
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(dataloader):
inputs, targets = batch
outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
logging.info(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}')
# Save the model
logging.info('Saving the model...')
save_file(model.state_dict(), 'lstm_model.safetensors')
with open('word2idx.pkl', 'wb') as f:
pickle.dump(word2idx, f)
with open('idx2word.pkl', 'wb') as f:
pickle.dump(idx2word, f)
logging.info('Model and vocabulary saved successfully.')