VQA / vocab.py
ZubairAhmed777's picture
Create vocab.py
c3e07b2 verified
raw
history blame
1.53 kB
import json
import os
import re
from collections import defaultdict
import glob
import numpy as np
import time
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
class Vocabulary:
def __init__(self, vocabulary_file_path):
#Initialize the Vocabulary object.
# Load vocabulary from the provided file path
self.vocabulary = self._load_vocabulary(vocabulary_file_path)
# Create a mapping from words to indices
self.vocabulary2idx = {word: idx for idx, word in enumerate(self.vocabulary)}
# Store the total size of the vocabulary
self.vocabulary_size = len(self.vocabulary)
def _load_vocabulary(self, vocabulary_file_path):
#Load vocabulary from a file.
with open(vocabulary_file_path, 'r') as file:
# Read each line, strip extra whitespace, and return as a list
vocabulary = [line.strip() for line in file]
return vocabulary
def word2idx(self, word):
#Convert a word to its corresponding index.
# Return the index of the word or the index of '<unk>' if the word is not in the vocabulary
return self.vocabulary2idx.get(word, self.vocabulary2idx.get('<unk>'))
def idx2word(self, idx):
#Convert an index back to its corresponding word.
return self.vocabulary[idx]