import gradio as gr

from functools import partial

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from sentence_transformers import SentenceTransformer
import torch
import tqdm

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import penman
from collections import Counter, defaultdict
import networkx as nx
from networkx.drawing.nx_agraph import pygraphviz_layout

class FramingLabels:
    def __init__(self, base_model, candidate_labels, batch_size=16):
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.base_pipeline = pipeline("zero-shot-classification", model=base_model, device=device)
        self.candidate_labels = candidate_labels
        self.classifier = partial(self.base_pipeline, candidate_labels=candidate_labels, multi_label=True, batch_size=batch_size)

    def order_scores(self, dic):
        indices_order = [dic["labels"].index(l) for l in self.candidate_labels]
        scores_ordered = np.array(dic["scores"])[indices_order].tolist()
        return scores_ordered

    def get_ordered_scores(self, sequence_to_classify):
        if type(sequence_to_classify) == list:
            res = []
            for out in tqdm.tqdm(self.classifier(sequence_to_classify)):
                res.append(out)
        else:
            res = self.classifier(sequence_to_classify)
        if type(res) == list:
            scores_ordered = list(map(self.order_scores, res))
            scores_ordered = list(map(list, zip(*scores_ordered)))  # reorder
        else:
            scores_ordered = self.order_scores(res)
        return scores_ordered

    def get_label_names(self):
        label_names = [l.split(":")[0].split(" ")[0] for l in self.candidate_labels]
        return label_names

    def __call__(self, sequence_to_classify):
        scores = self.get_ordered_scores(sequence_to_classify)
        label_names = self.get_label_names()
        return dict(zip(label_names, scores))

    def visualize(self, name_to_score_dict, threshold=0.5, **kwargs):
        fig, ax = plt.subplots()

        cp = sns.color_palette()

        scores_ordered = list(name_to_score_dict.values())
        label_names = list(name_to_score_dict.keys())

        colors = [cp[0] if s > 0.5 else cp[1] for s in scores_ordered]
        ax.barh(label_names[::-1], scores_ordered[::-1], color=colors[::-1], **kwargs)
        plt.xlim(left=0)
        plt.tight_layout()
        return fig, ax

class FramingDimensions:
    def __init__(self, base_model, dimensions, pole_names):
        self.encoder = SentenceTransformer(base_model)
        self.dimensions = dimensions
        self.dim_embs = self.encoder.encode(dimensions)
        self.pole_names = pole_names
        self.axis_names = list(map(lambda x: x[0] + "/" + x[1], pole_names))
        axis_embs = []
        for pole1, pole2 in pole_names:
            p1 = self.get_dimension_names().index(pole1)
            p2 = self.get_dimension_names().index(pole2)
            axis_emb = self.dim_embs[p1] - self.dim_embs[p2]
            axis_embs.append(axis_emb)
        self.axis_embs = np.stack(axis_embs)

    def get_dimension_names(self):
        dimension_names = [l.split(":")[0].split(" ")[0] for l in self.dimensions]
        return dimension_names

    def __call__(self, sequence_to_align):
        embs = self.encoder.encode(sequence_to_align)
        scores = embs @ self.axis_embs.T
        named_scores = dict(zip(self.pole_names, scores.T))
        return named_scores

    def visualize(self, align_scores_df, **kwargs):
        name_left = align_scores_df.columns.map(lambda x: x[1])
        name_right = align_scores_df.columns.map(lambda x: x[0])
        bias = align_scores_df.mean()
        color = ["b" if x > 0 else "r" for x in bias]
        inten = (align_scores_df.var().fillna(0)+0.001)*50_000
        bounds = bias.abs().max()*1.1

        fig = plt.figure()
        ax = fig.add_subplot(111)
        plt.scatter(x=bias, y=name_left, s=inten, c=color)
        plt.axvline(0)
        plt.xlim(-bounds, bounds)
        plt.gca().invert_yaxis()
        axi = ax.twinx()
        axi.set_ylim(ax.get_ylim())
        axi.set_yticks(ax.get_yticks(), labels=name_right)
        plt.tight_layout()
        return fig

