import pandas as pd
import numpy as np
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA
from sklearn.feature_selection import SelectKBest, f_classif
import matplotlib.pyplot as plt
import seaborn as sns
import logging
import os
from datetime import datetime
from typing import Dict, Tuple, List, Optional, Any
import xlsxwriter

class LDAAnalyzer:
    """
    Класс для выполнения линейного дискриминантного анализа (LDA)
    с расширенной функциональностью и форматированным выводом результатов
    """
    
    def __init__(self, input_file: str, target_column: int):
        """
        Инициализация анализатора LDA
        
        Args:
            input_file (str): Путь к входному файлу Excel
            target_column (int): Номер столбца для классификации
        """
        self.input_file = input_file
        self.target_column = target_column
        self.data = None
        self.X = None
        self.y = None
        self.X_transformed = None
        self.lda = None
        self.scaler = StandardScaler()
        self.label_encoder = LabelEncoder()
        self.feature_names = None
        
        # Настройка логирования
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler('lda_analysis.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
        
        # Цветовая схема для визуализации
        self.colors = ['lightblue', 'green', 'purple', 'yellow', 
                      'red', 'orange', 'cyan', 'brown', 'pink']
        
        self.logger.info(f"Инициализация LDA анализатора с файлом: {input_file}")

    def validate_data(self) -> None:
        """Валидация входных данных"""
        if self.data is None:
            raise ValueError("Данные не загружены")
            
        # Проверка размерности
        if self.data.shape[0] < 30:
            raise ValueError("Недостаточно наблюдений (минимум 30)")
            
        # Проверка пропущенных значений
        if self.data.isnull().any().any():
            raise ValueError("Обнаружены пропущенные значения")
            
        # Проверка типов данных
        numeric_cols = self.data.select_dtypes(include=[np.number]).columns
        if len(numeric_cols) < self.data.shape[1] - 1:  # -1 для целевой переменной
            raise ValueError("Обнаружены нечисловые признаки")

    def load_data(self) -> None:
        """Загрузка данных из Excel файла"""
        try:
            self.logger.info("Загрузка данных...")

            # Загрузка данных
            self.data = pd.read_excel(self.input_file)

            # Преобразование имен колонок
            self.data.columns = [str(col) for col in self.data.columns]

            # Попытка преобразовать все колонки (кроме целевой) в числовой формат
            for col in self.data.columns:
                if self.data.columns.get_loc(col) != self.target_column:
                    try:
                        self.data[col] = pd.to_numeric(self.data[col], errors='coerce')
                    except Exception as e:
                        self.logger.warning(f"Не удалось преобразовать колонку {col} в числовой формат: {str(e)}")

            self.validate_data()
            self.logger.info(f"Данные загружены. Размерность: {self.data.shape}")

        except Exception as e:
            self.logger.error(f"Ошибка при загрузке данных: {str(e)}")
            raise

    
    
    def prepare_data(self) -> None:
        """Подготовка данных для анализа"""
        try:
            self.logger.info("Подготовка данных...")

            # Разделение на признаки и целевую переменную
            X = self.data.drop(self.data.columns[self.target_column], axis=1)
            y = self.data.iloc[:, self.target_column]

            # Преобразование имен колонок в строки
            X.columns = X.columns.astype(str)

            # Кодирование меток классов
            self.y = self.label_encoder.fit_transform(y) + 1

            # Преобразование в числовой формат
            X = X.apply(pd.to_numeric, errors='coerce')

            # Проверка на пропущенные значения после преобразования
            if X.isnull().any().any():
                raise ValueError("После преобразования в числовой формат появились пропущенные значения")

            # Стандартизация признаков
            self.X = self.scaler.fit_transform(X)

            # Проверка количества классов и наблюдений в каждом классе
            class_counts = pd.Series(self.y).value_counts()
            if (class_counts < 5).any():
                self.logger.warning("Некоторые классы имеют менее 5 наблюдений")

            self.logger.info(f"Данные подготовлены. X: {self.X.shape}, y: {self.y.shape}")
            self.logger.info(f"Количество классов: {len(np.unique(self.y))}")

        except Exception as e:
            self.logger.error(f"Ошибка при подготовке данных: {str(e)}")
            raise
        
    def perform_lda(self) -> None:
        """Выполнение LDA анализа"""
        try:
            self.logger.info("Выполнение LDA анализа...")
            
            # Инициализация и обучение LDA
            self.lda = LinearDiscriminantAnalysis(solver='svd')
            self.X_transformed = self.lda.fit_transform(self.X, self.y)
            
            # Оценка качества модели
            accuracy = self.lda.score(self.X, self.y)
            self.logger.info(f"Общая точность модели: {accuracy:.3f}")
            
        except Exception as e:
            self.logger.error(f"Ошибка при выполнении LDA: {str(e)}")
            raise

    def create_confusion_matrix(self) -> Tuple[pd.DataFrame, List[List[str]], float]:
        """
        Создание матрицы ошибок и расчет процентов классификации

        Returns:
            tuple: (матрица ошибок, проценты, общая точность)
        """
        try:
            self.logger.info("Создание матрицы ошибок...")

            # Получение предсказаний
            y_pred = self.lda.predict(self.X)

            # Создание матрицы ошибок
            classes = sorted(np.unique(self.y))
            n_classes = len(classes)
            confusion_matrix = np.zeros((n_classes, n_classes))

            for i in range(len(self.y)):
                confusion_matrix[self.y[i]-1][y_pred[i]-1] += 1

            # Создание DataFrame для матрицы ошибок
            columns = [f"{i+1}.00" for i in range(n_classes)]
            index = [f"{i+1}.00" for i in range(n_classes)]

            df_confusion = pd.DataFrame(confusion_matrix, columns=columns, index=index)

            # Добавление столбца "Всего"
            df_confusion['Всего'] = df_confusion.sum(axis=1)

            # Расчет процентов
            percentages = np.zeros((n_classes, n_classes + 1))  # +1 для столбца "Всего"
            for i in range(n_classes):
                row_sum = confusion_matrix[i].sum()
                if row_sum > 0:
                    percentages[i, :-1] = (confusion_matrix[i] / row_sum) * 100
                    percentages[i, -1] = 100.0

            # Форматирование процентов
            percentage_rows = []
            for row in percentages:
                formatted_row = [f"{x:.1f}" for x in row]
                percentage_rows.append(formatted_row)

            # Расчет общей точности
            accuracy = (np.sum(np.diag(confusion_matrix)) / np.sum(confusion_matrix)) * 100

            self.logger.info(f"Процент правильной классификации: {accuracy:.1f}%")

            return df_confusion, percentage_rows, accuracy

        except Exception as e:
            self.logger.error(f"Ошибка при создании матрицы ошибок: {str(e)}")
            raise

    def get_coefficients(self) -> pd.DataFrame:
        """
        Получение коэффициентов дискриминантных функций
        
        Returns:
            pd.DataFrame: таблица коэффициентов
        """
        try:
            self.logger.info("Получение коэффициентов...")
            
            # Получение коэффициентов и размерностей
            n_features = self.X.shape[1]
            n_classes = len(np.unique(self.y))
            n_components = min(n_classes - 1, n_features)
            
            # Создание списка имен переменных
            var_names = [f"VAR{str(i+1).zfill(5)}" for i in range(n_features)]
            
            # Создание DataFrame с коэффициентами
            coef_data = []
            for i in range(n_components):
                row_data = {}
                for j, var_name in enumerate(var_names):
                    row_data[var_name] = self.lda.coef_[i][j]
                coef_data.append(row_data)
            
            df_coef = pd.DataFrame(coef_data, index=[f"Функция {i+1}" for i in range(n_components)])
            
            # Добавление константы (intercept)
            const_data = {}
            for j, var_name in enumerate(var_names):
                const_data[var_name] = self.lda.intercept_[j] if j < len(self.lda.intercept_) else 0.0
            
            const_df = pd.DataFrame([const_data], index=['Константа'])
            
            # Объединение коэффициентов и константы
            df_coef = pd.concat([df_coef, const_df])
            
            # Округление значений
            df_coef = df_coef.round(3)
            
            self.logger.info("Коэффициенты получены")
            return df_coef
            
        except Exception as e:
            self.logger.error(f"Ошибка при получении коэффициентов: {str(e)}")
            raise

    def create_visualization(self) -> plt.Figure:
        """
        Создание визуализации результатов
        
        Returns:
            plt.Figure: объект графика
        """
        try:
            self.logger.info("Создание визуализации...")
            
            fig = plt.figure(figsize=(12, 8))
            
            # Построение точек для каждого класса
            for class_num in np.unique(self.y):
                mask = self.y == class_num
                plt.scatter(
                    self.X_transformed[mask, 0],
                    self.X_transformed[mask, 1] if self.X_transformed.shape[1] > 1 
                    else np.zeros_like(self.X_transformed[mask, 0]),
                    c=[self.colors[(class_num-1) % len(self.colors)]],
                    label=f'Группа {class_num}',
                    alpha=0.7
                )
                
                # Добавление центроидов
                centroid = np.mean(self.X_transformed[mask, :2], axis=0)
                plt.scatter(
                    centroid[0],
                    centroid[1] if self.X_transformed.shape[1] > 1 else 0,
                    c='black',
                    marker='s',
                    s=100
                )
                plt.annotate(
                    f'{class_num}',
                    (centroid[0], centroid[1]),
                    xytext=(5, 5),
                    textcoords='offset points',
                    fontsize=10,
                    bbox=dict(facecolor='white', edgecolor='none', alpha=0.7)
                )
            
            plt.xlabel('Первая каноническая функция')
            plt.ylabel('Вторая каноническая функция')
            plt.title('Канонические дискриминантные функции')
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            
            self.logger.info("Визуализация создана")
            return fig
            
        except Exception as e:
            self.logger.error(f"Ошибка при создании визуализации: {str(e)}")
            raise

    def save_results(self, output_dir: str) -> None:
        """
        Сохранение всех результатов анализа

        Args:
            output_dir (str): директория для сохранения результатов
        """
        try:
            self.logger.info(f"Сохранение результатов в {output_dir}...")

            # Создание директории если её нет
            os.makedirs(output_dir, exist_ok=True)

            # Получение результатов
            confusion_matrix, percentages, accuracy = self.create_confusion_matrix()
            coefficients = self.get_coefficients()

            # Сохранение в Excel
            excel_path = os.path.join(output_dir, 'lda_results.xlsx')
            with pd.ExcelWriter(excel_path, engine='xlsxwriter') as writer:
                workbook = writer.book

                # Форматы для Excel
                header_format = workbook.add_format({
                    'bold': True,
                    'align': 'center',
                    'valign': 'vcenter',
                    'bg_color': '#D9D9D9',
                    'border': 1
                })

                cell_format = workbook.add_format({
                    'align': 'center',
                    'border': 1
                })

                number_format = workbook.add_format({
                    'align': 'center',
                    'border': 1,
                    'num_format': '0.000'
                })

                # 1. Матрица классификации
                worksheet1 = workbook.add_worksheet('Матрица классификации')

                # Записываем заголовки
                headers = ['Исходный', 'Количество'] + \
                         [f'{i+1}.00' for i in range(len(confusion_matrix.columns)-1)] + \
                         ['Всего']
                for col, header in enumerate(headers):
                    worksheet1.write(0, col, header, header_format)
                    worksheet1.set_column(col, col, 15)

                # Записываем данные
                for i, (index, row) in enumerate(confusion_matrix.iterrows()):
                    worksheet1.write(i+1, 0, index, cell_format)
                    worksheet1.write(i+1, 1, row['Всего'], cell_format)
                    for j, val in enumerate(row):
                        worksheet1.write(i+1, j+2, val, cell_format)

                # 2. Проценты классификации
                worksheet2 = workbook.add_worksheet('Проценты')

                # Заголовки
                for col, header in enumerate(headers):
                    worksheet2.write(0, col, header, header_format)
                    worksheet2.set_column(col, col, 15)

                # Данные процентов
                for i, row in enumerate(percentages):
                    worksheet2.write(i+1, 0, f"{i+1}.00", cell_format)
                    worksheet2.write(i+1, 1, confusion_matrix.iloc[i]['Всего'], cell_format)
                    for j, val in enumerate(row):
                        worksheet2.write(i+1, j+2, float(val.replace(',', '.')), number_format)

                # Примечание
                note_row = len(percentages) + 2
                worksheet2.write(
                    note_row, 0,
                    f'* Примечание: {accuracy:.1f}% исходных сгруппированных наблюдений '
                    f'классифицированы правильно.',
                    workbook.add_format({'bold': True})
                )

                # 3. Коэффициенты функций
                worksheet3 = workbook.add_worksheet('Коэффициенты')

                # Записываем заголовки коэффициентов
                worksheet3.write(0, 0, 'Переменная', header_format)
                for i, col in enumerate(coefficients.columns):
                    worksheet3.write(0, i+1, col, header_format)
                    worksheet3.set_column(i+1, i+1, 15)

                # Записываем данные коэффициентов
                for i, (index, row) in enumerate(coefficients.iterrows()):
                    worksheet3.write(i+1, 0, index, cell_format)
                    for j, val in enumerate(row):
                        worksheet3.write(i+1, j+1, val, number_format)

                # Добавляем примечание к коэффициентам
                worksheet3.write(
                    len(coefficients)+1, 0,
                    '*Нестандартизованные коэффициенты',
                    workbook.add_format({'bold': True, 'italic': True})
                )

            # Сохранение визуализации
            fig = self.create_visualization()
            fig.savefig(
                os.path.join(output_dir, 'lda_visualization.png'),
                bbox_inches='tight',
                dpi=300
            )
            plt.close(fig)

            self.logger.info("Результаты успешно сохранены")

        except Exception as e:
            self.logger.error(f"Ошибка при сохранении результатов: {str(e)}")
            raise