Spaces:
Sleeping
Sleeping
import gradio as gr | |
from tensorflow import keras | |
from keras import layers | |
import tensorflow as tf | |
import numpy as np | |
IMAGE_SIZE = (299, 299) | |
VOCAB_SIZE = 8800 | |
SEQ_LENGTH = 25 | |
EMBED_DIM = 512 | |
FF_DIM = 512 | |
import re | |
image_augmentation = keras.Sequential( | |
[ | |
keras.layers.RandomFlip("horizontal"), | |
keras.layers.RandomRotation(0.2), | |
keras.layers.RandomContrast(0.3), | |
] | |
) | |
def get_cnn_model(): | |
base_model = keras.applications.efficientnet.EfficientNetB0( | |
input_shape=(*IMAGE_SIZE, 3), | |
include_top=False, | |
weights="imagenet" | |
) | |
base_model.trainable = False | |
base_model_out = base_model.output | |
base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out) | |
cnn_model = keras.models.Model(base_model.input, base_model_out) | |
return cnn_model | |
class TransformerEncoderBlock(layers.Layer): | |
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs): | |
super().__init__(**kwargs) | |
self.embed_dim = embed_dim | |
self.dense_dim = dense_dim | |
self.num_heads = num_heads | |
self.attention_1 = layers.MultiHeadAttention( | |
num_heads=num_heads, key_dim=embed_dim, dropout=0.0 | |
) | |
self.layernorm_1 = layers.LayerNormalization() | |
self.layernorm_2 = layers.LayerNormalization() | |
self.dense_1 = layers.Dense(embed_dim, activation="relu") | |
def call(self, inputs, training): | |
inputs = self.layernorm_1(inputs) | |
inputs = self.dense_1(inputs) | |
attention_output_1 = self.attention_1( | |
query=inputs, | |
value=inputs, | |
key=inputs, | |
training=training, | |
) | |
out_1 = self.layernorm_2(inputs + attention_output_1) | |
return out_1 | |
class PositionalEmbedding(layers.Layer): | |
def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs): | |
super().__init__(**kwargs) | |
self.token_embeddings = layers.Embedding( | |
input_dim=vocab_size, output_dim=embed_dim, mask_zero=True | |
) | |
self.position_embeddings = layers.Embedding( | |
input_dim=sequence_length, output_dim=embed_dim | |
) | |
self.sequence_length = sequence_length | |
self.vocab_size = vocab_size | |
self.embed_dim = embed_dim | |
self.add = layers.Add() | |
def call(self, seq): | |
seq = self.token_embeddings(seq) | |
x = tf.range(tf.shape(seq)[1]) | |
x = x[tf.newaxis, :] | |
x = self.position_embeddings(x) | |
return self.add([seq,x]) | |
class TransformerDecoderBlock(layers.Layer): | |
def __init__(self, embed_dim, ff_dim, num_heads, **kwargs): | |
super().__init__(**kwargs) | |
self.embed_dim = embed_dim | |
self.ff_dim = ff_dim | |
self.num_heads = num_heads | |
self.attention_1 = layers.MultiHeadAttention( | |
num_heads=num_heads, key_dim=embed_dim, dropout=0.1 | |
) | |
self.attention_2 = layers.MultiHeadAttention( | |
num_heads=num_heads, key_dim=embed_dim, dropout=0.1 | |
) | |
self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu") | |
self.ffn_layer_2 = layers.Dense(embed_dim) | |
self.layernorm_1 = layers.LayerNormalization() | |
self.layernorm_2 = layers.LayerNormalization() | |
self.layernorm_3 = layers.LayerNormalization() | |
self.embedding = PositionalEmbedding( | |
embed_dim=EMBED_DIM, | |
sequence_length=SEQ_LENGTH, | |
vocab_size=VOCAB_SIZE, | |
) | |
self.out = layers.Dense(VOCAB_SIZE, activation="softmax") | |
self.dropout_1 = layers.Dropout(0.3) | |
self.dropout_2 = layers.Dropout(0.5) | |
self.supports_masking = True | |
def call(self, inputs, encoder_outputs, training, mask=None): | |
inputs = self.embedding(inputs) | |
attention_output_1 = self.attention_1( | |
query=inputs, | |
value=inputs, | |
key=inputs, | |
training=training, | |
use_causal_mask=True | |
) | |
out_1 = self.layernorm_1(inputs + attention_output_1) | |
attention_output_2 = self.attention_2( | |
query=out_1, | |
value=encoder_outputs, | |
key=encoder_outputs, | |
training=training, | |
) | |
out_2 = self.layernorm_2(out_1 + attention_output_2) | |
ffn_out = self.ffn_layer_1(out_2) | |
ffn_out = self.dropout_1(ffn_out, training=training) | |
ffn_out = self.ffn_layer_2(ffn_out) | |
ffn_out = self.layernorm_3(ffn_out + out_2, training=training) | |
ffn_out = self.dropout_2(ffn_out, training=training) | |
preds = self.out(ffn_out) | |
return preds | |
class ImageCaptioningModel(keras.Model): | |
def __init__( | |
self, | |
cnn_model, | |
encoder, | |
decoder, | |
image_aug=None, | |
**kwargs): | |
super().__init__(**kwargs) | |
self.cnn_model = cnn_model | |
self.encoder = encoder | |
self.decoder = decoder | |
self.image_aug = image_aug | |
def call(self, inputs, training): | |
img, caption = inputs | |
if self.image_aug: | |
img = self.image_aug(img) | |
img_embed = self.cnn_model(img) | |
encoder_out = self.encoder(img_embed, training=training) | |
pred = self.decoder(caption, encoder_out, training=training) | |
return pred | |
cnn_model = get_cnn_model() | |
encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, | |
dense_dim=FF_DIM, | |
num_heads=1) | |
decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, | |
ff_dim=FF_DIM, | |
num_heads=2) | |
loaded_model = ImageCaptioningModel( | |
cnn_model=cnn_model, | |
encoder=encoder, | |
decoder=decoder, | |
image_aug=image_augmentation) | |
loaded_model.compile(optimizer=keras.optimizers.Adam(learning_rate = 3e-4), loss='sparse_categorical_crossentropy', | |
metrics=['accuracy']) | |
loaded_model.load_weights("Checkpoint") | |
vocab = np.load("vocabulary.npy") | |
index_lookup = dict(zip(range(len(vocab)), vocab)) | |
data_txt = np.load("text_data.npy").tolist() | |
max_decoded_sentence_length = SEQ_LENGTH - 1 | |
strip_chars = "!\"#$%&'()*+,-./:;=?@[\]^_`{|}~" | |
def custom_standardization(input_string): | |
lowercase = tf.strings.lower(input_string) | |
return tf.strings.regex_replace(lowercase, f'{re.escape(strip_chars)}', '') | |
vectorization = keras.layers.TextVectorization( | |
max_tokens=VOCAB_SIZE, | |
output_mode="int", | |
output_sequence_length=SEQ_LENGTH, | |
standardize=custom_standardization, | |
) | |
vectorization.adapt(data_txt) | |
def generate_caption(image): | |
img = tf.constant(image) | |
img = tf.image.resize(img, IMAGE_SIZE) | |
img = tf.image.convert_image_dtype(img, tf.float32) | |
img = tf.expand_dims(img, 0) | |
img = loaded_model.cnn_model(img) | |
encoded_img = loaded_model.encoder(img, training=False) | |
decoded_caption = "startseq " | |
for i in range(SEQ_LENGTH - 1): | |
tokenized_caption = vectorization([decoded_caption]) | |
mask = tf.math.not_equal(tokenized_caption, 0) | |
predictions = loaded_model.decoder( | |
tokenized_caption, encoded_img, training=False, mask=mask | |
) | |
sampled_token_index = np.argmax(predictions[0, i, :]) | |
sampled_token = index_lookup[sampled_token_index] | |
if sampled_token == "endseq": | |
break | |
decoded_caption += " " + sampled_token | |
decoded_caption = decoded_caption.replace("startseq ", "") | |
decoded_caption = decoded_caption.replace(" endseq", "").strip() | |
return decoded_caption | |
demo = gr.Interface(fn=generate_caption, | |
inputs=gr.components.Image(), | |
outputs=[gr.components.Textbox(label="Generated Caption", lines=3)], | |
) | |
demo.launch(share = True, debug = True) | |