import os import numpy as np import torch import cv2 import logging import time from tqdm import tqdm import requests import zipfile import io from sklearn.model_selection import train_test_split, cross_val_score from sklearn.svm import SVC from sklearn.preprocessing import StandardScaler from sklearn.pipeline import make_pipeline from sklearn.metrics import classification_report, confusion_matrix import shutil from huggingface_hub import HfApi, HfFolder # Настройка логирования logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s', handlers=[ logging.FileHandler('/tmp/train_model.log', mode='w'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) def download_dataset(): """Распаковка локального архива с датасетом""" logger.info("Начало подготовки датасета") # Путь к локальному архиву dataset_path = 'tomato_dataset.zip' tmp_dataset_path = '/tmp/tomato_dataset.zip' extract_path = '/tmp/tomato' try: # Проверяем существование архива if not os.path.exists(dataset_path): logger.error(f"Архив не найден: {dataset_path}") raise FileNotFoundError(f"Архив не найден: {dataset_path}") # Создаем временную директорию для распаковки os.makedirs(extract_path, exist_ok=True) # Копируем архив во временную директорию shutil.copy(dataset_path, tmp_dataset_path) # Распаковываем архив with zipfile.ZipFile(tmp_dataset_path, 'r') as zip_ref: zip_ref.extractall(extract_path) logger.info(f"Архив распакован в {extract_path}") return tmp_dataset_path except Exception as e: logger.error(f"Ошибка при подготовке датасета: {e}") raise def load_images_from_folder(folder, max_images=None): """Загрузка изображений из папки""" images = [] labels = [] logger.info(f"Загрузка изображений из папки: {folder}") file_list = os.listdir(folder) if max_images: file_list = file_list[:max_images] for filename in tqdm(file_list, desc=f"Обработка {os.path.basename(folder)}"): img_path = os.path.join(folder, filename) if os.path.isfile(img_path): try: img = cv2.imread(img_path) if img is not None: # Resize и flatten изображения img_resized = cv2.resize(img, (64, 64)) images.append(img_resized.flatten()) labels.append(os.path.basename(folder)) except Exception as e: logger.warning(f"Ошибка при обработке {img_path}: {e}") logger.info(f"Загружено изображений: {len(images)}") return images, labels def prepare_dataset(dataset_archive_path): """Подготовка датасета из архива""" logger.info(f"Подготовка датасета из архива: {dataset_archive_path}") X = [] y = [] tomato_diseases = [ 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy' ] # Путь к распакованным данным extract_path = '/tmp/tomato' # Диагностика структуры датасета debug_dataset_structure(extract_path) # Сканируем распакованные файлы logger.info(f"Сканирование директории: {extract_path}") for disease in tqdm(tomato_diseases, desc="Обработка категорий"): folder_path = os.path.join(extract_path, disease) # Проверяем наличие директории if not os.path.exists(folder_path): logger.warning(f"Директория не найдена: {folder_path}") continue files = os.listdir(folder_path) logger.info(f"Категория {disease}: найдено файлов {len(files)}") for filename in tqdm(files, desc=f"Обработка {disease}"): if not filename.lower().endswith(('.jpg', '.jpeg', '.png')): continue img_path = os.path.join(folder_path, filename) try: img = cv2.imread(img_path) if img is not None: # Resize и flatten изображения img_resized = cv2.resize(img, (64, 64)) X.append(img_resized.flatten()) y.append(disease) else: logger.warning(f"Не удалось прочитать изображение: {img_path}") except Exception as e: logger.error(f"Ошибка при обработке {img_path}: {e}") logger.info(f"Всего изображений в датасете: {len(X)}") if len(X) == 0: logger.critical("КРИТИЧЕСКАЯ ОШИБКА: Не загружено ни одного изображения!") raise ValueError("Пустой датасет. Проверьте источник данных.") return np.array(X), np.array(y) def debug_dataset_info(dataset_path): """Расширенная диагностика структуры датасета""" logger.critical("НАЧАЛО ПОЛНОЙ ДИАГНОСТИКИ ДАТАСЕТА") # Проверка существования корневой директории if not os.path.exists(dataset_path): logger.critical(f"ОШИБКА: Корневая директория не существует: {dataset_path}") return logger.info(f"Корневая директория: {dataset_path}") logger.info(f"Содержимое корневой директории: {os.listdir(dataset_path)}") # Пути для проверки paths_to_check = [ dataset_path, os.path.join(dataset_path, 'raw'), os.path.join(dataset_path, 'raw', 'color') ] for check_path in paths_to_check: if os.path.exists(check_path): logger.info(f"Директория существует: {check_path}") logger.info(f"Содержимое {check_path}: {os.listdir(check_path)}") else: logger.critical(f"ОШИБКА: Директория не существует: {check_path}") # Проверка изображений color_path = os.path.join(dataset_path, 'raw', 'color') if os.path.exists(color_path): for disease in os.listdir(color_path): disease_path = os.path.join(color_path, disease) images = [f for f in os.listdir(disease_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] logger.info(f"Класс {disease}: найдено изображений {len(images)}") logger.critical("ЗАВЕРШЕНИЕ ДИАГНОСТИКИ ДАТАСЕТА") def debug_dataset_structure(extract_path): """Отладочная функция для вывода структуры распакованного датасета""" logger.critical("НАЧАЛО ДИАГНОСТИКИ СТРУКТУРЫ ДАТАСЕТА") if not os.path.exists(extract_path): logger.critical(f"ОШИБКА: Директория не существует: {extract_path}") return logger.info(f"Корневая директория: {extract_path}") logger.info(f"Содержимое корневой директории: {os.listdir(extract_path)}") for root, dirs, files in os.walk(extract_path): level = root.replace(extract_path, '').count(os.sep) indent = ' ' * 4 * level logger.info(f"{indent}{os.path.basename(root)}/") subindent = ' ' * 4 * (level + 1) for file in files: logger.info(f"{subindent}{file}") logger.critical("ЗАВЕРШЕНИЕ ДИАГНОСТИКИ СТРУКТУРЫ ДАТАСЕТА") def train_model(X, y): """Обучение модели с кросс-валидацией""" logger.info("Начало подготовки и обучения модели") # Разделение на train и test X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) logger.info(f"Размер тренировочной выборки: {len(X_train)}") logger.info(f"Размер тестовой выборки: {len(X_test)}") # Создание pipeline с масштабированием и SVM model = make_pipeline( StandardScaler(), SVC(kernel='rbf', probability=True) ) # Кросс-валидация cv_scores = cross_val_score(model, X_train, y_train, cv=5) logger.info(f"Кросс-валидация: {cv_scores}") logger.info(f"Средняя точность кросс-валидации: {cv_scores.mean() * 100:.2f}%") # Обучение модели logger.info("Начало обучения модели...") start_time = time.time() model.fit(X_train, y_train) # Оценка точности accuracy = model.score(X_test, y_test) logger.info(f"Точность модели: {accuracy * 100:.2f}%") # Подробный отчет о классификации y_pred = model.predict(X_test) logger.info("\nОтчет о классификации:") logger.info(classification_report(y_test, y_pred)) # Матрица ошибок logger.info("\nМатрица ошибок:") logger.info(str(confusion_matrix(y_test, y_pred))) end_time = time.time() logger.info(f"Время обучения: {end_time - start_time:.2f} секунд") return model def save_model(model, path): """Сохранение модели с расширенной обработкой ошибок""" try: # Проверка прав доступа к директории dir_path = os.path.dirname(path) if not os.access(dir_path, os.W_OK): logger.error(f"Нет прав записи в директорию: {dir_path}") raise PermissionError(f"Невозможно сохранить модель. Отсутствуют права записи в {dir_path}") # Создание директории с корректными правами os.makedirs(dir_path, exist_ok=True) os.chmod(dir_path, 0o755) # Установка прав чтения/записи logger.info(f"Подготовка к сохранению модели в {path}") # Проверка корректности модели перед сохранением if not hasattr(model, 'named_steps'): logger.error("Некорректная структура модели. Невозможно извлечь параметры.") raise ValueError("Модель не содержит необходимых компонентов") # Сохранение с дополнительными проверками torch.save({ 'classifier': model, 'mean': model.named_steps['standardscaler'].mean_, 'std': model.named_steps['standardscaler'].scale_ }, path) # Проверка успешности сохранения if not os.path.exists(path): logger.error(f"Модель не сохранена по пути: {path}") raise IOError("Не удалось сохранить модель") # Установка прав доступа к файлу os.chmod(path, 0o644) logger.info("Модель успешно сохранена") # Дополнительная проверка размера файла file_size = os.path.getsize(path) logger.info(f"Размер сохраненной модели: {file_size} байт") if file_size == 0: logger.warning("Размер файла модели равен 0 байт") except Exception as e: logger.error(f"Критическая ошибка при сохранении модели: {e}") # Расширенное логирование для диагностики import traceback logger.error(traceback.format_exc()) raise def upload_to_huggingface(model_path, model_name='tomato-disease-classifier'): """Загрузка обученной модели на Hugging Face""" logger.info(f"Начало загрузки модели {model_name} на Hugging Face") try: # Проверка токена token = HfFolder.get_token() if not token: logger.critical(""" ОШИБКА: Токен Hugging Face не найден! Для загрузки модели выполните следующие шаги: 1. Зарегистрируйтесь на https://huggingface.co 2. Создайте токен доступа в настройках профиля 3. Выполните в терминале: huggingface-cli login 4. Введите свой токен при запросе """) return False # Инициализация API api = HfApi() # Загрузка модели api.upload_file( path_or_fileobj=model_path, path_in_repo='tomato_disease_classifier.pth', repo_id=f'list91/{model_name}', repo_type='model', token=token ) logger.info(f"Модель успешно загружена в репозиторий list91/{model_name}") return True except Exception as e: logger.error(f"Ошибка при загрузке модели: {e}") return False def main(): """Основная функция обучения и загрузки""" logger.info("Начало процесса обучения") try: # Загрузка датасета dataset_path = download_dataset() logger.info(f"Датасет загружен в {dataset_path}") # Подготовка данных X, y = prepare_dataset(dataset_path) logger.info(f"Данные подготовлены: {len(X)} изображений") # Обучение модели model = train_model(X, y) # Сохранение модели model_path = '/tmp/tomato_disease_classifier.pth' save_model(model, model_path) logger.info(f"Модель сохранена в {model_path}") # Загрузка на Hugging Face upload_to_huggingface(model_path) logger.info("Процесс обучения и загрузки завершен успешно!") except Exception as e: logger.error(f"Критическая ошибка в процессе обучения: {e}") import traceback logger.error(traceback.format_exc()) raise if __name__ == '__main__': main()