from PIL import Image from io import BytesIO from matplotlib.figure import Figure from torchvision import transforms from tqdm import tqdm from typing import Literal, Any from urllib.request import urlopen import gradio as gr import matplotlib.pyplot as plt import os import spaces import sys import torch import torch.nn.functional as F LABELS = [ "Panoramic", "Feature", "Detail", "Enclosed", "Focal", "Ephemeral", "Canopied", ] MODELFILE = "Litton-7type-visual-landscape-model.pth" device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if not os.path.exists(MODELFILE): model_url = f"https://lclab.thu.edu.tw/modelzoo/{MODELFILE}" print(f"fetch model from {model_url}...", file=sys.stderr) with urlopen(model_url) as resp: progress = tqdm(total=int(resp["Content-Length"]), desc="Downloading") with open(MODELFILE, "wb") as modelfile: while True: chunk = resp.read(1024) if len(chunk) == 0: break modelfile.write(chunk) progress.update(len(chunk)) model = torch.load( MODELFILE, map_location=device, weights_only=False ).module model.eval() preprocess = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) @spaces.GPU def predict(image: Image.Image) -> Figure: image = image.convert("RGB") input_tensor = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(input_tensor) probs = F.softmax(logits[:, :7], dim=1).cpu() return draw_bar_chart( { "class": LABELS, "probs": probs[0] * 100, } ) def draw_bar_chart(data: dict[str, list[str | float]]): classes = data["class"] probabilities = data["probs"] fig, ax = plt.subplots(figsize=(8, 6)) ax.bar(classes, probabilities, color="skyblue") ax.set_xlabel("Class") ax.set_ylabel("Probability (%)") ax.set_title("Class Probability") for i, prob in enumerate(probabilities): ax.text(i, prob + 0.01, f"{prob:.2f}%", ha="center", va="bottom") fig.tight_layout() return fig def choose_example(imgpath: str) -> gr.Image: img = Image.open(imgpath) width, height = img.size ratio = 512 / max(width, height) img = img.resize((int(width * ratio), int(height * ratio))) return gr.Image(value=img, label="輸入影像(不支援 SVG 格式)", type="pil") def get_layout(): css = """ .main-title { font-size: 24px; font-weight: bold; text-align: center; margin-bottom: 20px; } .reference { text-align: center; font-size: 1.2em; color: #d1d5db; margin-bottom: 20px; } .reference a { color: #FB923C; text-decoration: none; } .reference a:hover { text-decoration: underline; color: #FB923C; } .title { border-bottom: 1px solid; } .footer { text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #ddd; color: #d1d5db; font-size: 14px; } .example-image { height: 220px; padding: 25px; } """ theme = gr.themes.Base( primary_hue="orange", secondary_hue="cyan", neutral_hue="gray", ).set( body_text_color='*neutral_100', body_text_color_subdued='*neutral_600', background_fill_primary='*neutral_950', background_fill_secondary='*neutral_600', border_color_accent='*secondary_800', color_accent='*primary_50', color_accent_soft='*secondary_800', code_background_fill='*neutral_700', block_background_fill_dark='*body_background_fill', block_info_text_color='#6b7280', block_label_text_color='*neutral_300', block_label_text_weight='700', block_title_text_color='*block_label_text_color', block_title_text_weight='300', panel_background_fill='*neutral_800', table_text_color_dark='*secondary_800', checkbox_background_color_selected='*primary_500', checkbox_label_background_fill='*neutral_500', checkbox_label_background_fill_hover='*neutral_700', checkbox_label_text_color='*neutral_200', input_background_fill='*neutral_700', input_background_fill_focus='*neutral_600', slider_color='*primary_500', table_even_background_fill='*neutral_700', table_odd_background_fill='*neutral_600', table_row_focus='*neutral_800' ) with gr.Blocks(css=css, theme=theme) as demo: with gr.Column(): gr.HTML( value=( '