Neural Network-Based Language Model for Next Token Prediction
Overview
This project is a midterm assignment focused on developing a neural network-based language model for next token prediction. The model was trained using a custom dataset with two languages, English and Amharic. The project incorporates techniques in neural networks to predict the next token in a sequence, demonstrating a non-transformer approach to language modeling.
Project Objectives
The main objective of this project was to:
- Develop a neural network-based model for next token prediction without using transformers or encoder-decoder architectures.
- Experiment with multiple languages to observe model performance.
- Implement checkpointing to save model progress and generate text during different training stages.
- Present a video demo showcasing the model's performance in generating text in both English and Amharic.
Project Details
1. Training Languages
The model was trained using datasets in English and Amharic. The datasets were cleaned and prepared, including tokenization and embedding for improved model training.
2. Tokenizer
A custom tokenizer was created using Byte Pair Encoding (BPE). This tokenizer was trained on five languages: English, Amharic, Sanskrit, Nepali, and Hindi, but the model specifically utilized English and Amharic for this task.
3. Embedding Model
A custom embedding model was employed to convert tokens into vector representations, allowing the neural network to better understand the structure and meaning of the input data.
4. Model Architecture
The project uses an LSTM (Long Short-Term Memory) neural network to predict the next token in a sequence. LSTMs are well-suited for sequential data and are a popular choice for language modeling due to their ability to capture long-term dependencies.
Results and Evaluation
Training Curve and Loss
The model’s training and validation loss over time are documented and included in the repository (loss_values.csv
). The training curve demonstrates the model's learning progress, with explanations provided for key observations in the loss trends.
Checkpoint Implementation
Checkpointing was implemented to save model states at different training stages, allowing for partial model evaluations and text generation demos. Checkpoints are included in the repository for reference.
Perplexity Score
The model's perplexity score, calculated during training, is available in the perplexity.csv
file. This score provides an indication of the model's predictive accuracy over time.
Demonstration
A video demo, linked below, demonstrates:
- Random initialization text generation in English.
- Text generation using the trained model in both English and Amharic, with English translations provided using Google Translate.
Video Demo Link: YouTube Demo
Instructions for Reproducing the Results
- Install dependencies (Python, PyTorch, and other required libraries).
- Load the .ipynb notebook and run cells sequentially to replicate training and evaluation.
- Refer to HuggingFace documentation for downloading the model and tokenizer files.
Note: The data for the project has been taken from saillab/taco-datasets