Cascade AI commited on
Commit
f63ca0d
·
1 Parent(s): 33747bb

Обновление модели и логики обучения

Browse files
Files changed (5) hide show
  1. README.md +21 -11
  2. app.py +38 -91
  3. hf_app.py +16 -1
  4. requirements.txt +9 -9
  5. train.py +2 -1
README.md CHANGED
@@ -9,23 +9,33 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # Классификатор болезней томатов
13
 
14
- Приложение для распознавания болезней томатов с использованием машинного обучения.
 
15
 
16
  ## Возможности
17
- - Классификация 10 различных состояний листьев томатов
18
- - Использование SVM-модели с RBF-ядром
19
- - Высокая точность распознавания
 
 
 
 
 
 
20
 
21
  ## Классы болезней
22
- 1. Бактериальное пятно
23
- 2. Ранняя пятнистость
24
- 3. Поздняя пятнистость
25
  4. Листовая плесень
26
- 5. Септориоз
27
  6. Паутинный клещ
28
  7. Целевая пятнистость
29
- 8. Скручивание листьев
30
  9. Мозаичный вирус
31
- 10. Здоровые листья
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # 🍅 Классификатор болезней томатов
13
 
14
+ ## Описание проекта
15
+ Интерактивное приложение для диагностики болезней томатов с использованием машинного обучения.
16
 
17
  ## Возможности
18
+ - Определение болезней по изображению листа томата
19
+ - Поддержка 10 различных классов заболеваний
20
+ - Автоматическое обучение модели при первом запуске
21
+
22
+ ## Технологии
23
+ - Gradio
24
+ - scikit-learn
25
+ - OpenCV
26
+ - PyTorch
27
 
28
  ## Классы болезней
29
+ 1. Бактериальная пятнистость
30
+ 2. Ранняя фитофтора
31
+ 3. Поздняя фитофтора
32
  4. Листовая плесень
33
+ 5. Пятнистость листьев Септориоз
34
  6. Паутинный клещ
35
  7. Целевая пятнистость
36
+ 8. Вирус желтой курчавости листьев
37
  9. Мозаичный вирус
38
+ 10. Здоровый лист
39
+
40
+ ## Использование
41
+ Загрузите изображение листа томата и получите диагноз с вероятностями.
app.py CHANGED
@@ -1,29 +1,14 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import torch
5
- import pkg_resources
6
- import subprocess
7
- import sys
8
- import threading
9
- import time
10
- import queue
11
-
12
- # Проверка и обновление версии Gradio
13
- required_gradio = '4.44.1'
14
- try:
15
- current_gradio = pkg_resources.get_distribution('gradio').version
16
- if current_gradio != required_gradio:
17
- subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--no-cache-dir', f'gradio=={required_gradio}'])
18
- print(f'Gradio обновлен до версии {required_gradio}')
19
- except:
20
- subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--no-cache-dir', f'gradio=={required_gradio}'])
21
- print(f'Установлен Gradio версии {required_gradio}')
22
-
23
  import gradio as gr
 
 
 
 
 
 
 
24
 
25
  # Список классов болезней
