|
import tensorflow as tf |
|
import numpy as np |
|
from tensorflow import keras |
|
import os |
|
from typing import Dict, List, Any |
|
import pickle |
|
from PIL import Image |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path: str): |
|
|
|
self.model = keras.models.load_model(os.path.join(path, "model")) |
|
|
|
self.word_to_index = tf.keras.layers.StringLookup( |
|
mask_token="", |
|
vocabulary=self.model.tokenizer.get_vocabulary()) |
|
|
|
self.index_to_word = tf.keras.layers.StringLookup( |
|
mask_token="", |
|
vocabulary=self.model.tokenizer.get_vocabulary(), |
|
invert=True) |
|
|
|
def load_image(img): |
|
|
|
img = tf.io.decode_jpeg(img, channels=3) |
|
img = tf.image.resize(img, IMAGE_SHAPE[:-1]) |
|
return img |
|
|
|
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]: |
|
""" |
|
Args: |
|
inputs (:obj:`PIL.Image`): |
|
The raw image representation as PIL. |
|
No transformation made whatsoever from the input. Make all necessary transformations here. |
|
Return: |
|
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82} |
|
It is preferred if the returned list is in decreasing `score` order |
|
""" |
|
img_array = tf.keras.utils.img_to_array(inputs) |
|
image = load_image(img_array) |
|
initial = self.word_to_index([['[START]']]) |
|
img_features = self.model.feature_extractor(image[tf.newaxis, ...]) |
|
temperature = 0 |
|
tokens = initial |
|
for n in range(50): |
|
preds = self.model((img_features, tokens)).numpy() |
|
preds = preds[:,-1, :] |
|
if temperature==0: |
|
next = tf.argmax(preds, axis=-1)[:, tf.newaxis] |
|
else: |
|
next = tf.random.categorical(preds/temperature, num_samples=1) |
|
tokens = tf.concat([tokens, next], axis=1) |
|
|
|
if next[0] == self.word_to_index('[END]'): |
|
break |
|
words = self.index_to_word(tokens[0, 1:-1]) |
|
result = tf.strings.reduce_join(words, axis=-1, separator=' ') |
|
return result.numpy().decode() |
|
|