File size: 2,568 Bytes
f6a67e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
__author__ = 'Tony Beltramelli - www.tonybeltramelli.com'
import sys
import numpy as np
START_TOKEN = "<START>"
END_TOKEN = "<END>"
PLACEHOLDER = " "
SEPARATOR = '->'
class Vocabulary:
def __init__(self):
self.binary_vocabulary = {}
self.vocabulary = {}
self.token_lookup = {}
self.size = 0
self.append(START_TOKEN)
self.append(END_TOKEN)
self.append(PLACEHOLDER)
def append(self, token):
if token not in self.vocabulary:
self.vocabulary[token] = self.size
self.token_lookup[self.size] = token
self.size += 1
def create_binary_representation(self):
if sys.version_info >= (3,):
items = self.vocabulary.items()
else:
items = self.vocabulary.iteritems()
for key, value in items:
binary = np.zeros(self.size)
binary[value] = 1
self.binary_vocabulary[key] = binary
def get_serialized_binary_representation(self):
if len(self.binary_vocabulary) == 0:
self.create_binary_representation()
string = ""
if sys.version_info >= (3,):
items = self.binary_vocabulary.items()
else:
items = self.binary_vocabulary.iteritems()
for key, value in items:
array_as_string = np.array2string(value, separator=',', max_line_width=self.size * self.size)
string += "{}{}{}\n".format(key, SEPARATOR, array_as_string[1:len(array_as_string) - 1])
return string
def save(self, path):
output_file_name = "{}/words.vocab".format(path)
output_file = open(output_file_name, 'w')
output_file.write(self.get_serialized_binary_representation())
output_file.close()
def retrieve(self, path):
input_file = open("{}/words.vocab".format(path), 'r')
buffer = ""
for line in input_file:
try:
separator_position = len(buffer) + line.index(SEPARATOR)
buffer += line
key = buffer[:separator_position]
value = buffer[separator_position + len(SEPARATOR):]
value = np.fromstring(value, sep=',')
self.binary_vocabulary[key] = value
self.vocabulary[key] = np.where(value == 1)[0][0]
self.token_lookup[np.where(value == 1)[0][0]] = key
buffer = ""
except ValueError:
buffer += line
input_file.close()
self.size = len(self.vocabulary)
|