Spaces:
Runtime error
Runtime error
Cascade AI
commited on
Commit
·
f63ca0d
1
Parent(s):
33747bb
Обновление модели и логики обучения
Browse files- README.md +21 -11
- app.py +38 -91
- hf_app.py +16 -1
- requirements.txt +9 -9
- train.py +2 -1
README.md
CHANGED
@@ -9,23 +9,33 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
# Классификатор болезней томатов
|
13 |
|
14 |
-
|
|
|
15 |
|
16 |
## Возможности
|
17 |
-
-
|
18 |
-
-
|
19 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
## Классы болезней
|
22 |
-
1.
|
23 |
-
2. Ранняя
|
24 |
-
3. Поздняя
|
25 |
4. Листовая плесень
|
26 |
-
5. Септориоз
|
27 |
6. Паутинный клещ
|
28 |
7. Целевая пятнистость
|
29 |
-
8.
|
30 |
9. Мозаичный вирус
|
31 |
-
10.
|
|
|
|
|
|
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# 🍅 Классификатор болезней томатов
|
13 |
|
14 |
+
## Описание проекта
|
15 |
+
Интерактивное приложение для диагностики болезней томатов с использованием машинного обучения.
|
16 |
|
17 |
## Возможности
|
18 |
+
- Определение болезней по изображению листа томата
|
19 |
+
- Поддержка 10 различных классов заболеваний
|
20 |
+
- Автоматическое обучение модели при первом запуске
|
21 |
+
|
22 |
+
## Технологии
|
23 |
+
- Gradio
|
24 |
+
- scikit-learn
|
25 |
+
- OpenCV
|
26 |
+
- PyTorch
|
27 |
|
28 |
## Классы болезней
|
29 |
+
1. Бактериальная пятнистость
|
30 |
+
2. Ранняя фитофтора
|
31 |
+
3. Поздняя фитофтора
|
32 |
4. Листовая плесень
|
33 |
+
5. Пятнистость листьев Септориоз
|
34 |
6. Паутинный клещ
|
35 |
7. Целевая пятнистость
|
36 |
+
8. Вирус желтой курчавости листьев
|
37 |
9. Мозаичный вирус
|
38 |
+
10. Здоровый лист
|
39 |
+
|
40 |
+
## Использование
|
41 |
+
Загрузите изображение листа томата и получите диагноз с вероятностями.
|
app.py
CHANGED
@@ -1,29 +1,14 @@
|
|
1 |
-
import os
|
2 |
-
import cv2
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
import pkg_resources
|
6 |
-
import subprocess
|
7 |
-
import sys
|
8 |
-
import threading
|
9 |
-
import time
|
10 |
-
import queue
|
11 |
-
|
12 |
-
# Проверка и обновление версии Gradio
|
13 |
-
required_gradio = '4.44.1'
|
14 |
-
try:
|
15 |
-
current_gradio = pkg_resources.get_distribution('gradio').version
|
16 |
-
if current_gradio != required_gradio:
|
17 |
-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--no-cache-dir', f'gradio=={required_gradio}'])
|
18 |
-
print(f'Gradio обновлен до версии {required_gradio}')
|
19 |
-
except:
|
20 |
-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--no-cache-dir', f'gradio=={required_gradio}'])
|
21 |
-
print(f'Установлен Gradio версии {required_gradio}')
|
22 |
-
|
23 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
# Список классов болезней
|
26 |
-
|
27 |
'Tomato___Bacterial_spot',
|
28 |
'Tomato___Early_blight',
|
29 |
'Tomato___Late_blight',
|
@@ -36,28 +21,31 @@ TOMATO_CLASSES = [
|
|
36 |
'Tomato___healthy'
|
37 |
]
|
38 |
|
39 |
-
# Состояние обучения
|
40 |
-
MODEL_TRAINING_STATUS = queue.Queue()
|
41 |
-
MODEL_TRAINING_STATUS.put("Модель не обучена")
|
42 |
-
|
43 |
def load_model():
|
44 |
"""Загрузка обученной модели"""
|
45 |
try:
|
46 |
model_path = 'tomato_disease_classifier.pth'
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
51 |
|
|
|
52 |
model_data = torch.load(model_path)
|
53 |
-
if 'classifier' not in model_data or 'scaler' not in model_data:
|
54 |
-
print("Некорректная структура модели: отсутствуют необходимые компоненты")
|
55 |
-
return None, None
|
56 |
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
except Exception as e:
|
60 |
-
print(f"Ошибка загрузки модели: {
|
61 |
return None, None
|
62 |
|
63 |
def preprocess_image(image):
|
@@ -74,12 +62,12 @@ def preprocess_image(image):
|
|
74 |
def predict_disease(image):
|
75 |
"""Предсказание болезни томата"""
|
76 |
if image is None:
|
77 |
-
return "Пожалуйста, загрузите изображение
|
78 |
|
79 |
# Загрузка модели
|
80 |
-
|
81 |
-
if
|
82 |
-
return "
|
83 |
|
84 |
# Предобработка изображения
|
85 |
processed_image = preprocess_image(image)
|
@@ -96,61 +84,20 @@ def predict_disease(image):
|
|
96 |
# Формирование результата
|
97 |
result = f"Обнаружено: {prediction[0]}\n\n"
|
98 |
result += "Вероятности:\n"
|
99 |
-
for disease, prob in zip(
|
100 |
result += f"{disease}: {prob*100:.2f}%\n"
|
101 |
|
102 |
return result
|
103 |
|
104 |
# Создание Gradio интерфейса
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
|
108 |
-
gr.Markdown("""
|
109 |
-
# Диагностика болезней томатов
|
110 |
-
За��рузите изображение листа томата для определения заболевания
|
111 |
-
""")
|
112 |
-
|
113 |
-
status_text = gr.Textbox(label="Статус модели", interactive=False)
|
114 |
-
|
115 |
-
with gr.Row():
|
116 |
-
with gr.Column():
|
117 |
-
input_image = gr.Image(type="numpy", label="Загрузите изображение листа томата")
|
118 |
-
predict_btn = gr.Button("Определить болезнь")
|
119 |
-
|
120 |
-
with gr.Column():
|
121 |
-
output_text = gr.Textbox(label="Результат диагностики")
|
122 |
-
|
123 |
-
predict_btn.click(
|
124 |
-
fn=predict_disease,
|
125 |
-
inputs=input_image,
|
126 |
-
outputs=output_text
|
127 |
-
)
|
128 |
-
|
129 |
-
# Периодическое обновление статуса
|
130 |
-
demo.load(fn=lambda: MODEL_TRAINING_STATUS.queue[0] if not MODEL_TRAINING_STATUS.empty() else "Модель не обучена", inputs=None, outputs=status_text, every=5)
|
131 |
-
|
132 |
-
def train_model_in_background():
|
133 |
-
"""Запуск обучения модели в фоновом режиме"""
|
134 |
-
try:
|
135 |
-
MODEL_TRAINING_STATUS.queue.clear()
|
136 |
-
MODEL_TRAINING_STATUS.put("Начало обучения модели...")
|
137 |
-
print("Начинаем обучение модели...")
|
138 |
-
from train import main as train_main
|
139 |
-
train_main()
|
140 |
-
MODEL_TRAINING_STATUS.queue.clear()
|
141 |
-
MODEL_TRAINING_STATUS.put("Модель успешно обучена!")
|
142 |
-
print("Модель успешно обучена!")
|
143 |
-
except Exception as e:
|
144 |
-
MODEL_TRAINING_STATUS.queue.clear()
|
145 |
-
MODEL_TRAINING_STATUS.put(f"Ошибка при обучении: {str(e)}")
|
146 |
-
print(f"Ошибка при обучении модели: {str(e)}")
|
147 |
-
|
148 |
-
# Запуск интерфейса
|
149 |
if __name__ == "__main__":
|
150 |
-
|
151 |
-
if not os.path.exists('tomato_disease_classifier.pth'):
|
152 |
-
print("Модель не найдена, запускаем обучение в фоновом режиме")
|
153 |
-
training_thread = threading.Thread(target=train_model_in_background)
|
154 |
-
training_thread.start()
|
155 |
-
|
156 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from sklearn.preprocessing import StandardScaler
|
6 |
+
from sklearn.svm import SVC
|
7 |
+
import logging
|
8 |
+
from train import download_and_prepare_dataset, load_images_and_labels, train_and_evaluate_model
|
9 |
|
10 |
# Список классов болезней
|
11 |
+
DISEASE_CLASSES = [
|
12 |
'Tomato___Bacterial_spot',
|
13 |
'Tomato___Early_blight',
|
14 |
'Tomato___Late_blight',
|
|
|
21 |
'Tomato___healthy'
|
22 |
]
|
23 |
|
|
|
|
|
|
|
|
|
24 |
def load_model():
|
25 |
"""Загрузка обученной модели"""
|
26 |
try:
|
27 |
model_path = 'tomato_disease_classifier.pth'
|
28 |
|
29 |
+
# Если модель не существует, обучаем
|
30 |
+
if not torch.os.path.exists(model_path):
|
31 |
+
download_and_prepare_dataset()
|
32 |
+
X, y = load_images_and_labels()
|
33 |
+
train_and_evaluate_model(X, y)
|
34 |
|
35 |
+
# Загрузка модели
|
36 |
model_data = torch.load(model_path)
|
|
|
|
|
|
|
37 |
|
38 |
+
# Создание scaler
|
39 |
+
scaler = StandardScaler()
|
40 |
+
scaler.mean_ = model_data['mean']
|
41 |
+
scaler.scale_ = model_data['std']
|
42 |
+
|
43 |
+
classifier = model_data['classifier']
|
44 |
+
|
45 |
+
return scaler, classifier
|
46 |
+
|
47 |
except Exception as e:
|
48 |
+
print(f"Ошибка загрузки модели: {e}")
|
49 |
return None, None
|
50 |
|
51 |
def preprocess_image(image):
|
|
|
62 |
def predict_disease(image):
|
63 |
"""Предсказание болезни томата"""
|
64 |
if image is None:
|
65 |
+
return "Пожалуйста, загрузите изображение"
|
66 |
|
67 |
# Загрузка модели
|
68 |
+
scaler, classifier = load_model()
|
69 |
+
if scaler is None or classifier is None:
|
70 |
+
return "Ошибка загрузки модели. Возможно, нужно сначала обучить модель."
|
71 |
|
72 |
# Предобработка изображения
|
73 |
processed_image = preprocess_image(image)
|
|
|
84 |
# Формирование результата
|
85 |
result = f"Обнаружено: {prediction[0]}\n\n"
|
86 |
result += "Вероятности:\n"
|
87 |
+
for disease, prob in zip(DISEASE_CLASSES, probabilities):
|
88 |
result += f"{disease}: {prob*100:.2f}%\n"
|
89 |
|
90 |
return result
|
91 |
|
92 |
# Создание Gradio интерфейса
|
93 |
+
iface = gr.Interface(
|
94 |
+
fn=predict_disease,
|
95 |
+
inputs=gr.Image(type="numpy", label="Загрузите изображение листа томата"),
|
96 |
+
outputs=gr.Textbox(label="Результат диагностики"),
|
97 |
+
title="🍅 Диагностика болезней томатов",
|
98 |
+
description="Загрузите изображение листа томата для определения заболевания"
|
99 |
+
)
|
100 |
|
101 |
+
# Запуск приложения
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
if __name__ == "__main__":
|
103 |
+
iface.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
hf_app.py
CHANGED
@@ -7,6 +7,8 @@ from fastapi import FastAPI, File, UploadFile
|
|
7 |
from sklearn.preprocessing import StandardScaler
|
8 |
from sklearn.svm import SVC
|
9 |
import logging
|
|
|
|
|
10 |
|
11 |
# Настройка логгера
|
12 |
logger = logging.getLogger(__name__)
|
@@ -51,7 +53,8 @@ def load_model():
|
|
51 |
model_path = next((path for path in model_paths if os.path.exists(path)), None)
|
52 |
|
53 |
if model_path is None:
|
54 |
-
logger.
|
|
|
55 |
return None, None
|
56 |
|
57 |
logger.info(f"Загрузка модели из: {model_path}")
|
@@ -122,5 +125,17 @@ def read_root():
|
|
122 |
return {"status": "Tomato Disease Classifier is running"}
|
123 |
|
124 |
# Запуск Gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
if __name__ == "__main__":
|
126 |
iface.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
7 |
from sklearn.preprocessing import StandardScaler
|
8 |
from sklearn.svm import SVC
|
9 |
import logging
|
10 |
+
from train import download_and_prepare_dataset, load_images_and_labels, train_and_evaluate_model, cleanup
|
11 |
+
import threading
|
12 |
|
13 |
# Настройка логгера
|
14 |
logger = logging.getLogger(__name__)
|
|
|
53 |
model_path = next((path for path in model_paths if os.path.exists(path)), None)
|
54 |
|
55 |
if model_path is None:
|
56 |
+
logger.info("Модель не найдена, запускаем обучение в фоновом режиме")
|
57 |
+
threading.Thread(target=train_model).start()
|
58 |
return None, None
|
59 |
|
60 |
logger.info(f"Загрузка модели из: {model_path}")
|
|
|
125 |
return {"status": "Tomato Disease Classifier is running"}
|
126 |
|
127 |
# Запуск Gradio
|
128 |
+
def train_model():
|
129 |
+
try:
|
130 |
+
logger.info("Начинаем обучение модели...")
|
131 |
+
download_and_prepare_dataset()
|
132 |
+
X, y = load_images_and_labels()
|
133 |
+
train_and_evaluate_model(X, y)
|
134 |
+
logger.info("Модель успешно обучена!")
|
135 |
+
except Exception as e:
|
136 |
+
logger.error(f"Ошибка при обучении модели: {e}")
|
137 |
+
finally:
|
138 |
+
cleanup()
|
139 |
+
|
140 |
if __name__ == "__main__":
|
141 |
iface.launch(server_name="0.0.0.0", server_port=7860)
|
requirements.txt
CHANGED
@@ -1,18 +1,18 @@
|
|
1 |
# Gradio и его зависимости
|
2 |
gradio==4.44.1
|
3 |
-
fastapi
|
4 |
-
uvicorn[standard]
|
5 |
|
6 |
# Основные библиотеки для ML
|
7 |
-
torch
|
8 |
-
numpy
|
9 |
-
opencv-python-headless
|
10 |
-
scikit-learn
|
11 |
-
matplotlib
|
12 |
-
seaborn
|
13 |
|
14 |
# Вспомогательные библиотеки
|
15 |
-
python-multipart
|
16 |
requests==2.31.0
|
17 |
scipy==1.12.0
|
18 |
tqdm==4.66.2
|
|
|
1 |
# Gradio и его зависимости
|
2 |
gradio==4.44.1
|
3 |
+
fastapi
|
4 |
+
uvicorn[standard]
|
5 |
|
6 |
# Основные библиотеки для ML
|
7 |
+
torch
|
8 |
+
numpy
|
9 |
+
opencv-python-headless
|
10 |
+
scikit-learn
|
11 |
+
matplotlib
|
12 |
+
seaborn
|
13 |
|
14 |
# Вспомогательные библиотеки
|
15 |
+
python-multipart
|
16 |
requests==2.31.0
|
17 |
scipy==1.12.0
|
18 |
tqdm==4.66.2
|
train.py
CHANGED
@@ -152,7 +152,8 @@ def train_and_evaluate_model(X, y):
|
|
152 |
model_save_path = os.path.join(os.getcwd(), 'tomato_disease_classifier.pth')
|
153 |
torch.save({
|
154 |
'classifier': classifier,
|
155 |
-
'
|
|
|
156 |
'classes': TOMATO_CLASSES
|
157 |
}, model_save_path)
|
158 |
|
|
|
152 |
model_save_path = os.path.join(os.getcwd(), 'tomato_disease_classifier.pth')
|
153 |
torch.save({
|
154 |
'classifier': classifier,
|
155 |
+
'mean': scaler.mean_,
|
156 |
+
'std': scaler.scale_,
|
157 |
'classes': TOMATO_CLASSES
|
158 |
}, model_save_path)
|
159 |
|