meow / app.py
nahidalam's picture
Update app.py
6376db8
raw
history blame
809 Bytes
import gradio as gr
import numpy as np
import tensorflow as tf
import PIL
def normalize_img(img):
img = tf.cast(img, dtype=tf.float32)
# Map values in the range [-1, 1]
return (img / 127.5) - 1.0
def predict_and_save(img, generator_model):
img = normalize_img(img)
prediction = generator_model(img, training=False)[0].numpy()
prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
im = PIL.Image.fromarray(prediction)
return im
def run(image_path):
model = tf.keras.models.load_model('pretrained')
print("Model loaded")
img_array = tf.expand_dims(image_path, 0)
im = predict_and_save(img_array, model)
print("Prediction Done")
return im
iface = gr.Interface(run, gr.inputs.Image(shape=(256, 256)), "image")
iface.launch(share = True)