import networkx as nx
from streamlit.components.v1 import html
import streamlit as st
import helpers
import logging

# Setup Basic Configuration
st.set_page_config(layout='wide',
                   page_title='STriP: Semantic Similarity of Scientific Papers!',
                   page_icon='πŸ•ΈοΈ'
                   )


logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s: %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')

logger = logging.getLogger('main')


def load_data():
    """Loads the data from the uploaded file.
    """

    st.header('πŸ“‚ Load Data')
    about_load_data = read_md('markdown/load_data.md')
    st.markdown(about_load_data)

    uploaded_file = st.file_uploader("Choose a CSV file",
                                     help='Upload a CSV file with the following columns: Title, Abstract')

    if uploaded_file is not None:
        df = helpers.load_data(uploaded_file)
    else:
        df = helpers.load_data('data.csv')
    data = df.copy()

    # Column Selection. By default, any column called 'title' and 'abstract' are selected
    st.subheader('Select columns to analyze')
    selected_cols = st.multiselect(label='Select one or more columns. All the selected columns are concatenated before analyzing', options=data.columns,
                                   default=[col for col in data.columns if col.lower() in ['title', 'abstract']])

    if not selected_cols:
        st.error('No columns selected! Please select some text columns to analyze')

    data = data[selected_cols]

    # Minor cleanup
    data = data.dropna()

    # Load max 200 rows only
    st.write(f'Number of rows: {len(data)}')
    n_rows = 200
    if len(data) > n_rows:
        data = data.sample(n_rows, random_state=0)
        st.write(f'Only random {n_rows} rows will be analyzed')

    data = data.reset_index(drop=True)

    # Prints
    st.write('First 5 rows of loaded data:')
    st.write(data[selected_cols].head())

    # Combine all selected columns
    if (data is not None) and selected_cols:
        data['Text'] = data[data.columns[0]]
        for column in data.columns[1:]:
            data['Text'] = data['Text'] + '[SEP]' + data[column].astype(str)

    return data, selected_cols


def topic_modeling(data):
    """Runs the topic modeling step.
    """

    st.header('πŸ”₯ Topic Modeling')
    about_topic_modeling = read_md('markdown/topic_modeling.md')
    st.markdown(about_topic_modeling)

    cols = st.columns(3)
    with cols[0]:
        min_topic_size = st.slider('Minimum topic size', key='min_topic_size', min_value=2,
                                   max_value=min(round(len(data)*0.25), 100), step=1, value=min(round(len(data)/25), 10),
                                   help='The minimum size of the topic. Increasing this value will lead to a lower number of clusters/topics.')
    with cols[1]:
        n_gram_range = st.slider('N-gram range', key='n_gram_range', min_value=1,
                                 max_value=3, step=1, value=(1, 2),
                                 help='N-gram range for the topic model')
    with cols[2]:
        st.text('')
        st.text('')
        st.button('Reset Defaults', on_click=helpers.reset_default_topic_sliders, key='reset_topic_sliders',
                  kwargs={'min_topic_size': min(round(len(data)/25), 10), 'n_gram_range': (1, 2)})

    with st.spinner('Topic Modeling'):
        #enable this to route terminal logs to the frontend
        #with helpers.st_stdout("success"), helpers.st_stderr("code"):
        topic_data, topic_model, topics = helpers.topic_modeling(
            data, min_topic_size=min_topic_size, n_gram_range=n_gram_range)

        mapping = {
            'Topic Keywords': topic_model.visualize_barchart,
            'Topic Similarities': topic_model.visualize_heatmap,
            'Topic Hierarchies': topic_model.visualize_hierarchy,
            'Intertopic Distance': topic_model.visualize_topics
        }

        cols = st.columns(3)
        with cols[0]:
            topic_model_vis_option = st.selectbox(
                'Select Topic Modeling Visualization', mapping.keys())
        try:
            fig = mapping[topic_model_vis_option](top_n_topics=10)
            fig.update_layout(title='')
            st.plotly_chart(fig, use_container_width=True)
        except:
            st.warning(
                'No visualization available. Try a lower Minimum topic size!')

    return topic_data, topics


