|
from PIL import Image |
|
from io import BytesIO |
|
from matplotlib.figure import Figure |
|
from torchvision import transforms |
|
from tqdm import tqdm |
|
from typing import Literal, Any |
|
from urllib.request import urlopen |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import os |
|
import spaces |
|
import sys |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
LABELS = [ |
|
"Panoramic", |
|
"Feature", |
|
"Detail", |
|
"Enclosed", |
|
"Focal", |
|
"Ephemeral", |
|
"Canopied", |
|
] |
|
MODELFILE = "Litton-7type-visual-landscape-model.pth" |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
if not os.path.exists(MODELFILE): |
|
model_url = f"https://lclab.thu.edu.tw/modelzoo/{MODELFILE}" |
|
|
|
print(f"fetch model from {model_url}...", file=sys.stderr) |
|
|
|
with urlopen(model_url) as resp: |
|
progress = tqdm(total=int(resp["Content-Length"]), desc="Downloading") |
|
with open(MODELFILE, "wb") as modelfile: |
|
while True: |
|
chunk = resp.read(1024) |
|
if len(chunk) == 0: |
|
break |
|
modelfile.write(chunk) |
|
progress.update(len(chunk)) |
|
|
|
model = torch.load( |
|
MODELFILE, map_location=device, weights_only=False |
|
).module |
|
model.eval() |
|
preprocess = transforms.Compose( |
|
[ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
] |
|
) |
|
|
|
@spaces.GPU |
|
def predict(image: Image.Image) -> Figure: |
|
image = image.convert("RGB") |
|
input_tensor = preprocess(image).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
logits = model(input_tensor) |
|
probs = F.softmax(logits[:, :7], dim=1).cpu() |
|
|
|
return draw_bar_chart( |
|
{ |
|
"class": LABELS, |
|
"probs": probs[0] * 100, |
|
} |
|
) |
|
|
|
|
|
def draw_bar_chart(data: dict[str, list[str | float]]): |
|
classes = data["class"] |
|
probabilities = data["probs"] |
|
|
|
fig, ax = plt.subplots(figsize=(8, 6)) |
|
ax.bar(classes, probabilities, color="skyblue") |
|
|
|
ax.set_xlabel("Class") |
|
ax.set_ylabel("Probability (%)") |
|
ax.set_title("Class Probability") |
|
|
|
for i, prob in enumerate(probabilities): |
|
ax.text(i, prob + 0.01, f"{prob:.2f}%", ha="center", va="bottom") |
|
|
|
fig.tight_layout() |
|
|
|
return fig |
|
|
|
|
|
def choose_example(imgpath: str) -> gr.Image: |
|
img = Image.open(imgpath) |
|
width, height = img.size |
|
ratio = 512 / max(width, height) |
|
img = img.resize((int(width * ratio), int(height * ratio))) |
|
return gr.Image(value=img, label="輸入影像(不支援 SVG 格式)", type="pil") |
|
|
|
|
|
def get_layout(): |
|
css = """ |
|
.main-title { |
|
font-size: 24px; |
|
font-weight: bold; |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
.reference { |
|
text-align: center; |
|
font-size: 1.2em; |
|
color: #d1d5db; |
|
margin-bottom: 20px; |
|
} |
|
.reference a { |
|
color: #FB923C; |
|
text-decoration: none; |
|
} |
|
.reference a:hover { |
|
text-decoration: underline; |
|
color: #FB923C; |
|
} |
|
.title { |
|
border-bottom: 1px solid; |
|
} |
|
.footer { |
|
text-align: center; |
|
margin-top: 30px; |
|
padding-top: 20px; |
|
border-top: 1px solid #ddd; |
|
color: #d1d5db; |
|
font-size: 14px; |
|
} |
|
.example-image { |
|
height: 220px; |
|
padding: 25px; |
|
} |
|
""" |
|
theme = gr.themes.Base( |
|
primary_hue="orange", |
|
secondary_hue="cyan", |
|
neutral_hue="gray", |
|
).set( |
|
body_text_color='*neutral_100', |
|
body_text_color_subdued='*neutral_600', |
|
background_fill_primary='*neutral_950', |
|
background_fill_secondary='*neutral_600', |
|
border_color_accent='*secondary_800', |
|
color_accent='*primary_50', |
|
color_accent_soft='*secondary_800', |
|
code_background_fill='*neutral_700', |
|
block_background_fill_dark='*body_background_fill', |
|
block_info_text_color='#6b7280', |
|
block_label_text_color='*neutral_300', |
|
block_label_text_weight='700', |
|
block_title_text_color='*block_label_text_color', |
|
block_title_text_weight='300', |
|
panel_background_fill='*neutral_800', |
|
table_text_color_dark='*secondary_800', |
|
checkbox_background_color_selected='*primary_500', |
|
checkbox_label_background_fill='*neutral_500', |
|
checkbox_label_background_fill_hover='*neutral_700', |
|
checkbox_label_text_color='*neutral_200', |
|
input_background_fill='*neutral_700', |
|
input_background_fill_focus='*neutral_600', |
|
slider_color='*primary_500', |
|
table_even_background_fill='*neutral_700', |
|
table_odd_background_fill='*neutral_600', |
|
table_row_focus='*neutral_800' |
|
) |
|
with gr.Blocks(css=css, theme=theme) as demo: |
|
with gr.Column(): |
|
gr.HTML( |
|
value=( |
|
'<div class="main-title">Litton7景觀分類模型</div>' |
|
'<div class="reference">引用資料:' |
|
'<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">' |
|
"何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)" |
|
"</a>" |
|
"</div>" |
|
), |
|
) |
|
|
|
with gr.Row(equal_height=True): |
|
with gr.Group(): |
|
img = gr.Image(label="上傳影像", type="pil", height="256px") |
|
gr.Label("範例影像", show_label=False) |
|
with gr.Row(): |
|
ex1 = gr.Image( |
|
value="examples/beach.jpg", |
|
show_label=False, |
|
type="filepath", |
|
elem_classes="example-image", |
|
interactive=False, |
|
show_download_button=False, |
|
show_fullscreen_button=False, |
|
show_share_button=False, |
|
) |
|
ex2 = gr.Image( |
|
value="examples/field.jpg", |
|
show_label=False, |
|
type="filepath", |
|
elem_classes="example-image", |
|
interactive=False, |
|
show_download_button=False, |
|
show_fullscreen_button=False, |
|
show_share_button=False, |
|
) |
|
ex3 = gr.Image( |
|
value="examples/sky.jpg", |
|
show_label=False, |
|
type="filepath", |
|
elem_classes="example-image", |
|
interactive=False, |
|
show_download_button=False, |
|
show_fullscreen_button=False, |
|
show_share_button=False, |
|
) |
|
chart = gr.Plot(label="分類結果") |
|
|
|
start_button = gr.Button("開始", variant="primary") |
|
gr.HTML( |
|
'<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>', |
|
) |
|
|
|
start_button.click( |
|
fn=predict, |
|
inputs=img, |
|
outputs=chart, |
|
) |
|
|
|
ex1.select(fn=choose_example, inputs=ex1, outputs=img) |
|
ex2.select(fn=choose_example, inputs=ex2, outputs=img) |
|
ex3.select(fn=choose_example, inputs=ex3, outputs=img) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
get_layout().launch() |
|
|