|
import matplotlib.pyplot as plt |
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
import torch.nn.functional as F |
|
from typing import Literal, Any |
|
import gradio as gr |
|
import spaces |
|
from io import BytesIO |
|
|
|
|
|
class Classifier: |
|
LABELS = [ |
|
"Panoramic", |
|
"Feature", |
|
"Detail", |
|
"Enclosed", |
|
"Focal", |
|
"Ephemeral", |
|
"Canopied", |
|
] |
|
|
|
@spaces.GPU(duration=60) |
|
def __init__( |
|
self, model_path="Litton-7type-visual-landscape-model.pth", device=None |
|
): |
|
if device is None: |
|
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") |
|
else: |
|
self.device = device |
|
|
|
self.device = device |
|
self.model = torch.load( |
|
model_path, map_location=self.device, weights_only=False |
|
) |
|
if hasattr(self.model, "module"): |
|
self.model = self.model.module |
|
self.model.eval() |
|
self.preprocess = transforms.Compose( |
|
[ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
] |
|
) |
|
|
|
@spaces.GPU(duration=60) |
|
def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]: |
|
image = image.convert("RGB") |
|
input_tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
|
|
|
with torch.no_grad(): |
|
logits = self.model(input_tensor) |
|
probs = F.softmax(logits[:, :7], dim=1).cpu() |
|
|
|
return draw_bar_chart( |
|
{ |
|
"class": self.LABELS, |
|
"probs": probs[0] * 100, |
|
} |
|
) |
|
|
|
|
|
def draw_bar_chart(data: dict[str, list[str | float]]): |
|
classes = data["class"] |
|
probabilities = data["probs"] |
|
|
|
plt.figure(figsize=(8, 6)) |
|
plt.bar(classes, probabilities, color="skyblue") |
|
|
|
plt.xlabel("Class") |
|
plt.ylabel("Probability (%)") |
|
plt.title("Class Probabilities") |
|
|
|
for i, prob in enumerate(probabilities): |
|
plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom") |
|
|
|
plt.tight_layout() |
|
|
|
return plt |
|
|
|
|
|
def get_layout(): |
|
css = """ |
|
.main-title { |
|
font-size: 24px; |
|
font-weight: bold; |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
.reference { |
|
text-align: center; |
|
font-size: 1.2em; |
|
color: #d1d5db; |
|
margin-bottom: 20px; |
|
} |
|
.reference a { |
|
color: #FB923C; |
|
text-decoration: none; |
|
} |
|
.reference a:hover { |
|
text-decoration: underline; |
|
color: #FB923C; |
|
} |
|
.title { |
|
border-bottom: 1px solid; |
|
} |
|
.footer { |
|
text-align: center; |
|
margin-top: 30px; |
|
padding-top: 20px; |
|
border-top: 1px solid #ddd; |
|
color: #d1d5db; |
|
font-size: 14px; |
|
} |
|
""" |
|
theme = gr.themes.Base( |
|
primary_hue="orange", |
|
secondary_hue="orange", |
|
neutral_hue="gray", |
|
font=gr.themes.GoogleFont("Source Sans Pro"), |
|
).set( |
|
background_fill_primary="*neutral_950", |
|
button_primary_background_fill="*primary_500", |
|
body_text_color="*neutral_200", |
|
) |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
gr.HTML( |
|
value=( |
|
'<div class="main-title">Litton7景觀分類模型</div>' |
|
'<div class="reference">引用資料:' |
|
'<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">' |
|
"何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)" |
|
"</a>" |
|
"</div>" |
|
), |
|
) |
|
|
|
with gr.Row(equal_height=True): |
|
image_input = gr.Image(label="上傳影像", type="pil") |
|
chart = gr.Image(label="分類結果") |
|
|
|
start_button = gr.Button("開始分類", variant="primary") |
|
gr.HTML( |
|
'<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>', |
|
) |
|
start_button.click( |
|
fn=Classifier().predict, |
|
inputs=image_input, |
|
outputs=chart, |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
get_layout().launch() |
|
|