from typing import Dict, List, Any
from scipy.special import softmax
import numpy as np
import weakref
import re
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')

from utils import clean_str, clean_str_nopunct
import torch
from utils import MultiHeadModel, BertInputBuilder, get_num_words, MATH_PREFIXES, MATH_WORDS

import transformers
from transformers import BertTokenizer, BertForSequenceClassification
from transformers.utils import logging

transformers.logging.set_verbosity_debug()

UPTAKE_MODEL = 'ddemszky/uptake-model'
REASONING_MODEL = 'ddemszky/student-reasoning'
QUESTION_MODEL = 'ddemszky/question-detection'
FOCUSING_QUESTION_MODEL = 'ddemszky/focusing-questions'


class Utterance:
    def __init__(self, speaker, text, uid=None,
                 transcript=None, starttime=None, endtime=None, **kwargs):
        self.speaker = speaker
        self.text = text
        self.uid = uid
        self.starttime = starttime
        self.endtime = endtime
        self.transcript = weakref.ref(transcript) if transcript else None
        self.props = kwargs
        self.role = None
        self.word_count = self.get_num_words()
        self.timestamp = [starttime, endtime]
        if starttime is not None and endtime is not None:                   
            self.unit_measure = endtime - starttime
        else:
            self.unit_measure = None
        self.aggregate_unit_measure = endtime  
        self.num_math_terms = None
        self.math_terms = None

        # moments
        self.uptake = None
        self.reasoning = None
        self.question = None
        self.focusing_question = None 

    def get_clean_text(self, remove_punct=False):
        if remove_punct:
            return clean_str_nopunct(self.text)
        return clean_str(self.text)

    def get_num_words(self):
        return get_num_words(self.text)

    def to_dict(self):
        return {
            'speaker': self.speaker,
            'text': self.text,
            'uid': self.uid,
            'starttime': self.starttime,
            'endtime': self.endtime,
            'uptake': self.uptake,
            'reasoning': self.reasoning,
            'question':  self.question,
            'focusingQuestion': self.focusing_question,
            'numMathTerms': self.num_math_terms,
            'mathTerms': self.math_terms,
            **self.props
        }

    def to_talk_timeline_dict(self):
        return{
            'speaker': self.speaker,
            'text': self.text,
            'uid': self.uid,
            'role': self.role,
            'timestamp': self.timestamp,
            'moments': {'reasoning': True if self.reasoning else False, 'questioning': True if self.question else False, 'uptake': True if self.uptake else False, 'focusingQuestion': True if self.focusing_question else False},
            'unitMeasure': self.unit_measure,
            'aggregateUnitMeasure': self.aggregate_unit_measure,
            'wordCount': self.word_count,
            'numMathTerms': self.num_math_terms,
            'mathTerms': self.math_terms,
        }

    def __repr__(self):
        return f"Utterance(speaker='{self.speaker}'," \
               f"text='{self.text}', uid={self.uid}," \
               f"starttime={self.starttime}, endtime={self.endtime}, props={self.props})"