26
- TOMATO_CLASSES = [
27
  'Tomato___Bacterial_spot',
28
  'Tomato___Early_blight',
29
  'Tomato___Late_blight',
@@ -36,28 +21,31 @@ TOMATO_CLASSES = [
36
  'Tomato___healthy'
37
  ]
38
 
39
- # Состояние обучения
40
- MODEL_TRAINING_STATUS = queue.Queue()
41
- MODEL_TRAINING_STATUS.put("Модель не обучена")
42
-
43
  def load_model():
44
  """Загрузка обученной модели"""
45
  try:
46
  model_path = 'tomato_disease_classifier.pth'
47
 
48
- if not os.path.exists(model_path):
49
- print(f"Модель не найдена по пути: {os.path.abspath(model_path)}")
50
- return None, None
 
 
51
 
 
52
  model_data = torch.load(model_path)
53
- if 'classifier' not in model_data or 'scaler' not in model_data:
54
- print("Некорректная структура модели: отсутствуют необходимые компоненты")
55
- return None, None
56
 
57
- print("Модель успешно загружена")
58
- return model_data['classifier'], model_data['scaler']
 
 
 
 
 
 
 
59
  except Exception as e:
60
- print(f"Ошибка загрузки модели: {str(e)}")
61
  return None, None
62
 
63
  def preprocess_image(image):
@@ -74,12 +62,12 @@ def preprocess_image(image):
74
  def predict_disease(image):
75
  """Предсказание болезни томата"""
76
  if image is None:
77
- return "Пожалуйста, загрузите изображение листа томата"
78
 
79
  # Загрузка модели
80
- classifier, scaler = load_model()
81
- if classifier is None or scaler is None:
82
- return "Модель не найдена. Пожалуйста, убедитесь, что файл 'tomato_disease_classifier.pth' находится в корневой папке проекта."
83
 
84
  # Предобработка изображения
85
  processed_image = preprocess_image(image)
@@ -96,61 +84,20 @@ def predict_disease(image):
96
  # Формирование результата
97
  result = f"Обнаружено: {prediction[0]}\n\n"
98
  result += "Вероятности:\n"
99
- for disease, prob in zip(TOMATO_CLASSES, probabilities):
100
  result += f"{disease}: {prob*100:.2f}%\n"
101
 
102
  return result
103
 
104
  # Создание Gradio интерфейса
105
- demo = gr.Blocks(title="Диагностика болезней томатов")
 
 
 
 
 
 
106
 
107
- with demo:
108
- gr.Markdown("""
109
- # Диагностика болезней томатов
110
- За��рузите изображение листа томата для определения заболевания
111
- """)
112
-
113
- status_text = gr.Textbox(label="Статус модели", interactive=False)
114
-
115
- with gr.Row():
116
- with gr.Column():
117
- input_image = gr.Image(type="numpy", label="Загрузите изображение листа томата")
118
- predict_btn = gr.Button("Определить болезнь")
119
-
120
- with gr.Column():
121
- output_text = gr.Textbox(label="Результат диагностики")
122
-
123
- predict_btn.click(
124
- fn=predict_disease,
125
- inputs=input_image,
126
- outputs=output_text
127
- )
128
-
129
- # Периодическое обновление статуса
130
- demo.load(fn=lambda: MODEL_TRAINING_STATUS.queue[0] if not MODEL_TRAINING_STATUS.empty() else "Модель не обучена", inputs=None, outputs=status_text, every=5)
131
-
132
- def train_model_in_background():
133
- """Запуск обучения модели в фоновом режиме"""
134
- try:
135
- MODEL_TRAINING_STATUS.queue.clear()
136
- MODEL_TRAINING_STATUS.put("Начало обучения модели...")
137
- print("Начинаем обучение модели...")
138
- from train import main as train_main
139
- train_main()
140
- MODEL_TRAINING_STATUS.queue.clear()
141
- MODEL_TRAINING_STATUS.put("Модель успешно обучена!")
142
- print("Модель успешно обучена!")
143
- except Exception as e:
144
- MODEL_TRAINING_STATUS.queue.clear()
145
- MODEL_TRAINING_STATUS.put(f"Ошибка при обучении: {str(e)}")
146
- print(f"Ошибка при обучении модели: {str(e)}")
147
-
148
- # Запуск интерфейса
149
  if __name__ == "__main__":
150
- # Проверка наличия модели
151
- if not os.path.exists('tomato_disease_classifier.pth'):
152
- print("Модель не найдена, запускаем обучение в фоновом режиме")
153
- training_thread = threading.Thread(target=train_model_in_background)
154
- training_thread.start()
155
-
156
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from sklearn.preprocessing import StandardScaler
6
+ from sklearn.svm import SVC
7
+ import logging
8
+ from train import download_and_prepare_dataset, load_images_and_labels, train_and_evaluate_model
9
 
10
  # Список классов болезней
11
+ DISEASE_CLASSES = [
12
  'Tomato___Bacterial_spot',
13
  'Tomato___Early_blight',
14
  'Tomato___Late_blight',
 
21
  'Tomato___healthy'
22
  ]
23
 
 
 
 
 
24
  def load_model():
25
  """Загрузка обученной модели"""
26
  try:
27
  model_path = 'tomato_disease_classifier.pth'
28
 
29
+ # Если модель не существует, обучаем
30
+ if not torch.os.path.exists(model_path):
31
+ download_and_prepare_dataset()
32
+ X, y = load_images_and_labels()
33
+ train_and_evaluate_model(X, y)
34
 
35
+ # Загрузка модели
36
  model_data = torch.load(model_path)
 
 
 
37
 
38
+ # Создание scaler
39
+ scaler = StandardScaler()
40
+ scaler.mean_ = model_data['mean']
41
+ scaler.scale_ = model_data['std']
42
+
43
+ classifier = model_data['classifier']
44
+
45
+ return scaler, classifier
46
+
47
  except Exception as e:
48
+ print(f"Ошибка загрузки модели: {e}")
49
  return None, None
50
 
51
  def preprocess_image(image):
 
62
  def predict_disease(image):
63
  """Предсказание болезни томата"""
64
  if image is None:
65
+ return "Пожалуйста, загрузите изображение"
66
 
67
  # Загрузка модели
68
+ scaler, classifier = load_model()
69
+ if scaler is None or classifier is None:
70
+ return "Ошибка загрузки модели. Возможно, нужно сначала обучить модель."
71
 
72
  # Предобработка изображения
73
  processed_image = preprocess_image(image)
 
84
  # Формирование результата
85
  result = f"Обнаружено: {prediction[0]}\n\n"
86
  result += "Вероятности:\n"
87
+ for disease, prob in zip(DISEASE_CLASSES, probabilities):
88
  result += f"{disease}: {prob*100:.2f}%\n"
89
 
90
  return result
91
 
92
  # Создание Gradio интерфейса
93
+ iface = gr.Interface(
94
+ fn=predict_disease,
95
+ inputs=gr.Image(type="numpy", label="Загрузите изображение листа томата"),
96
+ outputs=gr.Textbox(label="Результат диагностики"),
97
+ title="🍅 Диагностика болезней томатов",
98
+ description="Загрузите изображение листа томата для определения заболевания"
99
+ )
100
 
101
+ # Запуск приложения
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  if __name__ == "__main__":
103
+ iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
hf_app.py CHANGED
@@ -7,6 +7,8 @@ from fastapi import FastAPI, File, UploadFile
7
  from sklearn.preprocessing import StandardScaler
8
  from sklearn.svm import SVC
9
  import logging
 
 
10
 
11
  # Настройка логгера
12
  logger = logging.getLogger(__name__)
@@ -51,7 +53,8 @@ def load_model():
51
  model_path = next((path for path in model_paths if os.path.exists(path)), None)
52
 
53
  if model_path is None:
54
- logger.error("Модель не найдена ни в одном из путей")
 
55
  return None, None
56
 
57
  logger.info(f"Загрузка модели из: {model_path}")
@@ -122,5 +125,17 @@ def read_root():
122
  return {"status": "Tomato Disease Classifier is running"}
123
 
124
  # Запуск Gradio
 
 
 
 
 
 
 
 
 
 
 
 
125
  if __name__ == "__main__":
126
  iface.launch(server_name="0.0.0.0", server_port=7860)
 
7
  from sklearn.preprocessing import StandardScaler
8
  from sklearn.svm import SVC
9
  import logging
10
+ from train import download_and_prepare_dataset, load_images_and_labels, train_and_evaluate_model, cleanup
11
+ import threading
12
 
13
  # Настройка логгера
14
  logger = logging.getLogger(__name__)
 
53
  model_path = next((path for path in model_paths if os.path.exists(path)), None)
54
 
55
  if model_path is None:
56
+ logger.info("Модель не найдена, запускаем обучение в фоновом режиме")
57
+ threading.Thread(target=train_model).start()
58
  return None, None
59
 
60
  logger.info(f"Загрузка модели из: {model_path}")
 
125
  return {"status": "Tomato Disease Classifier is running"}
126
 
127
  # Запуск Gradio
128
+ def train_model():
129
+ try:
130
+ logger.info("Начинаем обучение модели...")
131
+ download_and_prepare_dataset()
132
+ X, y = load_images_and_labels()
133
+ train_and_evaluate_model(X, y)
134
+ logger.info("Модель успешно обучена!")
135
+ except Exception as e:
136
+ logger.error(f"Ошибка при обучении модели: {e}")
137
+ finally:
138
+ cleanup()
139
+
140
  if __name__ == "__main__":
141
  iface.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt CHANGED
@@ -1,18 +1,18 @@
1
  # Gradio и его зависимости
2
  gradio==4.44.1
3
- fastapi==0.104.1
4
- uvicorn[standard]==0.27.0
5
 
6
  # Основные библиотеки для ML
7
- torch==2.1.2
8
- numpy==1.24.3
9
- opencv-python-headless==4.8.1.78
10
- scikit-learn==1.3.2
11
- matplotlib==3.7.2
12
- seaborn==0.12.2
13
 
14
  # Вспомогательные библиотеки
15
- python-multipart==0.0.9
16
  requests==2.31.0
17
  scipy==1.12.0
18
  tqdm==4.66.2
 
1
  # Gradio и его зависимости
2
  gradio==4.44.1
3
+ fastapi
4
+ uvicorn[standard]
5
 
6
  # Основные библиотеки для ML
7
+ torch
8
+ numpy
9
+ opencv-python-headless
10
+ scikit-learn
11
+ matplotlib
12
+ seaborn
13
 
14
  # Вспомогательные библиотеки
15
+ python-multipart
16
  requests==2.31.0
17
  scipy==1.12.0
18
  tqdm==4.66.2
train.py CHANGED
@@ -152,7 +152,8 @@ def train_and_evaluate_model(X, y):
152
  model_save_path = os.path.join(os.getcwd(), 'tomato_disease_classifier.pth')
153
  torch.save({
154
  'classifier': classifier,
155
- 'scaler': scaler,
 
156
  'classes': TOMATO_CLASSES
157
  }, model_save_path)
158
 
 
152
  model_save_path = os.path.join(os.getcwd(), 'tomato_disease_classifier.pth')
153
  torch.save({
154
  'classifier': classifier,
155
+ 'mean': scaler.mean_,
156
+ 'std': scaler.scale_,
157
  'classes': TOMATO_CLASSES
158
  }, model_save_path)
159