class FramingStructure:
    def __init__(self, base_model, roles=None):
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.translator = pipeline("text2text-generation", base_model, device=device, max_length=300)

    def __call__(self, sequence_to_translate):
        res = self.translator(sequence_to_translate)
        def try_decode(x):
            try:
                return penman.decode(x["generated_text"])
            except:
                # print(f"Decode error for {res}")
                return None
        graphs = list(filter(lambda item: item is not None, [try_decode(x) for x in res]))
        return graphs

    def visualize(self, decoded_graphs, min_node_threshold=1, **kwargs):
        cnt = Counter()

        for gen_text in decoded_graphs:
            amr = gen_text.triples
            amr = list(filter(lambda x: x[2] is not None, amr))
            amr = list(map(lambda x: (x[0], x[1].replace(":", ""), x[2]), amr))
            def trim_distinction_end(x):
                x = x.split("_")[0]
                return x
            amr = list(map(lambda x: (trim_distinction_end(x[0]), x[1], trim_distinction_end(x[2])), amr))
            cnt.update(amr)

        G = nx.DiGraph()

        color_map = defaultdict(lambda: "k", {
            "ARG0": "y",
            "ARG1": "r",
            "ARG2": "g",
            "ARG3": "b"
        })

        for entry, num in cnt.items():
            if not G.has_node(entry[0]):
                G.add_node(entry[0], weight=0)
            if not G.has_node(entry[2]):
                G.add_node(entry[2], weight=0)
            G.nodes[entry[0]]["weight"] += num
            G.nodes[entry[2]]["weight"] += num
            G.add_edge(entry[0], entry[2], role=entry[1], weight=num, color=color_map[entry[1]])

        G_sub = nx.subgraph_view(G, filter_node=lambda n: G.nodes[n]["weight"] >= min_node_threshold)

        node_sizes = [x * 100 for x in nx.get_node_attributes(G_sub,'weight').values()]
        edge_colors = nx.get_edge_attributes(G_sub,'color').values()

        fig = plt.figure()

        pos = pygraphviz_layout(G_sub, prog="dot")
        nx.draw_networkx(G_sub, pos, node_size=node_sizes, edge_color=edge_colors)
        nx.draw_networkx_labels(G_sub, pos)
        nx.draw_networkx_edge_labels(G_sub, pos, edge_labels=nx.get_edge_attributes(G_sub, "role"))
        plt.tight_layout()
        return fig

# Specify the models
base_model_1 = "facebook/bart-large-mnli"
base_model_2 = 'all-mpnet-base-v2'
base_model_3 = "Iseratho/model_parse_xfm_bart_base-v0_1_0"
# https://homes.cs.washington.edu/~nasmith/papers/card+boydstun+gross+resnik+smith.acl15.pdf
candidate_labels = [
    "Economic: costs, benefits, or other financial implications",
    "Capacity and resources: availability of physical, human or financial resources, and capacity of current systems",
    "Morality: religious or ethical implications",
    "Fairness and equality: balance or distribution of rights, responsibilities, and resources",
    "Legality, constitutionality and jurisprudence: rights, freedoms, and authority of individuals, corporations, and government",
    "Policy prescription and evaluation: discussion of specific policies aimed at addressing problems",
    "Crime and punishment: effectiveness and implications of laws and their enforcement",
    "Security and defense: threats to welfare of the individual, community, or nation",
    "Health and safety: health care, sanitation, public safety",
    "Quality of life: threats and opportunities for the individual’s wealth, happiness, and well-being",
    "Cultural identity: traditions, customs, or values of a social group in relation to a policy issue",
    "Public opinion: attitudes and opinions of the general public, including polling and demographics",
    "Political: considerations related to politics and politicians, including lobbying, elections, and attempts to sway voters",
    "External regulation and reputation: international reputation or foreign policy of the U.S.",
    "Other: any coherent group of frames not covered by the above categories",
]

# https://osf.io/xakyw
dimensions = [
    "Care: ...acted with kindness, compassion, or empathy, or nurtured another person.",
    "Harm: ...acted with cruelty, or hurt or harmed another person/animal and caused suffering.",
    "Fairness: ...acted in a fair manner, promoting equality, justice, or rights.",
    "Cheating: ...was unfair or cheated, or caused an injustice or engaged in fraud.",
    "Loyalty: ...acted with fidelity, or as a team player, or was loyal or patriotic.",
    "Betrayal: ...acted disloyal, betrayed someone, was disloyal, or was a traitor.",
    "Authority: ...obeyed, or acted with respect for authority or tradition.",
    "Subversion: ...disobeyed or showed disrespect, or engaged in subversion or caused chaos.",
    "Sanctity: ...acted in a way that was wholesome or sacred, or displayed purity or sanctity.",
    "Degradation: ...was depraved, degrading, impure, or unnatural.",
]
pole_names = [
    ("Care", "Harm"),
    ("Fairness", "Cheating"),
    ("Loyalty", "Betrayal"),
    ("Authority", "Subversion"),
    ("Sanctity", "Degradation"),
]

framing_label_model = FramingLabels(base_model_1, candidate_labels)
framing_dimen_model = FramingDimensions(base_model_2, dimensions, pole_names)
framing_struc_model = FramingStructure(base_model_3)

def framing_multi(texts, min_node_threshold=1):
    res1 = pd.DataFrame(framing_label_model(texts))
    fig1, _ = framing_label_model.visualize(res1.mean().to_dict(), xerr=res1.sem())
    fig2 = framing_dimen_model.visualize(pd.DataFrame(framing_dimen_model(texts)))
    fig3 = framing_struc_model.visualize(framing_struc_model(texts), min_node_threshold=min_node_threshold)

    return fig1, fig2, fig3

def framing_single(text, min_node_threshold=1):
    fig1, _ = framing_label_model.visualize(framing_label_model(text))
    fig2 = framing_dimen_model.visualize(pd.DataFrame({k: [v] for k, v in framing_dimen_model(text).items()}))
    fig3 = framing_struc_model.visualize(framing_struc_model(text), min_node_threshold=min_node_threshold)

    return fig1, fig2, fig3

