import matplotlib.pyplot as plt import torch from PIL import Image from torchvision import transforms import torch.nn.functional as F from typing import Literal, Any import gradio as gr import spaces from io import BytesIO class Classifier: LABELS = [ "Panoramic", "Feature", "Detail", "Enclosed", "Focal", "Ephemeral", "Canopied", ] @spaces.GPU(duration=60) def __init__( self, model_path="Litton-7type-visual-landscape-model.pth", device=None ): if device is None: self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") else: self.device = device self.device = device self.model = torch.load( model_path, map_location=self.device, weights_only=False ) if hasattr(self.model, "module"): self.model = self.model.module self.model.eval() self.preprocess = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) @spaces.GPU(duration=60) def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]: image = image.convert("RGB") input_tensor = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): logits = self.model(input_tensor) probs = F.softmax(logits[:, :7], dim=1).cpu() return draw_bar_chart( { "class": self.LABELS, "probs": probs[0] * 100, } ) def draw_bar_chart(data: dict[str, list[str | float]]): classes = data["class"] probabilities = data["probs"] plt.figure(figsize=(8, 6)) plt.bar(classes, probabilities, color="skyblue") plt.xlabel("Class") plt.ylabel("Probability (%)") plt.title("Class Probabilities") for i, prob in enumerate(probabilities): plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom") plt.tight_layout() return plt def get_layout(): css = """ .main-title { font-size: 24px; font-weight: bold; text-align: center; margin-bottom: 20px; } .reference { text-align: center; font-size: 1.2em; color: #d1d5db; margin-bottom: 20px; } .reference a { color: #FB923C; text-decoration: none; } .reference a:hover { text-decoration: underline; color: #FB923C; } .title { border-bottom: 1px solid; } .footer { text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #ddd; color: #d1d5db; font-size: 14px; } """ theme = gr.themes.Base( primary_hue="orange", secondary_hue="orange", neutral_hue="gray", font=gr.themes.GoogleFont("Source Sans Pro"), ).set( background_fill_primary="*neutral_950", # 主背景色(深黑) button_primary_background_fill="*primary_500", # 按鈕顏色(橘色) body_text_color="*neutral_200", # 文字顏色(淺色) ) # with gr.Blocks(css=css, theme=theme) as demo: with gr.Blocks() as demo: with gr.Column(): gr.HTML( value=( '