class Transcript:
    def __init__(self, **kwargs):
        self.utterances = []
        self.params = kwargs

    def add_utterance(self, utterance):
        utterance.transcript = weakref.ref(self)
        self.utterances.append(utterance)

    def get_idx(self, idx):
        if idx >= len(self.utterances):
            return None
        return self.utterances[idx]

    def get_uid(self, uid):
        for utt in self.utterances:
            if utt.uid == uid:
                return utt
        return None

    def length(self):
        return len(self.utterances)

    def update_utterance_roles(self, uptake_speaker):
        for utt in self.utterances:
            if (utt.speaker == uptake_speaker):
                utt.role = 'teacher'
            else:
                utt.role = 'student'

    def get_talk_distribution_and_length(self, uptake_speaker):
        if ((uptake_speaker is None)):
            return None
        teacher_words = 0
        teacher_utt_count = 0
        student_words = 0
        student_utt_count = 0
        for utt in self.utterances:
            if (utt.speaker == uptake_speaker):
                utt.role = 'teacher'
                teacher_words += utt.get_num_words()
                teacher_utt_count += 1
            else:
                utt.role = 'student'
                student_words += utt.get_num_words()
                student_utt_count += 1
        if teacher_words + student_words > 0:         
            teacher_percentage = round(
                (teacher_words / (teacher_words + student_words)) * 100)
            student_percentage = 100 - teacher_percentage
        else:
            teacher_percentage = student_percentage = 0
        avg_teacher_length = teacher_words / teacher_utt_count if teacher_utt_count > 0 else 0
        avg_student_length = student_words / student_utt_count if student_utt_count > 0 else 0
        return {'teacher': teacher_percentage, 'student': student_percentage}, {'teacher': avg_teacher_length, 'student': avg_student_length}

    def get_word_clouds(self):
        teacher_dict = {}
        student_dict = {}
        uptake_teacher_dict = {}
        stop_words = stopwords.words('english')
        for utt in self.utterances:
            words = (utt.get_clean_text(remove_punct=True)).split(' ')            
            for word in words:
                if word in stop_words or word in ['inaudible', 'crosstalk']: continue
                # handle uptake case
                if utt.role == 'teacher':
                    if utt.uptake == 1:
                        if word not in uptake_teacher_dict:
                            uptake_teacher_dict[word] = 0
                        uptake_teacher_dict[word] += 1
                # ignore math words so they don't get tagged as general
                if any(math_word in word for math_word in utt.math_terms): continue
                if utt.role == 'teacher':
                    if word not in teacher_dict:
                        teacher_dict[word] = 0
                    teacher_dict[word] += 1

                else:
                    if word not in student_dict:
                        student_dict[word] = 0
                    student_dict[word] += 1
        dict_list = []
        uptake_dict_list = []
        teacher_dict_list = []
        student_dict_list = []
        for word in uptake_teacher_dict.keys():
            uptake_dict_list.append({'text': word, 'value': uptake_teacher_dict[word], 'category': 'teacher'})
        for word in teacher_dict.keys():
            teacher_dict_list.append(
                {'text': word, 'value': teacher_dict[word], 'category': 'general'})
            dict_list.append({'text': word, 'value': teacher_dict[word], 'category': 'general'})
        for word in student_dict.keys():
            student_dict_list.append(
                {'text': word, 'value': student_dict[word], 'category': 'general'})
            dict_list.append({'text': word, 'value': student_dict[word], 'category': 'general'})
        sorted_dict_list = sorted(dict_list, key=lambda x: x['value'], reverse=True)
        sorted_uptake_dict_list = sorted(uptake_dict_list, key=lambda x: x['value'], reverse=True)
        sorted_teacher_dict_list = sorted(teacher_dict_list, key=lambda x: x['value'], reverse=True)
        sorted_student_dict_list = sorted(student_dict_list, key=lambda x: x['value'], reverse=True)
        return sorted_dict_list[:50], sorted_uptake_dict_list[:50], sorted_teacher_dict_list[:50], sorted_student_dict_list[:50]

    def get_talk_timeline(self):
        return [utterance.to_talk_timeline_dict() for utterance in self.utterances]
    
    def calculate_aggregate_word_count(self):
        unit_measures = [utt.unit_measure for utt in self.utterances]
        if None in unit_measures:
            aggregate_word_count = 0
            for utt in self.utterances: 
                aggregate_word_count += utt.get_num_words()
                utt.unit_measure = utt.get_num_words()
                utt.aggregate_unit_measure = aggregate_word_count


    def to_dict(self):
        return {
            'utterances': [utterance.to_dict() for utterance in self.utterances],
            **self.params
        }

    def __repr__(self):
        return f"Transcript(utterances={self.utterances}, custom_params={self.params})"