def strip_network(data, topic_data, topics):
    """Generated the STriP network.
    """

    st.header('πŸš€ STriP Network')
    about_stripnet = read_md('markdown/stripnet.md')
    st.markdown(about_stripnet)

    with st.spinner('Cosine Similarity Calculation'):
        cosine_sim_matrix = helpers.cosine_sim(data)

    value, min_value = helpers.calc_optimal_threshold(
        cosine_sim_matrix,
        # 25% is a good value for the number of papers
        max_connections=min(
            helpers.calc_max_connections(len(data), 0.25), 5_000
        )
    )

    cols = st.columns(3)
    with cols[0]:
        threshold = st.slider('Cosine Similarity Threshold', key='threshold', min_value=min_value,
                              max_value=1.0, step=0.01, value=value,
                              help='The minimum cosine similarity between papers to draw a connection. Increasing this value will lead to a lesser connections.')

        neighbors, num_connections = helpers.calc_neighbors(
            cosine_sim_matrix, threshold)
        st.write(f'Number of connections: {num_connections}')

    with cols[1]:
        st.text('')
        st.text('')
        st.button('Reset Defaults', on_click=helpers.reset_default_threshold_slider, key='reset_threshold',
                  kwargs={'threshold': value})

    with st.spinner('Network Generation'):
        nx_net, pyvis_net = helpers.network_plot(
            topic_data, topics, neighbors)

        # Save and read graph as HTML file (on Streamlit Sharing)
        try:
            path = '/tmp'
            pyvis_net.save_graph(f'{path}/pyvis_graph.html')
            HtmlFile = open(f'{path}/pyvis_graph.html',
                            'r', encoding='utf-8')

        # Save and read graph as HTML file (locally)
        except:
            path = '/html_files'
            pyvis_net.save_graph(f'{path}/pyvis_graph.html')
            HtmlFile = open(f'{path}/pyvis_graph.html',
                            'r', encoding='utf-8')

        # Load HTML file in HTML component for display on Streamlit page
        html(HtmlFile.read(), height=800)

        return nx_net


def network_centrality(nx_net, topic_data):
    """Finds most important papers using network centrality measures.
    """

    st.header('πŸ… Most Important Papers')
    about_centrality = read_md('markdown/centrality.md')
    st.markdown(about_centrality)

    centrality_mapping = {
        'Betweenness Centrality': nx.betweenness_centrality,
        'Closeness Centrality': nx.closeness_centrality,
        'Degree Centrality': nx.degree_centrality,
        'Eigenvector Centrality': nx.eigenvector_centrality,
    }

    cols = st.columns(3)
    with cols[0]:
        centrality_option = st.selectbox(
            'Select Centrality Measure', centrality_mapping.keys())

    # Calculate centrality
    centrality = centrality_mapping[centrality_option](nx_net)

    cols = st.columns([1, 10, 1])
    with cols[1]:
        with st.spinner('Network Centrality Calculation'):
            fig = helpers.network_centrality(
                topic_data, centrality, centrality_option)
            st.plotly_chart(fig, use_container_width=True)


def read_md(file_path):
    """Reads a markdown file and returns the contents.
    """
    with open(file_path, 'r') as f:
        content = f.read()

    return content


def main():
    st.title('STriPNet: Semantic Similarity of Scientific Papers!')
    about_stripnet = read_md('markdown/about_stripnet.md')
    st.markdown(about_stripnet)

    logger.info('========== Step1: Loading data ==========')
    data, selected_cols = load_data()

    if (data is not None) and selected_cols:
        logger.info('========== Step2: Topic modeling ==========')
        topic_data, topics = topic_modeling(data)

        logger.info('========== Step3: STriP Network ==========')
        nx_net = strip_network(data, topic_data, topics)

        logger.info('========== Step4: Network Centrality ==========')
        network_centrality(nx_net, topic_data)

    about_me = read_md('markdown/about_me.md')
    st.markdown(about_me)


if __name__ == '__main__':
    main()