import gradio as gr import tensorflow as tf import numpy as np import cv2 import matplotlib.pyplot as plt # Path to the pre-trained sentiment analysis model model_path = "saved_model" # Load the pre-trained segmentation model segmentation_model = tf.keras.models.load_model(model_path) # Target image shape TARGET_SHAPE = (256, 256) # Define image segmentation function def segment_image(img:np.ndarray): # Original image shape ORIGINAL_SHAPE = img.shape # Check if the image is RGB and convert if not if len(ORIGINAL_SHAPE) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # Resize the image to TARGET_SHAPE img = cv2.resize(img, TARGET_SHAPE) # Add a batch dimension img = np.expand_dims(img, axis=0) # Predict the segmentation mask mask = segmentation_model.predict(img) # Remove the batch dimension mask = np.squeeze(mask, axis=0) # Convert to labels mask = np.argmax(mask, axis=-1) # Convert to uint8 mask = mask.astype(np.uint8) # Resize to original image shape mask = cv2.resize(mask, (ORIGINAL_SHAPE[1], ORIGINAL_SHAPE[0])) return mask def overlay_mask(img, mask, alpha=0.5): # Define color mapping colors = { 0: [255, 0, 0], # Class 0 - Red 1: [0, 255, 0], # Class 1 - Green 2: [0, 0, 255] # Class 2 - Blue # Add more colors for additional classes if needed } # Create a blank colored overlay image overlay = np.zeros_like(img) # Map each mask value to the corresponding color for class_id, color in colors.items(): overlay[mask == class_id] = color # Blend the overlay with the original image output = cv2.addWeighted(img, 1 - alpha, overlay, alpha, 0) return output # The main function def transform(img): mask=segment_image(img) blended_img = overlay_mask(img, mask) return blended_img # Create the Gradio app app = gr.Interface( fn=transform, inputs=gr.Image(label="Input Image"), outputs=gr.Image(label="Image with Segmentation Overlay"), title="Image Segmentation on Pet Images", description="Segment image of a pet animal into three classes: background, pet, and boundary.", examples=[ "example_images/img1.jpg", "example_images/img2.jpg", "example_images/img3.jpg" ] ) # Run the app app.launch()