class QuestionModel:
    def __init__(self, device, tokenizer, input_builder, max_length=300, path=QUESTION_MODEL):
        print("Loading models...")
        self.device = device
        self.tokenizer = tokenizer
        self.input_builder = input_builder
        self.max_length = max_length
        self.model = MultiHeadModel.from_pretrained(
            path, head2size={"is_question": 2})
        self.model.to(self.device)

    def run_inference(self, transcript):
        self.model.eval()
        with torch.no_grad():
            for i, utt in enumerate(transcript.utterances):
                if "?" in utt.text:
                    utt.question = 1
                else:
                    text = utt.get_clean_text(remove_punct=True)
                    instance = self.input_builder.build_inputs([], text,
                                                               max_length=self.max_length,
                                                               input_str=True)
                    output = self.get_prediction(instance)
                    # print(output)
                    utt.question = np.argmax(
                        output["is_question_logits"][0].tolist())

    def get_prediction(self, instance):
        instance["attention_mask"] = [[1] * len(instance["input_ids"])]
        for key in ["input_ids", "token_type_ids", "attention_mask"]:
            instance[key] = torch.tensor(
                instance[key]).unsqueeze(0)  # Batch size = 1
            instance[key].to(self.device)

        output = self.model(input_ids=instance["input_ids"],
                            attention_mask=instance["attention_mask"],
                            token_type_ids=instance["token_type_ids"],
                            return_pooler_output=False)
        return output


class ReasoningModel:
    def __init__(self, device, tokenizer, input_builder, max_length=128, path=REASONING_MODEL):
        print("Loading models...")
        self.device = device
        self.tokenizer = tokenizer
        self.input_builder = input_builder
        self.max_length = max_length
        self.model = BertForSequenceClassification.from_pretrained(path)
        self.model.to(self.device)

    def run_inference(self, transcript, min_num_words=8, uptake_speaker=None):
        self.model.eval()
        with torch.no_grad():
            for i, utt in enumerate(transcript.utterances):
                if utt.get_num_words() >= min_num_words and utt.speaker != uptake_speaker:
                    instance = self.input_builder.build_inputs([], utt.text,
                                                               max_length=self.max_length,
                                                               input_str=True)
                    output = self.get_prediction(instance)
                    utt.reasoning = np.argmax(output["logits"][0].tolist())

    def get_prediction(self, instance):
        instance["attention_mask"] = [[1] * len(instance["input_ids"])]
        for key in ["input_ids", "token_type_ids", "attention_mask"]:
            instance[key] = torch.tensor(
                instance[key]).unsqueeze(0)  # Batch size = 1
            instance[key].to(self.device)

        output = self.model(input_ids=instance["input_ids"],
                            attention_mask=instance["attention_mask"],
                            token_type_ids=instance["token_type_ids"])
        return output


class UptakeModel:
    def __init__(self, device, tokenizer, input_builder, max_length=120, path=UPTAKE_MODEL):
        print("Loading models...")
        self.device = device
        self.tokenizer = tokenizer
        self.input_builder = input_builder
        self.max_length = max_length
        self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2})
        self.model.to(self.device)

    def run_inference(self, transcript, min_prev_words, uptake_speaker=None):
        self.model.eval()
        prev_num_words = 0
        prev_utt = None
        with torch.no_grad():
            for i, utt in enumerate(transcript.utterances):
                if ((uptake_speaker is None) or (utt.speaker == uptake_speaker)) and (prev_num_words >= min_prev_words):
                    textA = prev_utt.get_clean_text(remove_punct=False)
                    textB = utt.get_clean_text(remove_punct=False)
                    instance = self.input_builder.build_inputs([textA], textB,
                                                               max_length=self.max_length,
                                                               input_str=True)
                    output = self.get_prediction(instance)

                    utt.uptake = int(
                        softmax(output["nsp_logits"][0].tolist())[1] > .8)
                prev_num_words = utt.get_num_words()
                prev_utt = utt

    def get_prediction(self, instance):
        instance["attention_mask"] = [[1] * len(instance["input_ids"])]
        for key in ["input_ids", "token_type_ids", "attention_mask"]:
            instance[key] = torch.tensor(
                instance[key]).unsqueeze(0)  # Batch size = 1
            instance[key].to(self.device)

        output = self.model(input_ids=instance["input_ids"],
                            attention_mask=instance["attention_mask"],
                            token_type_ids=instance["token_type_ids"],
                            return_pooler_output=False)
        return output

