HariSekhar's picture
Upload 6 files
f08f01e verified
raw
history blame
11.6 kB
from transformer import Transformer # this is the transformer.py file
import torch
import numpy as np
import chardet
import matplotlib.pyplot as plt
from torch import nn
english_file = r'C:\Users\haris\Downloads\eng_marathi\train.en' # only 100 instances are used for experiment
marathi_file = r'C:\Users\haris\Downloads\eng_marathi\train.mr' # only 100 instances are used for experiment
# Generated this by filtering Appendix code
START_TOKEN = '<START>'
PADDING_TOKEN = '<PADDING>'
END_TOKEN = '<END>'
marathi_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '=', '>', '?', 'ˌ',
'ँ', 'ఆ', 'ఇ', 'ా', 'ి', 'ీ', 'ు', 'ూ',
'अ', 'आ', 'इ', 'ई', 'उ', 'ऊ', 'ऋ', 'ॠ', 'ऌ', 'ऎ', 'ए', 'ऐ', 'ऒ', 'ओ', 'औ',
'क', 'ख', 'ग', 'घ', 'ङ',
'च', 'छ', 'ज', 'झ', 'ञ',
'ट', 'ठ', 'ड', 'ढ', 'ण',
'त', 'थ', 'द', 'ध', 'न',
'प', 'फ', 'ब', 'भ', 'म',
'य', 'र', 'ऱ', 'ल', 'ळ', 'व', 'श', 'ष', 'स', 'ह',
'़', 'ऽ', 'ा', 'ि', 'ी', 'ु', 'ू', 'ृ', 'ॄ', 'ॅ', 'े', 'ै', 'ॉ', 'ो', 'ौ', '्', 'ॐ', '।', '॥', '॰', 'ॱ', PADDING_TOKEN, END_TOKEN]
english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
':', '<', '=', '>', '?', '@',
'[', '\\', ']', '^', '_', '`',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
'y', 'z',
'{', '|', '}', '~', PADDING_TOKEN, END_TOKEN]
index_to_marathi = {k:v for k,v in enumerate(marathi_vocabulary)}
marathi_to_index = {v:k for k,v in enumerate(marathi_vocabulary)}
index_to_english = {k:v for k,v in enumerate(english_vocabulary)}
english_to_index = {v:k for k,v in enumerate(english_vocabulary)}
# Open the file in binary mode to detect its encoding
with open(marathi_file, 'rb') as file:
raw_data = file.read(10000) # Read some bytes to check the encoding
result = chardet.detect(raw_data)
encoding = result['encoding']
print(f"Detected encoding: {encoding}")
# Correct way to open the Marathi file with the right encoding
with open(marathi_file, 'r', encoding=encoding) as file:
marathi_sentences = file.readlines()
# If you are reusing the same file, ensure you specify the encoding every time.
with open(english_file, 'r', encoding='utf-8') as file:
english_sentences = file.readlines()
# Now process the sentences as needed
TOTAL_SENTENCES = 20000
english_sentences = english_sentences[:TOTAL_SENTENCES]
marathi_sentences = marathi_sentences[:TOTAL_SENTENCES]
english_sentences = [sentence.rstrip('\n').lower() for sentence in english_sentences]
marathi_sentences = [sentence.rstrip('\n') for sentence in marathi_sentences]
max_sequence_length = 200
def is_valid_tokens(sentence, vocab):
for token in list(set(sentence)):
if token not in vocab:
return False
return True
def is_valid_length(sentence, max_sequence_length):
return len(list(sentence)) < (max_sequence_length - 1) # need to re-add the end token so leaving 1 space
valid_sentence_indicies = []
for index in range(len(marathi_sentences)):
marathi_sentence, english_sentence = marathi_sentences[index], english_sentences[index]
if is_valid_length(marathi_sentence, max_sequence_length) \
and is_valid_length(english_sentence, max_sequence_length) \
and is_valid_tokens(marathi_sentence, marathi_vocabulary):
valid_sentence_indicies.append(index)
print(f"Number of sentences: {len(marathi_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indicies)}")
marathi_sentences = [marathi_sentences[i] for i in valid_sentence_indicies]
english_sentences = [english_sentences[i] for i in valid_sentence_indicies]
d_model = 512
batch_size = 64
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 4
max_sequence_length = 200
mr_vocab_size = len(marathi_vocabulary)
transformer = Transformer(d_model,
ffn_hidden,
num_heads,
drop_prob,
num_layers,
max_sequence_length,
mr_vocab_size,
english_to_index,
marathi_to_index,
START_TOKEN,
END_TOKEN,
PADDING_TOKEN)
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
def __init__(self, english_sentences, marathi_sentences):
self.english_sentences = english_sentences
self.marathi_sentences = marathi_sentences
def __len__(self):
return len(self.english_sentences)
def __getitem__(self, idx):
return self.english_sentences[idx], self.marathi_sentences[idx]
dataset = TextDataset(english_sentences, marathi_sentences)
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)
from torch import nn
criterian = nn.CrossEntropyLoss(ignore_index=marathi_to_index[PADDING_TOKEN],
reduction='none')
# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
if params.dim() > 1:
nn.init.xavier_uniform_(params)
optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
NEG_INFTY = -1e9
def create_masks(eng_batch, mr_batch):
num_sentences = len(eng_batch)
look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
for idx in range(num_sentences):
eng_sentence_length, mr_sentence_length = len(eng_batch[idx]), len(mr_batch[idx])
eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
mr_chars_to_padding_mask = np.arange(mr_sentence_length + 1, max_sequence_length)
encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
decoder_padding_mask_self_attention[idx, :, mr_chars_to_padding_mask] = True
decoder_padding_mask_self_attention[idx, mr_chars_to_padding_mask, :] = True
decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
decoder_padding_mask_cross_attention[idx, mr_chars_to_padding_mask, :] = True
encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
decoder_self_attention_mask = torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask
transformer.train()
transformer.to(device)
num_epochs = 100
epoch_losses = []
for epoch in range(num_epochs):
print(f"Epoch {epoch}")
total_loss = 0
count_batches = 0
iterator = iter(train_loader)
for batch_num, batch in enumerate(iterator):
transformer.train()
eng_batch, mr_batch = batch
encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, mr_batch)
optim.zero_grad()
mr_predictions = transformer(eng_batch,
mr_batch,
encoder_self_attention_mask.to(device),
decoder_self_attention_mask.to(device),
decoder_cross_attention_mask.to(device),
enc_start_token=False,
enc_end_token=False,
dec_start_token=True,
dec_end_token=True)
labels = transformer.decoder.sentence_embedding.batch_tokenize(mr_batch, start_token=False, end_token=True)
loss = criterian(
mr_predictions.view(-1, mr_vocab_size).to(device),
labels.view(-1).to(device)
).to(device)
valid_indicies = torch.where(labels.view(-1) == marathi_to_index[PADDING_TOKEN], False, True)
loss = loss.sum() / valid_indicies.sum()
loss.backward()
optim.step()
total_loss += loss.item()
count_batches += 1
#train_losses.append(loss.item())
if batch_num % 100 == 0:
print(f"Iteration {batch_num} : {loss.item()}")
print(f"English: {eng_batch[0]}")
print(f"marathi Translation: {mr_batch[0]}")
mr_sentence_predicted = torch.argmax(mr_predictions[0], axis=1)
predicted_sentence = ""
for idx in mr_sentence_predicted:
if idx == marathi_to_index[END_TOKEN]:
break
predicted_sentence += index_to_marathi[idx.item()]
print(f"marathi Prediction: {predicted_sentence}")
average_loss = total_loss / count_batches
epoch_losses.append(average_loss)
print(f"Average Loss for Epoch {epoch}: {average_loss}")
transformer.eval()
mr_sentence = ("",)
eng_sentence = ("should we go to the mall?",)
for word_counter in range(max_sequence_length):
encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, mr_sentence)
predictions = transformer(eng_sentence,
mr_sentence,
encoder_self_attention_mask.to(device),
decoder_self_attention_mask.to(device),
decoder_cross_attention_mask.to(device),
enc_start_token=False,
enc_end_token=False,
dec_start_token=True,
dec_end_token=False)
next_token_prob_distribution = predictions[0][word_counter] # not actual probs
next_token_index = torch.argmax(next_token_prob_distribution).item()
next_token = index_to_marathi[next_token_index]
mr_sentence = (mr_sentence[0] + next_token, )
if next_token == END_TOKEN:
break
print(f"Evaluation translation (should we go to the mall?) : {mr_sentence}")
print("-------------------------------------------")