Litton7 / app.py
lichih's picture
Update app.py
e23356b verified
raw
history blame
4.68 kB
import matplotlib.pyplot as plt
import torch
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from typing import Literal, Any
import gradio as gr
import spaces
from io import BytesIO
class Classifier:
LABELS = [
"Panoramic",
"Feature",
"Detail",
"Enclosed",
"Focal",
"Ephemeral",
"Canopied",
]
@spaces.GPU(duration=60)
def __init__(
self, model_path="Litton-7type-visual-landscape-model.pth", device=None
):
if device is None:
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
else:
self.device = device
self.device = device
self.model = torch.load(
model_path, map_location=self.device, weights_only=False
)
if hasattr(self.model, "module"):
self.model = self.model.module
self.model.eval()
self.preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
@spaces.GPU(duration=60)
def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]:
image = image.convert("RGB")
input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(input_tensor)
probs = F.softmax(logits[:, :7], dim=1).cpu()
return draw_bar_chart(
{
"class": self.LABELS,
"probs": probs[0] * 100,
}
)
def draw_bar_chart(data: dict[str, list[str | float]]):
classes = data["class"]
probabilities = data["probs"]
plt.figure(figsize=(8, 6))
plt.bar(classes, probabilities, color="skyblue")
plt.xlabel("Class")
plt.ylabel("Probability (%)")
plt.title("Class Probabilities")
for i, prob in enumerate(probabilities):
plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom")
plt.tight_layout()
return plt
def get_layout():
css = """
.main-title {
font-size: 24px;
font-weight: bold;
text-align: center;
margin-bottom: 20px;
}
.reference {
text-align: center;
font-size: 1.2em;
color: #d1d5db;
margin-bottom: 20px;
}
.reference a {
color: #FB923C;
text-decoration: none;
}
.reference a:hover {
text-decoration: underline;
color: #FB923C;
}
.title {
border-bottom: 1px solid;
}
.footer {
text-align: center;
margin-top: 30px;
padding-top: 20px;
border-top: 1px solid #ddd;
color: #d1d5db;
font-size: 14px;
}
"""
theme = gr.themes.Base(
primary_hue="orange",
secondary_hue="orange",
neutral_hue="gray",
font=gr.themes.GoogleFont("Source Sans Pro"),
).set(
background_fill_primary="*neutral_950", # 主背景色(深黑)
button_primary_background_fill="*primary_500", # 按鈕顏色(橘色)
body_text_color="*neutral_200", # 文字顏色(淺色)
)
# with gr.Blocks(css=css, theme=theme) as demo:
with gr.Blocks() as demo:
with gr.Column():
gr.HTML(
value=(
'<div class="main-title">Litton7景觀分類模型</div>'
'<div class="reference">引用資料:'
'<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">'
"何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)"
"</a>"
"</div>"
),
)
with gr.Row(equal_height=True):
image_input = gr.Image(label="上傳影像", type="pil")
chart = gr.Image(label="分類結果")
start_button = gr.Button("開始分類", variant="primary")
gr.HTML(
'<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
)
start_button.click(
fn=Classifier().predict,
inputs=image_input,
outputs=chart,
)
return demo
if __name__ == "__main__":
get_layout().launch()