import torch import torch.nn as nn import torch.nn.functional as F import pickle from safetensors.torch import load_file import logging # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Hyperparameters embedding_dim = 8 hidden_dim = 16 num_layers = 1 sequence_length = 64 temp = 1.0 # Temperature parameter top_k = 10 # Top-k sampling parameter # LSTM Model class LSTMModel(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers): super(LSTMModel, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True) self.fc = nn.Linear(hidden_dim, vocab_size) def forward(self, x): embeds = self.embedding(x) lstm_out, _ = self.lstm(embeds) logits = self.fc(lstm_out[:, -1, :]) return logits # Load the model and vocabulary logging.info('Loading the model and vocabulary...') model_state_dict = load_file('lstm_model.safetensors') with open('word2idx.pkl', 'rb') as f: word2idx = pickle.load(f) with open('idx2word.pkl', 'rb') as f: idx2word = pickle.load(f) vocab_size = len(word2idx) model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers) model.load_state_dict(model_state_dict) model.eval() logging.info('Model and vocabulary loaded successfully.') # Function to predict the next word with temperature and top-k sampling def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k): model.eval() with torch.no_grad(): seq_idx = [word2idx.get(word, word2idx['']) for word in sequence.split()] seq_idx = seq_idx[-sequence_length:] # Ensure the sequence length is correct seq_tensor = torch.tensor(seq_idx, dtype=torch.long).unsqueeze(0) outputs = model(seq_tensor) outputs = outputs / temp # Apply temperature probs = F.softmax(outputs, dim=1).squeeze() top_k_probs, top_k_idx = torch.topk(probs, top_k) predicted_idx = torch.multinomial(top_k_probs, 1).item() predicted_word = idx2word[top_k_idx[predicted_idx].item()] return predicted_word # Function to generate a sentence def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length=50): sentence = start_sequence for _ in range(max_length): next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k) sentence += ' ' + next_word if next_word == '' or next_word == 'User': break return sentence # Example usage start_sequence = "User : What is the capital of France ? Bot :" temp = 0.5 # Adjust temperature top_k = 32 # Adjust top-k logging.info(f'Starting sequence: {start_sequence}') logging.info(f'Temperature: {temp}, Top-k: {top_k}') generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k) logging.info(f'Generated sentence: {generated_sentence}')