Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| from patchify import patchify, unpatchify | |
| import numpy as np | |
| from skimage.io import imshow, imsave | |
| import tensorflow | |
| import tensorflow as tf | |
| from tensorflow.keras import backend as K | |
| def jacard(y_true, y_pred): | |
| y_true_c = K.flatten(y_true) | |
| y_pred_c = K.flatten(y_pred) | |
| intersection = K.sum(y_true_c * y_pred_c) | |
| return (intersection + 1.0) / (K.sum(y_true_c) + K.sum(y_pred_c) - intersection + 1.0) | |
| def bce_dice(y_true, y_pred): | |
| bce = tf.keras.losses.BinaryCrossentropy() | |
| return bce(y_true, y_pred) - K.log(jacard(y_true, y_pred)) | |
| size = 1024 | |
| pach_size = 256 | |
| def predict_2(image): | |
| image = Image.fromarray(image).resize((size,size)) | |
| image = np.array(image) | |
| stride = 2 | |
| steps = int(pach_size/stride) | |
| patches_img = patchify(image, (pach_size, pach_size, 3), step=steps) #Step=256 for 256 patches means no overlap | |
| patches_img = patches_img[:,:,0,:,:,:] | |
| patched_prediction = [] | |
| for i in range(patches_img.shape[0]): | |
| for j in range(patches_img.shape[1]): | |
| single_patch_img = patches_img[i,j,:,:,:] | |
| single_patch_img = single_patch_img/255 | |
| single_patch_img = np.expand_dims(single_patch_img, axis=0) | |
| pred = model.predict(single_patch_img) | |
| # Postprocess the mask | |
| pred = np.argmax(pred, axis=3) | |
| #print(pred.shape) | |
| pred = pred[0, :,:] | |
| patched_prediction.append(pred) | |
| patched_prediction = np.reshape(patched_prediction, [patches_img.shape[0], patches_img.shape[1], | |
| patches_img.shape[2], patches_img.shape[3]]) | |
| unpatched_prediction = unpatchify(patched_prediction, (image.shape[0], image.shape[1])) | |
| unpatched_prediction = targets_classes_colors[unpatched_prediction] | |
| return 'Predicted Masked Image', unpatched_prediction | |
| targets_classes_colors = np.array([[ 0, 0, 0], | |
| [128, 64, 128], | |
| [130, 76, 0], | |
| [ 0, 102, 0], | |
| [112, 103, 87], | |
| [ 28, 42, 168], | |
| [ 48, 41, 30], | |
| [ 0, 50, 89], | |
| [107, 142, 35], | |
| [ 70, 70, 70], | |
| [102, 102, 156], | |
| [254, 228, 12], | |
| [254, 148, 12], | |
| [190, 153, 153], | |
| [153, 153, 153], | |
| [255, 22, 96], | |
| [102, 51, 0], | |
| [ 9, 143, 150], | |
| [119, 11, 32], | |
| [ 51, 51, 0], | |
| [190, 250, 190], | |
| [112, 150, 146], | |
| [ 2, 135, 115], | |
| [255, 0, 0]]) | |
| class_weights = {0: 1.0, | |
| 1: 1.0, | |
| 2: 2.171655596616696, | |
| 3: 1.0, | |
| 4: 1.0, | |
| 5: 2.2101197049812593, | |
| 6: 11.601519937899578, | |
| 7: 7.99072122367673, | |
| 8: 1.0, | |
| 9: 1.0, | |
| 10: 2.5426918173402457, | |
| 11: 11.187574445057574, | |
| 12: 241.57620214903147, | |
| 13: 9.234779790464515, | |
| 14: 1077.2745952165694, | |
| 15: 7.396021659003857, | |
| 16: 855.6730643687165, | |
| 17: 6.410869993189135, | |
| 18: 42.0186736125025, | |
| 19: 2.5648760196752947, | |
| 20: 4.089194047656931, | |
| 21: 27.984593442818955, | |
| 22: 2.0509251319694712} | |
| weight_list = list(class_weights.values()) | |
| def weighted_categorical_crossentropy(weights): | |
| weights = weight_list | |
| def wcce(y_true, y_pred): | |
| Kweights = K.constant(weights) | |
| if not tf.is_tensor(y_pred): y_pred = K.constant(y_pred) | |
| y_true = K.cast(y_true, y_pred.dtype) | |
| return bce_dice(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1) | |
| return wcce | |
| # Load the model | |
| model = tf.keras.models.load_model("model.h5", custom_objects={"jacard":jacard, "wcce":weighted_categorical_crossentropy}) | |
| # Create a user interface for the model | |
| my_app = gr.Blocks() | |
| with my_app: | |
| gr.Markdown("Statellite Image Segmentation Application UI with Gradio") | |
| with gr.Tabs(): | |
| with gr.TabItem("Select your image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_source = gr.Image(label="Please select source Image") | |
| source_image_loader = gr.Button("Load above Image") | |
| with gr.Column(): | |
| output_label = gr.Label(label="Image Info") | |
| img_output = gr.Image(label="Image Output") | |
| source_image_loader.click( | |
| predict_2, | |
| [ | |
| img_source | |
| ], | |
| [ | |
| output_label, | |
| img_output | |
| ] | |
| ) | |
| my_app.launch(debug=True, share=True) | |
| my_app.close() |