import gradio as gr
from gradio_calendar import Calendar
import datetime
import requests
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import calendar

url = "https://api.brevo.com/v3/smtp/statistics/events"
headers = {
    'accept': 'application/json',
    'api-key': os.getenv('BREVO_API_KEY')
    }
events_cols = [
    'collectivite',
    'email',
    'date',
    'subject',
    'event',
    'g_men',
    'g_conso',
    'ressource',
]

# Recursively filter logs to get all logs
def get_events(limit=5000, offset=0, **kwargs):
    try:
        # Get all your transactional email activity (unaggregated events)
        params = {'limit': limit, 'offset': offset}
        params.update(kwargs)
        api_response = requests.get(url=url, params=params, headers=headers)
        response = api_response.json()
        events = response['events']
        print(f"- Found {len(events)} events (limit={limit} - offset={offset}).")

    except Exception as e:
        print("Exception when calling TransactionalEmailsApi->get_email_event_report: %s\n" % e)
        events = []

    if len(events) < limit:
        # print("All logs found.")
        return events
    else:
        return events + get_events(limit=limit, offset=offset+limit, **kwargs)

def get_all_events(year, months):
    df = pd.DataFrame()
    for month in months:
        startDate = f'{year}-{month:02d}-01'
        endDate = f'{year}-{month:02d}-{calendar.monthrange(year, month)[1]}'
        today = datetime.date.today().strftime('%Y-%m-%d')
        if endDate > today:
            endDate = today

        if startDate < today:
            print(f'{year}.{month:02d}:')
            events = get_events(startDate=startDate, endDate=endDate, sort='asc')
            df = pd.concat([df, pd.DataFrame(events)])
            print(f'=> Found {len(events)} month events.\n')
    print(f'=> Found {len(df)} total events.')
    return df

# Add stratification datas
def get_stratification_datas(events_df, users_df):
    # Join dataset
    dataset = users_df.merge(events_df, left_on='mail', right_on='email', how='left')
    filter = dataset['messageId'].notna()
    dataset = dataset[filter]
    # dataset = dataset.drop(columns='date').drop_duplicates().reset_index(drop=True)
    return dataset

# Filter event
def filter_datas(df, filters):
    # print("Filters:", filters)
    if not filters:
        return df
    k, v = filters[0]
    # print(f"Filter on {k} = {v}")
    filter = df[k].isin(v)
    return filter_datas(df[filter], filters[1:])

# Filter on date
def filter_dates(df, start_date, end_date):
    start = np.datetime64(start_date) if start_date else np.datetime64(events_df['date'].min())
    end = np.datetime64(end_date) if end_date else np.datetime64(events_df['date'].max())
    filter = (df['date']>=start) & (df['date']<end)
    return df[filter].drop(columns='date').drop_duplicates().reset_index(drop=True)

strate_dict = {
    'g_men': 'Ménage',
    'g_conso': 'Consommation'
}

events_dict = {
    'delivered': 'Envoyé',
    'opened': 'Ouvert',
    'loadedByProxy': 'Ouvert (Apple)',
    'clicks': 'Cliqué',
    'unsubscribed': 'Désinscrit'
}

# Load users
filename = "dataset.csv"
users_df = pd.read_csv(filename)
print(f'- Loaded users datas: {users_df.shape}.')

# Load consumptions
filename = "consos_mois.csv"
consos_df = pd.read_csv(filename)
# consos_df['periode'] = consos_df['periode'].astype(str).replace('2023.1', '2023.10')
print(f'- Loaded consos datas: {consos_df.shape}.')

# Load log events
print('Load events dataframe...')
events_df = get_all_events(year=2024,
                           months=list(range(5, 11)))
print(events_df.columns)
# Remove test users
filter = events_df['email'].isin(['huynhdoo@gmail.com, anne@thegoodtrack', 'tarteret.astrid@gmail.com', 'ma.morel.pro@gmail.com', 'laura.stabile@grandchambery.fr', 'cyrille.girel@grandchambery.fr'])
events_df = events_df[~filter].reset_index()

