English
File size: 3,555 Bytes
a1b5703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e142b58
a1b5703
 
 
 
 
 
 
 
030b3c0
a1b5703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
from torch.utils.data import Dataset, DataLoader
from safetensors.torch import save_file
import logging
import json

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Hyperparameters
sequence_length = 64
batch_size = 16
num_epochs = 1
learning_rate = 0.00001
embedding_dim = 256
hidden_dim = 800
num_layers = 4

# Read the text file
logging.info('Reading the text file...')
with open('fulltext.json', 'r') as file:
    text = file.read()
logging.info('Text file read successfully.')

# Preprocess the text
logging.info('Preprocessing the text...')
words = json.loads(text)
vocab = set(words)
vocab.add('<pad>')
vocab.add('<UNK>')
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
vocab_size = len(vocab)

logging.info(f'Vocabulary size: {vocab_size}')
#logging.info(f'Word to index mapping: {word2idx}')

# Create sequences
logging.info('Creating sequences...')
sequences = []
for i in range(len(words) - sequence_length):
    seq = words[i:i + sequence_length]
    label = words[i + sequence_length]
    sequences.append((seq, label))

logging.info(f'Number of sequences: {len(sequences)}')

# Dataset and DataLoader
class TextDataset(Dataset):
    def __init__(self, sequences, word2idx):
        self.sequences = sequences
        self.word2idx = word2idx

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq, label = self.sequences[idx]
        seq_idx = [self.word2idx.get(word, self.word2idx['<UNK>']) for word in seq]
        label_idx = self.word2idx.get(label, self.word2idx['<UNK>'])
        return torch.tensor(seq_idx, dtype=torch.long), torch.tensor(label_idx, dtype=torch.long)

logging.info('Creating dataset and dataloader...')
dataset = TextDataset(sequences, word2idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# LSTM Model
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embeds = self.embedding(x)
        lstm_out, _ = self.lstm(embeds)
        logits = self.fc(lstm_out[:, -1, :])
        return logits

logging.info('Initializing the LSTM model...')
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
logging.info('Starting training...')
for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(dataloader):
        inputs, targets = batch
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 10 == 0:
            logging.info(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}')

# Save the model
logging.info('Saving the model...')
save_file(model.state_dict(), 'lstm_model.safetensors')
with open('word2idx.pkl', 'wb') as f:
    pickle.dump(word2idx, f)
with open('idx2word.pkl', 'wb') as f:
    pickle.dump(idx2word, f)

logging.info('Model and vocabulary saved successfully.')