import torch |
import torch.nn as nn |
import torch.optim as optim |
import pickle |
from torch.utils.data import Dataset, DataLoader |
from safetensors.torch import load_file, save_file |
import logging |
import json |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
sequence_length = 16 |
batch_size = 32 |
num_epochs = 1 |
learning_rate = 0.00001 |
embedding_dim = 256 |
hidden_dim = 512 |
num_layers = 2 |
class LSTMModel(nn.Module): |
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers): |
super(LSTMModel, self).__init__() |
self.embedding = nn.Embedding(vocab_size, embedding_dim) |
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True) |
self.fc = nn.Linear(hidden_dim, vocab_size) |
def forward(self, x): |
embeds = self.embedding(x) |
lstm_out, _ = self.lstm(embeds) |
logits = self.fc(lstm_out[:, -1, :]) |
return logits |
logging.info('Loading the model and vocabulary...') |
model_state_dict = load_file('lstm_model.safetensors') |
with open('word2idx.pkl', 'rb') as f: |
word2idx = pickle.load(f) |
with open('idx2word.pkl', 'rb') as f: |
idx2word = pickle.load(f) |
vocab_size = len(word2idx) |
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers) |
model.load_state_dict(model_state_dict) |
model.train() |
logging.info('Model and vocabulary loaded successfully.') |
total_params = sum(p.numel() for p in model.parameters()) |
logging.info(f'Total number of parameters: {total_params}') |
logging.info('Reading the text file...') |
with open('text.txt', 'r') as file: |
text = file.read() |
logging.info('Text file read successfully.') |
logging.info('Preprocessing the text...') |
words = json.loads(text) |
sequences = [] |
for i in range(len(words) - sequence_length): |
seq = words[i:i + sequence_length] |
label = words[i + sequence_length] |
sequences.append((seq, label)) |
logging.info(f'Number of sequences: {len(sequences)}') |
class TextDataset(Dataset): |
def __init__(self, sequences, word2idx): |
self.sequences = sequences |
self.word2idx = word2idx |
def __len__(self): |
return len(self.sequences) |
def __getitem__(self, idx): |
seq, label = self.sequences[idx] |
seq_idx = [self.word2idx.get(word, self.word2idx['<UNK>']) for word in seq] |
label_idx = self.word2idx.get(label, self.word2idx['<UNK>']) |
return torch.tensor(seq_idx, dtype=torch.long), torch.tensor(label_idx, dtype=torch.long) |
logging.info('Creating dataset and dataloader...') |
dataset = TextDataset(sequences, word2idx) |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
criterion = nn.CrossEntropyLoss() |
optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
logging.info('Starting continued training...') |
for epoch in range(num_epochs): |
for batch_idx, batch in enumerate(dataloader): |
inputs, targets = batch |
outputs = model(inputs) |
loss = criterion(outputs, targets) |
optimizer.zero_grad() |
loss.backward() |
optimizer.step() |
if batch_idx % 10 == 0: |
logging.info(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}') |
logging.info('Saving the updated model...') |
save_file(model.state_dict(), 'lstm_model.safetensors') |
with open('word2idx.pkl', 'wb') as f: |
pickle.dump(word2idx, f) |
with open('idx2word.pkl', 'wb') as f: |
pickle.dump(idx2word, f) |
logging.info('Updated model and vocabulary saved successfully.') |