File size: 4,891 Bytes
37e9e4f
 
 
 
 
 
 
 
 
 
 
fdb53d2
37e9e4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdb53d2
 
 
 
 
 
 
37e9e4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image, UnidentifiedImageError
import streamlit as st
import numpy as np
import requests
from io import BytesIO
from kan_linear import KANLinear
import logging
import os

# Setup logging
logging.basicConfig(level=logging.INFO)

# Definisikan model ResNet Anda
class CustomResNetKAN(nn.Module):
    def __init__(self, num_classes=1):  # Set num_classes to 1 for binary classification
        super(CustomResNetKAN, self).__init__()
        self.model = models.resnet50(pretrained=False)
        self.model.fc = KANLinear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

def load_model(weights_path, device):
    model = CustomResNetKAN().to(device)
    state_dict = torch.load(weights_path, map_location=device)
    
    # Remove 'module.' prefix from keys
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[len('module.'):]] = v
        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model.eval()
    return model

class CustomImageLoadingError(Exception):
    """Custom exception for image loading errors"""
    pass

def load_image_from_url(url):
    try:
        logging.info(f"Loading image from URL: {url}")

        # Check the file extension
        valid_extensions = ['jpg', 'jpeg', 'png', 'webp']
        file_extension = os.path.splitext(url)[1][1:].lower()
        if file_extension not in valid_extensions:
            raise CustomImageLoadingError(f"URL does not point to an image with a valid extension: {file_extension}")

        response = requests.get(url)
        response.raise_for_status()  # Check if the request was successful

        content_type = response.headers['Content-Type']
        logging.info(f"Content-Type: {content_type}")

        # Check if the content type is an image
        if 'image' not in content_type:
            raise CustomImageLoadingError(f"URL does not point to an image: {content_type}")

        img = Image.open(BytesIO(response.content)).convert('RGB')
        logging.info("Image successfully loaded and converted to RGB")
        return img
    except requests.HTTPError as e:
        logging.error(f"HTTPError while loading image: {e}")
        raise CustomImageLoadingError(f"Error loading image from URL: {e}")
    except UnidentifiedImageError as e:
        logging.error(f"UnidentifiedImageError while loading image: {e}")
        raise CustomImageLoadingError(f"Cannot identify image file: {e}")
    except requests.RequestException as e:
        logging.error(f"RequestException while loading image: {e}")
        raise CustomImageLoadingError(f"Error loading image from URL: {e}")
    except Exception as e:
        logging.error(f"Unexpected error while loading image: {e}")
        raise CustomImageLoadingError(f"Error loading image from URL: {e}")

def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    return transform(image).unsqueeze(0)

# Streamlit app
st.title("Cat and Dog Classification with ResNet-KAN")

st.sidebar.title("Upload Images")
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
image_url = st.sidebar.text_input("Or enter image URL...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model('weights/best_model_resnet50_KAN.pth', device)

img = None

if uploaded_file is not None:
    logging.info("Image uploaded via file uploader")
    img = Image.open(uploaded_file).convert('RGB')
elif image_url:
    try:
        img = load_image_from_url(image_url)
    except CustomImageLoadingError as e:
        st.sidebar.error(str(e))
    except Exception as e:
        st.sidebar.error(f"Unexpected error: {e}")

st.sidebar.write("-----")

# Define your information for the footer
name = "Wayan Dadang"

st.sidebar.write("Follow me on:")
# Create a footer section with links and copyright information
st.sidebar.markdown(f"""
    [LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/)
    [GitHub](https://github.com/Wayan123)
    [Resume](https://wayan123.github.io/)
    © {name} - {2024}
    """, unsafe_allow_html=True)

if img is not None:
    st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
    if st.button('Predict'):
        img_tensor = preprocess_image(img).to(device)

        with torch.no_grad():
            output = model(img_tensor)
            prob = torch.sigmoid(output).item()

        st.write(f"Prediction: {prob:.4f}")

        if prob < 0.5:
            st.write("This image is classified as a Cat.")
        else:
            st.write("This image is classified as a Dog.")