from PIL import Image from io import BytesIO from matplotlib.figure import Figure from torchvision import transforms from tqdm import tqdm from typing import Literal, Any from urllib.request import urlopen import gradio as gr import matplotlib.pyplot as plt import os import spaces import sys import torch import torch.nn.functional as F LABELS = [ "Panoramic", "Feature", "Detail", "Enclosed", "Focal", "Ephemeral", "Canopied", ] MODELFILE = "Litton-7type-visual-landscape-model.pth" device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if not os.path.exists(MODELFILE): model_url = f"https://lclab.thu.edu.tw/modelzoo/{MODELFILE}" print(f"fetch model from {model_url}...", file=sys.stderr) with urlopen(model_url) as resp: progress = tqdm(total=int(resp["Content-Length"]), desc="Downloading") with open(MODELFILE, "wb") as modelfile: while True: chunk = resp.read(1024) if len(chunk) == 0: break modelfile.write(chunk) progress.update(len(chunk)) model = torch.load( MODELFILE, map_location=device, weights_only=False ).module model.eval() preprocess = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) @spaces.GPU def predict(image: Image.Image) -> Figure: image = image.convert("RGB") input_tensor = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(input_tensor) probs = F.softmax(logits[:, :7], dim=1).cpu() return draw_bar_chart( { "class": LABELS, "probs": probs[0] * 100, } ) def draw_bar_chart(data: dict[str, list[str | float]]): classes = data["class"] probabilities = data["probs"] fig, ax = plt.subplots(figsize=(8, 6)) ax.bar(classes, probabilities, color="skyblue") ax.set_xlabel("Class") ax.set_ylabel("Probability (%)") ax.set_title("Class Probability") for i, prob in enumerate(probabilities): ax.text(i, prob + 0.01, f"{prob:.2f}%", ha="center", va="bottom") fig.tight_layout() return fig def choose_example(imgpath: str) -> gr.Image: img = Image.open(imgpath) width, height = img.size ratio = 512 / max(width, height) img = img.resize((int(width * ratio), int(height * ratio))) return gr.Image(value=img, label="輸入影像(不支援 SVG 格式)", type="pil") def get_layout(): css = """ .main-title { font-size: 24px; font-weight: bold; text-align: center; margin-bottom: 20px; } .reference { text-align: center; font-size: 1.2em; color: #d1d5db; margin-bottom: 20px; } .reference a { color: #FB923C; text-decoration: none; } .reference a:hover { text-decoration: underline; color: #FB923C; } .title { border-bottom: 1px solid; } .footer { text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #ddd; color: #d1d5db; font-size: 14px; } .example-image { height: 220px; padding: 25px; } """ theme = gr.themes.Base( primary_hue="orange", secondary_hue="cyan", neutral_hue="gray", ).set( body_text_color='*neutral_100', body_text_color_subdued='*neutral_600', background_fill_primary='*neutral_950', background_fill_secondary='*neutral_600', border_color_accent='*secondary_800', color_accent='*primary_50', color_accent_soft='*secondary_800', code_background_fill='*neutral_700', block_background_fill_dark='*body_background_fill', block_info_text_color='#6b7280', block_label_text_color='*neutral_300', block_label_text_weight='700', block_title_text_color='*block_label_text_color', block_title_text_weight='300', panel_background_fill='*neutral_800', table_text_color_dark='*secondary_800', checkbox_background_color_selected='*primary_500', checkbox_label_background_fill='*neutral_500', checkbox_label_background_fill_hover='*neutral_700', checkbox_label_text_color='*neutral_200', input_background_fill='*neutral_700', input_background_fill_focus='*neutral_600', slider_color='*primary_500', table_even_background_fill='*neutral_700', table_odd_background_fill='*neutral_600', table_row_focus='*neutral_800' ) with gr.Blocks(css=css, theme=theme) as demo: with gr.Column(): gr.HTML( value=( '
Litton7景觀分類模型
' '
引用資料:' '' "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)" "" "
" ), ) with gr.Row(equal_height=True): with gr.Group(): img = gr.Image(label="上傳影像", type="pil", height="256px") gr.Label("範例影像", show_label=False) with gr.Row(): ex1 = gr.Image( value="examples/beach.jpg", show_label=False, type="filepath", elem_classes="example-image", interactive=False, show_download_button=False, show_fullscreen_button=False, show_share_button=False, ) ex2 = gr.Image( value="examples/field.jpg", show_label=False, type="filepath", elem_classes="example-image", interactive=False, show_download_button=False, show_fullscreen_button=False, show_share_button=False, ) ex3 = gr.Image( value="examples/sky.jpg", show_label=False, type="filepath", elem_classes="example-image", interactive=False, show_download_button=False, show_fullscreen_button=False, show_share_button=False, ) chart = gr.Plot(label="分類結果") start_button = gr.Button("開始", variant="primary") gr.HTML( '', ) start_button.click( fn=predict, inputs=img, outputs=chart, ) ex1.select(fn=choose_example, inputs=ex1, outputs=img) ex2.select(fn=choose_example, inputs=ex2, outputs=img) ex3.select(fn=choose_example, inputs=ex3, outputs=img) return demo if __name__ == "__main__": get_layout().launch()