class FocusingQuestionModel:
    def __init__(self, device, tokenizer, input_builder, max_length=128, path=FOCUSING_QUESTION_MODEL):
        print("Loading models...")
        self.device = device
        self.tokenizer = tokenizer
        self.input_builder = input_builder
        self.model = BertForSequenceClassification.from_pretrained(path)
        self.model.to(self.device)
        self.max_length = max_length

    def run_inference(self, transcript, min_focusing_words=0, uptake_speaker=None):
        self.model.eval()
        with torch.no_grad():
            for i, utt in enumerate(transcript.utterances):
                if utt.speaker != uptake_speaker or uptake_speaker is None:
                    utt.focusing_question = None
                    continue
                if utt.get_num_words() < min_focusing_words:
                    utt.focusing_question = None
                    continue
                instance = self.input_builder.build_inputs([], utt.text, max_length=self.max_length, input_str=True)
                output = self.get_prediction(instance)
                utt.focusing_question = np.argmax(output["logits"][0].tolist())

    def get_prediction(self, instance):
        instance["attention_mask"] = [[1] * len(instance["input_ids"])]
        for key in ["input_ids", "token_type_ids", "attention_mask"]:
            instance[key] = torch.tensor(
                instance[key]).unsqueeze(0)  # Batch size = 1
            instance[key].to(self.device)

        output = self.model(input_ids=instance["input_ids"],
                            attention_mask=instance["attention_mask"],
                            token_type_ids=instance["token_type_ids"])
        return output   

def load_math_terms():
    math_regexes = []
    math_terms_dict = {}
    for term in MATH_WORDS:
        if term in MATH_PREFIXES:
            math_terms_dict[rf"\b{term}(s|es|d|ed)?\b"] = term
            math_regexes.append(rf"\b{term}(s|es|d|ed)?\b")
        else:
            math_regexes.append(rf"\b{term}\b")
            math_terms_dict[rf"\b{term}\b"] = term
    return math_regexes, math_terms_dict

def run_math_density(transcript):
    math_regexes, math_terms_dict = load_math_terms()
    sorted_regexes = sorted(math_regexes, key=len, reverse=True)
    teacher_math_word_cloud = {}
    student_math_word_cloud = {}
    for i, utt in enumerate(transcript.utterances):
        text = utt.get_clean_text(remove_punct=True)
        num_matches = 0
        matched_positions = set()
        match_list = []
        for regex in sorted_regexes:
            matches = list(re.finditer(regex, text, re.IGNORECASE))
            # Filter out matches that share positions with longer terms
            matches = [match for match in matches if not any(match.start() in range(existing[0], existing[1]) for existing in matched_positions)]
            # matched_text = [match.group(0) for match in matches]
            if len(matches) > 0:
                if utt.role == "teacher":
                    if math_terms_dict[regex] not in teacher_math_word_cloud:
                        teacher_math_word_cloud[math_terms_dict[regex]] = 0
                    teacher_math_word_cloud[math_terms_dict[regex]] += len(matches)
                else:
                    if math_terms_dict[regex] not in student_math_word_cloud:
                        student_math_word_cloud[math_terms_dict[regex]] = 0
                    student_math_word_cloud[math_terms_dict[regex]] += len(matches)
                match_list.append(math_terms_dict[regex])                    
            # Update matched positions
            matched_positions.update((match.start(), match.end()) for match in matches)
            num_matches += len(matches)
            # print("match group list: ", [match.group(0) for match in matches])
        utt.num_math_terms = num_matches
        utt.math_terms = match_list
        # utt.math_match_positions = list(matched_positions)
        # utt.math_terms_raw = [text[start:end] for start, end in matched_positions]
    teacher_dict_list = []
    student_dict_list = []
    dict_list = []
    for word in teacher_math_word_cloud.keys():
        teacher_dict_list.append(
            {'text': word, 'value': teacher_math_word_cloud[word], 'category': "math"})
        dict_list.append({'text': word, 'value': teacher_math_word_cloud[word], 'category': "math"})
    for word in student_math_word_cloud.keys():
        student_dict_list.append(
            {'text': word, 'value': student_math_word_cloud[word], 'category': "math"}) 
        dict_list.append({'text': word, 'value': student_math_word_cloud[word], 'category': "math"})
    sorted_dict_list = sorted(dict_list, key=lambda x: x['value'], reverse=True)
    sorted_teacher_dict_list = sorted(teacher_dict_list, key=lambda x: x['value'], reverse=True)
    sorted_student_dict_list = sorted(student_dict_list, key=lambda x: x['value'], reverse=True)
    # return sorted_dict_list[:50]
    return sorted_dict_list[:50], sorted_teacher_dict_list[:50], sorted_student_dict_list[:50]

