import torch import torch.nn as nn import torch.optim as optim import pickle from torch.utils.data import Dataset, DataLoader from safetensors.torch import load_file, save_file import logging import json # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Hyperparameters sequence_length = 16 batch_size = 32 num_epochs = 1 # Continue training for 1 more epoch learning_rate = 0.00001 embedding_dim = 256 hidden_dim = 512 num_layers = 2 # LSTM Model class LSTMModel(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers): super(LSTMModel, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True) self.fc = nn.Linear(hidden_dim, vocab_size) def forward(self, x): embeds = self.embedding(x) lstm_out, _ = self.lstm(embeds) logits = self.fc(lstm_out[:, -1, :]) return logits # Load the model and vocabulary logging.info('Loading the model and vocabulary...') model_state_dict = load_file('lstm_model.safetensors') with open('word2idx.pkl', 'rb') as f: word2idx = pickle.load(f) with open('idx2word.pkl', 'rb') as f: idx2word = pickle.load(f) vocab_size = len(word2idx) model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers) model.load_state_dict(model_state_dict) model.train() logging.info('Model and vocabulary loaded successfully.') # Output the total number of parameters total_params = sum(p.numel() for p in model.parameters()) logging.info(f'Total number of parameters: {total_params}') # Read the text file logging.info('Reading the text file...') with open('text.txt', 'r') as file: text = file.read() logging.info('Text file read successfully.') # Preprocess the text logging.info('Preprocessing the text...') words = json.loads(text) sequences = [] for i in range(len(words) - sequence_length): seq = words[i:i + sequence_length] label = words[i + sequence_length] sequences.append((seq, label)) logging.info(f'Number of sequences: {len(sequences)}') # Dataset and DataLoader class TextDataset(Dataset): def __init__(self, sequences, word2idx): self.sequences = sequences self.word2idx = word2idx def __len__(self): return len(self.sequences) def __getitem__(self, idx): seq, label = self.sequences[idx] seq_idx = [self.word2idx.get(word, self.word2idx['']) for word in seq] label_idx = self.word2idx.get(label, self.word2idx['']) return torch.tensor(seq_idx, dtype=torch.long), torch.tensor(label_idx, dtype=torch.long) logging.info('Creating dataset and dataloader...') dataset = TextDataset(sequences, word2idx) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Continue training criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) logging.info('Starting continued training...') for epoch in range(num_epochs): for batch_idx, batch in enumerate(dataloader): inputs, targets = batch outputs = model(inputs) loss = criterion(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % 10 == 0: logging.info(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}') # Save the updated model logging.info('Saving the updated model...') save_file(model.state_dict(), 'lstm_model.safetensors') with open('word2idx.pkl', 'wb') as f: pickle.dump(word2idx, f) with open('idx2word.pkl', 'wb') as f: pickle.dump(idx2word, f) logging.info('Updated model and vocabulary saved successfully.')