lichih commited on
Commit
9c15e66
·
verified ·
1 Parent(s): 36d6732

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -170
app.py CHANGED
@@ -1,176 +1,159 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import pipeline
3
-
4
- pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
5
-
6
- def predict(image):
7
- predictions = pipeline(image)
8
- return {p["label"]: p["score"] for p in predictions}
9
-
10
- gr.Interface(
11
- predict,
12
- inputs=gr.Image(label="Upload hot dog candidate", type="filepath"),
13
- outputs=gr.Label(num_top_classes=2),
14
- title="Hot Dog? Or Not?",
15
- flagging_mode="manual"
16
- ).launch()
17
-
18
- # import matplotlib.pyplot as plt
19
- # import torch
20
- # from PIL import Image
21
- # from torchvision import transforms
22
- # import torch.nn.functional as F
23
- # from typing import Literal, Any
24
- # import gradio as gr
25
- # import spaces
26
- # from io import BytesIO
27
-
28
-
29
- # class Classifier:
30
- # LABELS = [
31
- # "Panoramic",
32
- # "Feature",
33
- # "Detail",
34
- # "Enclosed",
35
- # "Focal",
36
- # "Ephemeral",
37
- # "Canopied",
38
- # ]
39
-
40
- # @spaces.GPU(duration=60)
41
- # def __init__(
42
- # self, model_path="Litton-7type-visual-landscape-model.pth", device="cuda:0"
43
- # ):
44
- # self.device = device
45
- # self.model = torch.load(
46
- # model_path, map_location=self.device, weights_only=False
47
- # )
48
- # if hasattr(self.model, "module"):
49
- # self.model = self.model.module
50
- # self.model.eval()
51
- # self.preprocess = transforms.Compose(
52
- # [
53
- # transforms.Resize(256),
54
- # transforms.CenterCrop(224),
55
- # transforms.ToTensor(),
56
- # transforms.Normalize(
57
- # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
58
- # ),
59
- # ]
60
- # )
61
-
62
- # @spaces.GPU(duration=60)
63
- # def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]:
64
- # image = image.convert("RGB")
65
- # input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
66
-
67
- # with torch.no_grad():
68
- # logits = self.model(input_tensor)
69
- # probs = F.softmax(logits[:, :7], dim=1).cpu()
70
-
71
- # return draw_bar_chart(
72
- # {
73
- # "class": self.LABELS,
74
- # "probs": probs[0] * 100,
75
- # }
76
- # )
77
-
78
-
79
- # def draw_bar_chart(data: dict[str, list[str | float]]):
80
- # classes = data["class"]
81
- # probabilities = data["probs"]
82
-
83
- # plt.figure(figsize=(8, 6))
84
- # plt.bar(classes, probabilities, color="skyblue")
85
-
86
- # plt.xlabel("Class")
87
- # plt.ylabel("Probability (%)")
88
- # plt.title("Class Probabilities")
89
-
90
- # for i, prob in enumerate(probabilities):
91
- # plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom")
92
-
93
- # plt.tight_layout()
94
-
95
- # return plt
96
-
97
-
98
- # def get_layout():
99
- # demo = gr.Interface(fn=Classifier().predict, inputs="image", outputs="plot")
100
- # return demo
101
- # css = """
102
- # .main-title {
103
- # font-size: 24px;
104
- # font-weight: bold;
105
- # text-align: center;
106
- # margin-bottom: 20px;
107
- # }
108
- # .reference {
109
- # text-align: center;
110
- # font-size: 1.2em;
111
- # color: #d1d5db;
112
- # margin-bottom: 20px;
113
- # }
114
- # .reference a {
115
- # color: #FB923C;
116
- # text-decoration: none;
117
- # }
118
- # .reference a:hover {
119
- # text-decoration: underline;
120
- # color: #FB923C;
121
- # }
122
- # .title {
123
- # border-bottom: 1px solid;
124
- # }
125
- # .footer {
126
- # text-align: center;
127
- # margin-top: 30px;
128
- # padding-top: 20px;
129
- # border-top: 1px solid #ddd;
130
- # color: #d1d5db;
131
- # font-size: 14px;
132
- # }
133
- # """
134
- # theme = gr.themes.Base(
135
- # primary_hue="orange",
136
- # secondary_hue="orange",
137
- # neutral_hue="gray",
138
- # font=gr.themes.GoogleFont("Source Sans Pro"),
139
- # ).set(
140
- # background_fill_primary="*neutral_950", # 主背景色(深黑)
141
- # button_primary_background_fill="*primary_500", # 按鈕顏色(橘色)
142
- # body_text_color="*neutral_200", # 文字顏色(淺色)
143
- # )
144
- # # with gr.Blocks(css=css, theme=theme) as demo:
145
- # with gr.Blocks() as demo:
146
- # with gr.Column():
147
- # gr.HTML(
148
- # value=(
149
- # '<div class="main-title">Litton7景觀分類模型</div>'
150
- # '<div class="reference">引用資料:'
151
- # '<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">'
152
- # "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)"
153
- # "</a>"
154
- # "</div>"
155
- # ),
156
- # )
157
 
158
- # with gr.Row(equal_height=True):
159
- # image_input = gr.Image(label="上傳影像", type="pil")
160
- # chart = gr.Image(label="分類結果")
161
 
