File size: 2,407 Bytes
49bb575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
import matplotlib.pyplot as plt


# Path to the pre-trained sentiment analysis model
model_path = "saved_model"

# Load the pre-trained segmentation model
segmentation_model = tf.keras.models.load_model(model_path)

# Target image shape
TARGET_SHAPE = (256, 256)

# Define image segmentation function
def segment_image(img:np.ndarray):
    # Original image shape
    ORIGINAL_SHAPE = img.shape

    # Check if the image is RGB and convert if not
    if len(ORIGINAL_SHAPE) == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

    # Resize the image to TARGET_SHAPE
    img = cv2.resize(img, TARGET_SHAPE)
    
    # Add a batch dimension
    img = np.expand_dims(img, axis=0)
    
    # Predict the segmentation mask
    mask = segmentation_model.predict(img)

    # Remove the batch dimension
    mask = np.squeeze(mask, axis=0)

    # Convert to labels 
    mask = np.argmax(mask, axis=-1)
    
    # Convert to uint8
    mask = mask.astype(np.uint8)

    # Resize to original image shape
    mask = cv2.resize(mask, (ORIGINAL_SHAPE[1], ORIGINAL_SHAPE[0]))

    return mask

def overlay_mask(img, mask, alpha=0.5):
    # Define color mapping
    colors = {
        0: [255, 0, 0],   # Class 0 - Red
        1: [0, 255, 0],   # Class 1 - Green
        2: [0, 0, 255]    # Class 2 - Blue
        # Add more colors for additional classes if needed
    }

    # Create a blank colored overlay image
    overlay = np.zeros_like(img)

    # Map each mask value to the corresponding color
    for class_id, color in colors.items():
        overlay[mask == class_id] = color

    # Blend the overlay with the original image
    output = cv2.addWeighted(img, 1 - alpha, overlay, alpha, 0)

    return output


# The main function
def transform(img):
    mask=segment_image(img)
    blended_img = overlay_mask(img, mask)
    return blended_img


# Create the Gradio app
app = gr.Interface(
    fn=transform, 
    inputs=gr.Image(label="Input Image"), 
    outputs=gr.Image(label="Image with Segmentation Overlay"), 
    title="Image Segmentation on Pet Images",
    description="Segment image of a pet animal into three classes: background, pet, and boundary.",
    examples=[
        "example_images/img1.jpg",
        "example_images/img2.jpg",
        "example_images/img3.jpg"
    ]
)
                   
# Run the app
app.launch()