import os import cv2 import numpy as np import torch import gradio as gr from fastapi import FastAPI, File, UploadFile from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC import logging from train import download_and_prepare_dataset, load_images_and_labels, train_and_evaluate_model, cleanup import threading # Настройка логгера logger = logging.getLogger(__name__) # Список классов болезней DISEASE_CLASSES = [ 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy' ] def preprocess_image(image): """Подготовка изображения для предсказания""" if image is None: return None # Resize и flatten img_resized = cv2.resize(image, (64, 64)) img_flattened = img_resized.flatten() return img_flattened def load_model(): """Загрузка обученной модели с поддержкой множественных путей""" try: # Список возможных путей для модели model_paths = [ '/home/user/app/tomato_disease_classifier.pth', # Основной путь '/tmp/data/state/SVC_comb_R.pth.pth', # Путь Hugging Face '/tmp/tomato_disease_classifier.pth', # Резервный путь 'tomato_disease_classifier.pth' # Локальный путь ] # Поиск первого существующего пути model_path = next((path for path in model_paths if os.path.exists(path)), None) if model_path is None: logger.info("Модель не найдена, запускаем обучение в фоновом режиме") threading.Thread(target=train_model).start() return None, None logger.info(f"Загрузка модели из: {model_path}") # Загрузка данных модели model_data = torch.load(model_path) # Создание pipeline с масштабированием scaler = StandardScaler() scaler.mean_ = model_data['mean'] scaler.scale_ = model_data['std'] classifier = model_data['classifier'] return scaler, classifier except Exception as e: logger.error(f"Ошибка загрузки модели: {e}") import traceback logger.error(traceback.format_exc()) return None, None def predict_disease(image): """Предсказание болезни томата""" if image is None: return "Пожалуйста, загрузите изображение" # Загрузка модели scaler, classifier = load_model() if scaler is None or classifier is None: return "Ошибка загрузки модели. Возможно, нужно сначала обучить модель." # Предобработка изображения processed_image = preprocess_image(image) if processed_image is None: return "Не удалось обработать изображение" # Масштабирование processed_image = scaler.transform([processed_image]) # Предсказание prediction = classifier.predict(processed_image) probabilities = classifier.predict_proba(processed_image)[0] # Формирование результата result = f"Обнаружено: {prediction[0]}\n\n" result += "Вероятности:\n" for disease, prob in zip(DISEASE_CLASSES, probabilities): result += f"{disease}: {prob*100:.2f}%\n" return result # FastAPI приложение app = FastAPI() # Gradio интерфейс iface = gr.Interface( fn=predict_disease, inputs=gr.Image(type="numpy", label="Загрузите изображение листа томата"), outputs=gr.Textbox(label="Результат диагностики"), title="Диагностика болезней томатов", description="Загрузите изображение листа томата для определения заболевания" ) # Маршрут для Gradio @app.get("/") def read_root(): return {"status": "Tomato Disease Classifier is running"} # Запуск Gradio def train_model(): try: logger.info("Начинаем обучение модели...") download_and_prepare_dataset() X, y = load_images_and_labels() train_and_evaluate_model(X, y) logger.info("Модель успешно обучена!") except Exception as e: logger.error(f"Ошибка при обучении модели: {e}") finally: cleanup() if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)