Litton7 / app.py
lichih's picture
no share example image
4f65ff4 verified
from PIL import Image
from io import BytesIO
from matplotlib.figure import Figure
from torchvision import transforms
from tqdm import tqdm
from typing import Literal, Any
from urllib.request import urlopen
import gradio as gr
import matplotlib.pyplot as plt
import os
import spaces
import sys
import torch
import torch.nn.functional as F
LABELS = [
"Panoramic",
"Feature",
"Detail",
"Enclosed",
"Focal",
"Ephemeral",
"Canopied",
]
MODELFILE = "Litton-7type-visual-landscape-model.pth"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if not os.path.exists(MODELFILE):
model_url = f"https://lclab.thu.edu.tw/modelzoo/{MODELFILE}"
print(f"fetch model from {model_url}...", file=sys.stderr)
with urlopen(model_url) as resp:
progress = tqdm(total=int(resp["Content-Length"]), desc="Downloading")
with open(MODELFILE, "wb") as modelfile:
while True:
chunk = resp.read(1024)
if len(chunk) == 0:
break
modelfile.write(chunk)
progress.update(len(chunk))
model = torch.load(
MODELFILE, map_location=device, weights_only=False
).module
model.eval()
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
@spaces.GPU
def predict(image: Image.Image) -> Figure:
image = image.convert("RGB")
input_tensor = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(input_tensor)
probs = F.softmax(logits[:, :7], dim=1).cpu()
return draw_bar_chart(
{
"class": LABELS,
"probs": probs[0] * 100,
}
)
def draw_bar_chart(data: dict[str, list[str | float]]):
classes = data["class"]
probabilities = data["probs"]
fig, ax = plt.subplots(figsize=(8, 6))
ax.bar(classes, probabilities, color="skyblue")
ax.set_xlabel("Class")
ax.set_ylabel("Probability (%)")
ax.set_title("Class Probability")
for i, prob in enumerate(probabilities):
ax.text(i, prob + 0.01, f"{prob:.2f}%", ha="center", va="bottom")
fig.tight_layout()
return fig
def choose_example(imgpath: str) -> gr.Image:
img = Image.open(imgpath)
width, height = img.size
ratio = 512 / max(width, height)
img = img.resize((int(width * ratio), int(height * ratio)))
return gr.Image(value=img, label="輸入影像(不支援 SVG 格式)", type="pil")
def get_layout():
css = """
.main-title {
font-size: 24px;
font-weight: bold;
text-align: center;
margin-bottom: 20px;
}
.reference {
text-align: center;
font-size: 1.2em;
color: #d1d5db;
margin-bottom: 20px;
}
.reference a {
color: #FB923C;
text-decoration: none;
}
.reference a:hover {
text-decoration: underline;
color: #FB923C;
}
.title {
border-bottom: 1px solid;
}
.footer {
text-align: center;
margin-top: 30px;
padding-top: 20px;
border-top: 1px solid #ddd;
color: #d1d5db;
font-size: 14px;
}
.example-image {
height: 220px;
padding: 25px;
}
"""
theme = gr.themes.Base(
primary_hue="orange",
secondary_hue="cyan",
neutral_hue="gray",
).set(
body_text_color='*neutral_100',
body_text_color_subdued='*neutral_600',
background_fill_primary='*neutral_950',
background_fill_secondary='*neutral_600',
border_color_accent='*secondary_800',
color_accent='*primary_50',
color_accent_soft='*secondary_800',
code_background_fill='*neutral_700',
block_background_fill_dark='*body_background_fill',
block_info_text_color='#6b7280',
block_label_text_color='*neutral_300',
block_label_text_weight='700',
block_title_text_color='*block_label_text_color',
block_title_text_weight='300',
panel_background_fill='*neutral_800',
table_text_color_dark='*secondary_800',
checkbox_background_color_selected='*primary_500',
checkbox_label_background_fill='*neutral_500',
checkbox_label_background_fill_hover='*neutral_700',
checkbox_label_text_color='*neutral_200',
input_background_fill='*neutral_700',
input_background_fill_focus='*neutral_600',
slider_color='*primary_500',
table_even_background_fill='*neutral_700',
table_odd_background_fill='*neutral_600',
table_row_focus='*neutral_800'
)
with gr.Blocks(css=css, theme=theme) as demo:
with gr.Column():
gr.HTML(
value=(
'<div class="main-title">Litton7景觀分類模型</div>'
'<div class="reference">引用資料:'
'<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">'
"何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)"
"</a>"
"</div>"
),
)
with gr.Row(equal_height=True):
with gr.Group():
img = gr.Image(label="上傳影像", type="pil", height="256px")
gr.Label("範例影像", show_label=False)
with gr.Row():
ex1 = gr.Image(
value="examples/beach.jpg",
show_label=False,
type="filepath",
elem_classes="example-image",
interactive=False,
show_download_button=False,
show_fullscreen_button=False,
show_share_button=False,
)
ex2 = gr.Image(
value="examples/field.jpg",
show_label=False,
type="filepath",
elem_classes="example-image",
interactive=False,
show_download_button=False,
show_fullscreen_button=False,
show_share_button=False,
)
ex3 = gr.Image(
value="examples/sky.jpg",
show_label=False,
type="filepath",
elem_classes="example-image",
interactive=False,
show_download_button=False,
show_fullscreen_button=False,
show_share_button=False,
)
chart = gr.Plot(label="分類結果")
start_button = gr.Button("開始", variant="primary")
gr.HTML(
'<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
)
start_button.click(
fn=predict,
inputs=img,
outputs=chart,
)
ex1.select(fn=choose_example, inputs=ex1, outputs=img)
ex2.select(fn=choose_example, inputs=ex2, outputs=img)
ex3.select(fn=choose_example, inputs=ex3, outputs=img)
return demo
if __name__ == "__main__":
get_layout().launch()