162
- # start_button = gr.Button("開始分類", variant="primary")
163
- # gr.HTML(
164
- # '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
165
- # )
166
- # start_button.click(
167
- # fn=Classifier().predict,
168
- # inputs=image_input,
169
- # outputs=chart,
170
- # )
171
 
172
- # return demo
173
 
174
 
175
- # if __name__ == "__main__":
176
- # get_layout().launch()
 
1
+ import matplotlib.pyplot as plt
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+ from typing import Literal, Any
7
  import gradio as gr
8
+ import spaces
9
+ from io import BytesIO
10
+
11
+
12
+ class Classifier:
13
+ LABELS = [
14
+ "Panoramic",
15
+ "Feature",
16
+ "Detail",
17
+ "Enclosed",
18
+ "Focal",
19
+ "Ephemeral",
20
+ "Canopied",
21
+ ]
22
+
23
+ @spaces.GPU(duration=60)
24
+ def __init__(
25
+ self, model_path="Litton-7type-visual-landscape-model.pth", device="cuda:0"
26
+ ):
27
+ self.device = device
28
+ self.model = torch.load(
29
+ model_path, map_location=self.device, weights_only=False
30
+ )
31
+ if hasattr(self.model, "module"):
32
+ self.model = self.model.module
33
+ self.model.eval()
34
+ self.preprocess = transforms.Compose(
35
+ [
36
+ transforms.Resize(256),
37
+ transforms.CenterCrop(224),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(
40
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
41
+ ),
42
+ ]
43
+ )
44
+
45
+ @spaces.GPU(duration=60)
46
+ def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]:
47
+ image = image.convert("RGB")
48
+ input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
49
+
50
+ with torch.no_grad():
51
+ logits = self.model(input_tensor)
52
+ probs = F.softmax(logits[:, :7], dim=1).cpu()
53
+
54
+ return draw_bar_chart(
55
+ {
56
+ "class": self.LABELS,
57
+ "probs": probs[0] * 100,
58
+ }
59
+ )
60
+
61
+
62
+ def draw_bar_chart(data: dict[str, list[str | float]]):
63
+ classes = data["class"]
64
+ probabilities = data["probs"]
65
+
66
+ plt.figure(figsize=(8, 6))
67
+ plt.bar(classes, probabilities, color="skyblue")
68
+
69
+ plt.xlabel("Class")
70
+ plt.ylabel("Probability (%)")
71
+ plt.title("Class Probabilities")
72
+
73
+ for i, prob in enumerate(probabilities):
74
+ plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom")
75
+
76
+ plt.tight_layout()
77
+
78
+ return plt
79
+
80
+
81
+ def get_layout():
82
+ demo = gr.Interface(fn=Classifier().predict, inputs="image", outputs="plot")
83
+ return demo
84
+ css = """
85
+ .main-title {
86
+ font-size: 24px;
87
+ font-weight: bold;
88
+ text-align: center;
89
+ margin-bottom: 20px;
90
+ }
91
+ .reference {
92
+ text-align: center;
93
+ font-size: 1.2em;
94
+ color: #d1d5db;
95
+ margin-bottom: 20px;
96
+ }
97
+ .reference a {
98
+ color: #FB923C;
99
+ text-decoration: none;
100
+ }
101
+ .reference a:hover {
102
+ text-decoration: underline;
103
+ color: #FB923C;
104
+ }
105
+ .title {
106
+ border-bottom: 1px solid;
107
+ }
108
+ .footer {
109
+ text-align: center;
110
+ margin-top: 30px;
111
+ padding-top: 20px;
112
+ border-top: 1px solid #ddd;
113
+ color: #d1d5db;
114
+ font-size: 14px;
115
+ }
116
+ """
117
+ theme = gr.themes.Base(
118
+ primary_hue="orange",
119
+ secondary_hue="orange",
120
+ neutral_hue="gray",
121
+ font=gr.themes.GoogleFont("Source Sans Pro"),
122
+ ).set(
123
+ background_fill_primary="*neutral_950", # 主背景色(深黑)
124
+ button_primary_background_fill="*primary_500", # 按鈕顏色(橘色)
125
+ body_text_color="*neutral_200", # 文字顏色(淺色)
126
+ )
127
+ # with gr.Blocks(css=css, theme=theme) as demo:
128
+ with gr.Blocks() as demo:
129
+ with gr.Column():
130
+ gr.HTML(
131
+ value=(
132
+ '<div class="main-title">Litton7景觀分類模型</div>'
133
+ '<div class="reference">引用資料:'
134
+ '<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">'
135
+ "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)"
136
+ "</a>"
137
+ "</div>"
138
+ ),
139
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ with gr.Row(equal_height=True):
142
+ image_input = gr.Image(label="上傳影像", type="pil")
143
+ chart = gr.Image(label="分類結果")
144
 
145
+ start_button = gr.Button("開始分類", variant="primary")
146
+ gr.HTML(
147
+ '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
148
+ )
149
+ start_button.click(
150
+ fn=Classifier().predict,
151
+ inputs=image_input,
152
+ outputs=chart,
153
+ )
154
 
155
+ return demo
156
 
157
 
158
+ if __name__ == "__main__":
159
+ get_layout().launch()