File size: 4,891 Bytes
37e9e4f fdb53d2 37e9e4f fdb53d2 37e9e4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image, UnidentifiedImageError
import streamlit as st
import numpy as np
import requests
from io import BytesIO
from kan_linear import KANLinear
import logging
import os
# Setup logging
logging.basicConfig(level=logging.INFO)
# Definisikan model ResNet Anda
class CustomResNetKAN(nn.Module):
def __init__(self, num_classes=1): # Set num_classes to 1 for binary classification
super(CustomResNetKAN, self).__init__()
self.model = models.resnet50(pretrained=False)
self.model.fc = KANLinear(self.model.fc.in_features, num_classes)
def forward(self, x):
return self.model(x)
def load_model(weights_path, device):
model = CustomResNetKAN().to(device)
state_dict = torch.load(weights_path, map_location=device)
# Remove 'module.' prefix from keys
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[len('module.'):]] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.eval()
return model
class CustomImageLoadingError(Exception):
"""Custom exception for image loading errors"""
pass
def load_image_from_url(url):
try:
logging.info(f"Loading image from URL: {url}")
# Check the file extension
valid_extensions = ['jpg', 'jpeg', 'png', 'webp']
file_extension = os.path.splitext(url)[1][1:].lower()
if file_extension not in valid_extensions:
raise CustomImageLoadingError(f"URL does not point to an image with a valid extension: {file_extension}")
response = requests.get(url)
response.raise_for_status() # Check if the request was successful
content_type = response.headers['Content-Type']
logging.info(f"Content-Type: {content_type}")
# Check if the content type is an image
if 'image' not in content_type:
raise CustomImageLoadingError(f"URL does not point to an image: {content_type}")
img = Image.open(BytesIO(response.content)).convert('RGB')
logging.info("Image successfully loaded and converted to RGB")
return img
except requests.HTTPError as e:
logging.error(f"HTTPError while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
except UnidentifiedImageError as e:
logging.error(f"UnidentifiedImageError while loading image: {e}")
raise CustomImageLoadingError(f"Cannot identify image file: {e}")
except requests.RequestException as e:
logging.error(f"RequestException while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
except Exception as e:
logging.error(f"Unexpected error while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
return transform(image).unsqueeze(0)
# Streamlit app
st.title("Cat and Dog Classification with ResNet-KAN")
st.sidebar.title("Upload Images")
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
image_url = st.sidebar.text_input("Or enter image URL...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model('weights/best_model_resnet50_KAN.pth', device)
img = None
if uploaded_file is not None:
logging.info("Image uploaded via file uploader")
img = Image.open(uploaded_file).convert('RGB')
elif image_url:
try:
img = load_image_from_url(image_url)
except CustomImageLoadingError as e:
st.sidebar.error(str(e))
except Exception as e:
st.sidebar.error(f"Unexpected error: {e}")
st.sidebar.write("-----")
# Define your information for the footer
name = "Wayan Dadang"
st.sidebar.write("Follow me on:")
# Create a footer section with links and copyright information
st.sidebar.markdown(f"""
[LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/)
[GitHub](https://github.com/Wayan123)
[Resume](https://wayan123.github.io/)
© {name} - {2024}
""", unsafe_allow_html=True)
if img is not None:
st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
if st.button('Predict'):
img_tensor = preprocess_image(img).to(device)
with torch.no_grad():
output = model(img_tensor)
prob = torch.sigmoid(output).item()
st.write(f"Prediction: {prob:.4f}")
if prob < 0.5:
st.write("This image is classified as a Cat.")
else:
st.write("This image is classified as a Dog.")
|