# Remove impact mails
filter = events_df['templateId'].isin([39, 40, 42, 43])
events_df = events_df[~filter].reset_index()

# Cast datetime column
events_df['date'] = pd.to_datetime(events_df['date']).dt.tz_localize(None)
print('Build dataset...')
dataset = get_stratification_datas(events_df, users_df)

# Event col
filter = dataset['event'].isin(events_dict.keys())
dataset = dataset[filter]
dataset['event'] = dataset['event'].map(events_dict)

# Strates col
# dataset.rename(columns=strate_dict, inplace=True)
print(f'Dataset ready: {dataset.shape}.')

def mails(community, subjects, start_date, end_date, gmens, gconsos, ressources):
    # Filter dates
    events_df = filter_dates(dataset, start_date, end_date)

    # Filter events
    filters = []
    if community:
        filters.append(('collectivite', community))
    if subjects:
        filters.append(('subject', subjects))
    if gmens:
        filters.append(('g_men', gmens))
    if gconsos:
        filters.append(('g_conso', gconsos))
    if ressources:
        filters.append(('ressource', ressources))

    events_df = filter_datas(events_df, filters)
    events_df = events_df[['email', 'subject', 'event']].drop_duplicates()
    events = events_df['event'].value_counts()

    # Build fig
    #x = events.index
    #y = events.values
    #fig, ax = plt.subplots()
    #bars = ax.bar(x, y, color=sns.palettes.mpl_palette('Dark2'))
    #ax.bar_label(bars)

    # Get logs
    filter = events_df['event'] == 'Envoyé'
    total = len(events_df[filter])
    if total > 0:
        logs = {f'{k} ({v})':v/total for k, v in events.to_dict().items()}
    else:
        logs = {f'{k} ({v})':v for k, v in events.to_dict().items()}
    return logs

def consos(community, periods, group):
    # Filter consos
    filters = []
    if community:
        filters.append(('collectivite', community))
    if periods:
        filters.append(('periode', periods))
    if group:
        filters.append(('group', group))

    datas = filter_datas(consos_df, filters)
    width = 0.4  # the width of the bars
    multiplier = 0

    periods = datas['periode'].astype(str).unique()
    x = np.arange(len(periods))

    test_group = datas[datas['group']=='test'][['periode', 'consommation']]
    control_group = datas[datas['group']=='control'][['periode', 'consommation']]


    # fig_sum
    fig_sum, ax_sum = plt.subplots(layout='constrained', figsize=(max(len(x), 10), 10))
    group_consos_sum = {
        'test': test_group.groupby('periode').sum()['consommation'].tolist(),
        'control': control_group.groupby('periode').sum()['consommation'].tolist(),
    }
    for group, consumptions in group_consos_sum.items():
        offset = width * multiplier
        rects = ax_sum.bar(x + offset, [round(c/(1000*1000)) for c in consumptions], width, label=group)
        ax_sum.bar_label(rects, padding=3)
        multiplier += 1
    ax_sum.set_title('Evolution de la consommation cumulée par groupe test/contrôle')
    ax_sum.set_ylabel('Consommation (dam3 = 1000 m3)')
    ax_sum.set_xlabel('Période')
    ax_sum.set_xticks(x + width, periods)
    ax_sum.legend(loc='upper right', ncols=2)

    # fig_med
    width = 0.4  # the width of the bars
    multiplier = 0
    fig_med, ax_med = plt.subplots(layout='constrained', figsize=(max(len(x), 10), 5))
    group_consos_med = {
        'test': test_group.groupby('periode').median()['consommation'].tolist(),
        'control': control_group.groupby('periode').median()['consommation'].tolist(),
    }
    for group, consumptions in group_consos_med.items():
        offset = width * multiplier
        rects = ax_med.bar(x + offset, [round(c/1000, 1) for c in consumptions], width, label=group)
        ax_med.bar_label(rects, padding=3)
        multiplier += 1
    ax_med.set_title('Evolution de la consommation médiane par groupe test/contrôle')
    ax_med.set_ylabel('Consommation (m3)')
    ax_med.set_xlabel('Période')
    ax_med.set_xticks(x + width, periods)
    ax_med.legend(loc='upper right', ncols=2)

    # fig_mean
    width = 0.4  # the width of the bars
    multiplier = 0
    fig_mean, ax_mean = plt.subplots(layout='constrained', figsize=(max(len(x), 10), 5))
    group_consos_mean = {
        'test': test_group.groupby('periode').mean()['consommation'].tolist(),
        'control': control_group.groupby('periode').mean()['consommation'].tolist(),
    }
    for group, consumptions in group_consos_mean.items():
        offset = width * multiplier
        rects = ax_mean.bar(x + offset, [round(c/1000, 1) for c in consumptions], width, label=group)
        ax_mean.bar_label(rects, padding=3)
        multiplier += 1
    ax_mean.set_title('Evolution de la consommation moyenne par groupe test/contrôle')
    ax_mean.set_ylabel('Consommation (m3)')
    ax_mean.set_xlabel('Période')
    ax_mean.set_xticks(x + width, periods)
    ax_mean.legend(loc='upper right', ncols=2)
    return fig_sum, fig_med, fig_mean

