|
from __future__ import print_function |
|
from __future__ import absolute_import |
|
__author__ = 'Tony Beltramelli - www.tonybeltramelli.com' |
|
|
|
from .Vocabulary import * |
|
from .Utils import * |
|
|
|
|
|
class Sampler: |
|
def __init__(self, voc_path, input_shape, output_size, context_length): |
|
self.voc = Vocabulary() |
|
self.voc.retrieve(voc_path) |
|
|
|
self.input_shape = input_shape |
|
self.output_size = output_size |
|
|
|
print("Vocabulary size: {}".format(self.voc.size)) |
|
print("Input shape: {}".format(self.input_shape)) |
|
print("Output size: {}".format(self.output_size)) |
|
|
|
self.context_length = context_length |
|
|
|
def predict_greedy(self, model, input_img, require_sparse_label=True, sequence_length=150, verbose=False): |
|
current_context = [self.voc.vocabulary[PLACEHOLDER]] * (self.context_length - 1) |
|
current_context.append(self.voc.vocabulary[START_TOKEN]) |
|
if require_sparse_label: |
|
current_context = Utils.sparsify(current_context, self.output_size) |
|
|
|
predictions = START_TOKEN |
|
out_probas = [] |
|
|
|
for i in range(0, sequence_length): |
|
if verbose: |
|
print("predicting {}/{}...".format(i, sequence_length)) |
|
|
|
probas = model.predict(input_img, np.array([current_context])) |
|
prediction = np.argmax(probas) |
|
out_probas.append(probas) |
|
|
|
new_context = [] |
|
for j in range(1, self.context_length): |
|
new_context.append(current_context[j]) |
|
|
|
if require_sparse_label: |
|
sparse_label = np.zeros(self.output_size) |
|
sparse_label[prediction] = 1 |
|
new_context.append(sparse_label) |
|
else: |
|
new_context.append(prediction) |
|
|
|
current_context = new_context |
|
|
|
predictions += self.voc.token_lookup[prediction] |
|
|
|
if self.voc.token_lookup[prediction] == END_TOKEN: |
|
break |
|
|
|
return predictions, out_probas |
|
|
|
|