class EndpointHandler():
    def __init__(self, path="."):
        print("Loading models...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.input_builder = BertInputBuilder(tokenizer=self.tokenizer)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `list`):
            List of dicts, where each dict represents an utterance; each utterance object must have a `speaker`,
            `text` and `uid`and can include list of custom properties
            parameters (:obj: `dict`)
       Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        utterances = data.pop("inputs", data)
        params = data.pop("parameters", None)

        transcript = Transcript(filename=params.pop("filename", None))
        for utt in utterances:
            transcript.add_utterance(Utterance(**utt))

        print("Running inference on %d examples..." % transcript.length())
        logging.set_verbosity_info()
        # Uptake
        uptake_model = UptakeModel(
            self.device, self.tokenizer, self.input_builder)
        uptake_speaker = params.pop("uptake_speaker", None)
        uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'],
                                   uptake_speaker=uptake_speaker)
        del uptake_model
        
        # Reasoning
        reasoning_model = ReasoningModel(
            self.device, self.tokenizer, self.input_builder)
        reasoning_model.run_inference(transcript, uptake_speaker=uptake_speaker)
        del reasoning_model
        
        # Question
        question_model = QuestionModel(
            self.device, self.tokenizer, self.input_builder)
        question_model.run_inference(transcript)
        del question_model
        
        # Focusing Question
        focusing_question_model = FocusingQuestionModel(
            self.device, self.tokenizer, self.input_builder)
        focusing_question_model.run_inference(transcript, uptake_speaker=uptake_speaker)
        del focusing_question_model
        
        transcript.update_utterance_roles(uptake_speaker)
        sorted_math_cloud, teacher_math_cloud, student_math_cloud = run_math_density(transcript)
        transcript.calculate_aggregate_word_count()
        return_dict = {'talkDistribution': None, 'talkLength': None, 'talkMoments': None, 'studentTopWords': None, 'teacherTopWords': None}
        talk_dist, talk_len = transcript.get_talk_distribution_and_length(uptake_speaker)
        return_dict['talkDistribution'] = talk_dist
        return_dict['talkLength'] = talk_len
        talk_moments = transcript.get_talk_timeline()
        return_dict['talkMoments'] = talk_moments
        word_cloud, uptake_word_cloud, teacher_general_cloud, student_general_cloud = transcript.get_word_clouds()
        teacher_cloud = teacher_math_cloud + teacher_general_cloud
        student_cloud = student_math_cloud + student_general_cloud
        return_dict['teacherTopWords'] = teacher_cloud
        return_dict['studentTopWords'] = student_cloud

        return return_dict