async def framing_textbox(text, split, min_node_threshold):
    texts = text.split("\n")
    if split and len(texts) > 1:
        return framing_multi(texts, min_node_threshold)
    return framing_single(text, min_node_threshold)

async def framing_file(file_obj, split, min_node_threshold):
    with open(file_obj.name, "r") as f:
        if split:
            texts = f.readlines()
            if len(texts) > 1:
                return framing_multi(texts, min_node_threshold)
            else:
                text = texts[0]
        else:
            text = f.read()
    return framing_single(text, min_node_threshold)

example_list = [["In 2010, CFCs were banned internationally due to their harmful effect on the ozone layer.", False, 1],
                ["In 2021, doctors prevented the spread of the virus by vaccinating with Pfizer.", False, 1],
                ["We must fight for our freedom.", False, 1],
                ["The government prevents our freedom.", False, 1],
                ["They prevent the spread.", False, 1],
                ["We fight the virus.", False, 1],
                ["I believe that we should act now.\nThere is no time to waste.", True, 1],
                ]

description = """A simple tool that helps you find (discover and detect) frames in text.

Note that due to the computation time required for underlying Transformer models, only short texts are recommended."""
article=""""Check out the preliminary article in the [Web Conference Symposium](https://dl.acm.org/doi/pdf/10.1145/3543873.3587534), will be updated to currently in review article after publication.

<details>
<summary>Explanation of labels:</summary>
<ul>
<li>Economic: costs, benefits, or other financial implications</li>
<li>Capacity and resources: availability of physical, human or financial resources, and capacity of current systems</li>
<li>Morality: religious or ethical implications</li>
<li>Fairness and equality: balance or distribution of rights, responsibilities, and resources</li>
<li>Legality, constitutionality and jurisprudence: rights, freedoms, and authority of individuals, corporations, and government</li>
<li>Policy prescription and evaluation: discussion of specific policies aimed at addressing problems</li>
<li>Crime and punishment: effectiveness and implications of laws and their enforcement</li>
<li>Security and defense: threats to welfare of the individual, community, or nation</li>
<li>Health and safety: health care, sanitation, public safety</li>
<li>Quality of life: threats and opportunities for the individual’s wealth, happiness, and well-being</li>
<li>Cultural identity: traditions, customs, or values of a social group in relation to a policy issue</li>
<li>Public opinion: attitudes and opinions of the general public, including polling and demographics</li>
<li>Political: considerations related to politics and politicians, including lobbying, elections, and attempts to sway voters</li>
<li>External regulation and reputation: international reputation or foreign policy of the U.S.</li>
<li>Other: any coherent group of frames not covered by the above categories</li>
</ul>
</details>

<details>
<summary>Explanation of dimensions: </summary>
<ul>
<li>Care: ...acted with kindness, compassion, or empathy, or nurtured another person.</li>
<li>Harm: ...acted with cruelty, or hurt or harmed another person/animal and caused suffering.</li>
<li>Fairness: ...acted in a fair manner, promoting equality, justice, or rights.</li>
<li>Cheating: ...was unfair or cheated, or caused an injustice or engaged in fraud.</li>
<li>Loyalty: ...acted with fidelity, or as a team player, or was loyal or patriotic.</li>
<li>Betrayal: ...acted disloyal, betrayed someone, was disloyal, or was a traitor.</li>
<li>Authority: ...obeyed, or acted with respect for authority or tradition.</li>
<li>Subversion: ...disobeyed or showed disrespect, or engaged in subversion or caused chaos.</li>
<li>Sanctity: ...acted in a way that was wholesome or sacred, or displayed purity or sanctity.</li>
<li>Degradation: ...was depraved, degrading, impure, or unnatural.</li>
</ul>
</details>

Document of structure (AMR) explanation: [AMR Specification](https://github.com/amrisi/amr-guidelines/blob/master/amr.md)
"""

textbox_inferface = gr.Interface(fn=framing_textbox,
                    inputs=[
                        gr.Textbox(label="Text to analyze."),
                        gr.Checkbox(True, label="Split on newlines? (To enter newlines type shift+Enter)"),
                        gr.Number(1, label="Min node threshold for framing structure.")
                    ],
                    description=description,
                    examples=example_list,
                    article=article,
                    outputs=[gr.Plot(label="Label"),
                             gr.Plot(label="Dimensions"),
                             gr.Plot(label="Structure")
                            ])

file_interface = gr.Interface(fn=framing_file,
                    inputs=[
                        gr.File(label="File of texts to analyze."),
                        gr.Checkbox(True, label="Split on newlines?"),
                        gr.Number(1, label="Min node threshold for framing structure."),
                    ],
                    description=description,
                    article=article,
                    outputs=[gr.Plot(label="Label"),
                             gr.Plot(label="Dimensions"),
                             gr.Plot(label="Structure")])

demo = gr.TabbedInterface([textbox_inferface, file_interface], 
                        tab_names=["Single Mode", "File Mode"],
                        title="FrameFinder",)

demo.launch()