def results(subjects, start_date, end_date, gmens, gconsos, ressources):
    community = None
    campaign = mails(community, subjects, start_date, end_date, gmens, gconsos, ressources)
    start = start_date if start_date else events_df['date'].min()
    end = end_date if end_date else events_df['date'].max()
    periods = [start.strftime("%Y.%m"), end.strftime("%Y.%m")]
    group = None
    periods= None
    consos_sum, consos_med, consos_mean = consos(community, periods, group)
    return campaign, consos_med, consos_sum #, consos_sum, consos_mean

main_interface = gr.Interface(
    fn=results,
    inputs=[
        #gr.Dropdown(choices=list(dataset['collectivite'].unique()), multiselect=True, allow_custom_value=False, label="Collectivité", info="Choisir une ou plusieurs collectivités"),
        gr.Dropdown(choices=list(dataset['subject'].unique()), multiselect=True, allow_custom_value=False, label="Mél envoyé", info="Choisir un ou plusieurs emails envoyés"),
        Calendar(type="datetime", value=events_df['date'].min(), label="Date de début", info="Choisir une date de début"),
        Calendar(type="datetime", value=events_df['date'].max(), label="Date de fin", info="Choisir une date de fin"),
        #gr.Dropdown(choices=list(dataset['event'].unique()), multiselect=True, allow_custom_value=False, label="Evénement", info="Choisir un ou plusieurs événements"),
        gr.Dropdown(choices=list(dataset['g_men'].unique()), multiselect=True, allow_custom_value=False, label="Type de ménage", info="Choisir un ou plusieurs groupes"),
        gr.Dropdown(choices=list(dataset['g_conso'].unique()), multiselect=True, allow_custom_value=False, label="Type de consommation", info="Choisir un ou plusieurs groupes"),
        gr.Dropdown(choices=list(dataset['ressource'].unique()), multiselect=True, allow_custom_value=False, label="Ressource", info="Choisir une ou plusieurs ressources"),
        #gr.Textbox(label="Recherche", info="Votre recherche (mots séparés par des virgules pour des recherches multiples)"),
        #gr.Dropdown(choices=list(origins.keys()), value=list(origins.keys())[0], label="Origine", info="Choisir un type de donnée à interroger"),
        #gr.Number(value=10, label="Nombre de résultats", info="Nombre de résultats attendus")
    ],
    outputs=[
        gr.Label(label="Méls"),
        gr.Plot(label="Médian"),
        gr.Plot(label="Cumul"),
        # gr.Plot(label="Moyenne"),
    ],
    examples=[
    ],
    cache_examples=False
)

#gradio_app = gr.TabbedInterface([main_interface, consos_interface], tab_names=['Méls', 'Consommations'], title="Préservons l'eau - Tableau de bord")
gradio_app = main_interface
if __name__ == "__main__":
    gradio_app.launch(auth=("alerte", "renforcée"), share=True)