import pandas as pd
import matplotlib.pyplot as plt
import gradio as gr
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX


def plot_graph(data, algorithm):
    df = pd.read_csv(data)

    columns = df.columns.values

    if len(columns) < 2:
        raise gr.Error('Неверная структура данных. Ожидается второй столбец value.')

    df['Date'] = pd.to_datetime(df[columns[0]])
    df = df.groupby(pd.Grouper(key='Date', freq='ME'))[columns[1]].sum().reset_index()
    df.set_index('Date', inplace=True)

    if algorithm == 'Exponential Smoothing':
        if len(df) < 24:
            raise gr.Error("Для Exponential Smoothing нужны данные за как минимум 24 месяца.")
        model = ExponentialSmoothing(df[columns[1]], seasonal_periods=12, trend="add", seasonal="add")
        model_fit = model.fit()
    elif algorithm == 'ARIMA':
        model = ARIMA(df[columns[1]], order=(1, 1, 1), seasonal_order=(1, 1, 1, 12))
        model_fit = model.fit()
    elif algorithm == 'SARIMA':
        model = SARIMAX(df[columns[1]], order=(1, 1, 1), seasonal_order=(1, 1, 1, 12))
        model_fit = model.fit(disp=False)

    last_date = df.index[-1]
    forecast_dates = pd.date_range(start=last_date, periods=101, freq='MS')[1:]
    prediction = model_fit.forecast(steps=100)

    plt.figure(figsize=(10, 5))
    plt.plot(df[columns[1]], label=columns[1])
    plt.plot(forecast_dates, prediction, label="Прогноз")
    plt.title(f'Прогноз  {columns[1]} на следующие 100 месяцев')
    plt.legend()

    return plt


if __name__ == "__main__":
    iface = gr.Interface(fn=plot_graph,
                         inputs=[gr.File(label="\'Date - Value\'. Example: 2010-01-01,100"),
                                 gr.Radio(["Exponential Smoothing", "ARIMA", "SARIMA"],
                                          label='Выберите алгоритм')],
                         outputs="plot"
                         )

    iface.launch()