import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification,AutoModel
import re
from textblob import TextBlob
from nltk import pos_tag, word_tokenize
from nltk.corpus import stopwords
import emoji 
import string
import nltk
from nltk import pos_tag
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import textstat
import pandas as pd
from transformers import pipeline
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import os
from dotenv import load_dotenv
import pandas as pd
load_dotenv()

    





#Loading author details
def average_word_length(tweet):
    words = tweet.split()
    return sum(len(word) for word in words) / len(words)


def lexical_diversity(tweet):
    words = tweet.split()
    unique_words = set(words)
    return len(unique_words) / len(words)

def count_capital_letters(tweet):
    return sum(1 for char in tweet if char.isupper())

def count_words_surrounded_by_colons(tweet):
    # Define a regular expression pattern to match words surrounded by ':'
    pattern = r':(\w+):'

    # Use re.findall to find all matches in the tweet
    matches = re.findall(pattern, tweet)

    # Return the count of matched words
    return len(matches)

def count_emojis(tweet):
    # Convert emoji symbols to their corresponding names
    tweet_with_names = emoji.demojize(tweet)
    return count_words_surrounded_by_colons(tweet_with_names)

def hashtag_frequency(tweet):
    hashtags = re.findall(r'#\w+', tweet)
    return len(hashtags)

def mention_frequency(tweet):
    mentions = re.findall(r'@\w+', tweet)
    return len(mentions)

def count_special_characters(tweet):
    special_characters = [char for char in tweet if char in string.punctuation]
    return len(special_characters)


def stop_word_frequency(tweet):
    stop_words = set(stopwords.words('english'))
    words = [word for word in tweet.split() if word.lower() in stop_words]
    return len(words)

nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')

def get_linguistic_features(tweet):
    # Tokenize the tweet
    words = word_tokenize(tweet)

    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    filtered_words = [word.lower() for word in words if word.isalnum() and word.lower() not in stop_words]

    # Get parts of speech tags
    pos_tags = pos_tag(filtered_words)

    # Count various linguistic features
    noun_count = sum(1 for word, pos in pos_tags if pos.startswith('N'))
    verb_count = sum(1 for word, pos in pos_tags if pos.startswith('V'))
    participle_count = sum(1 for word, pos in pos_tags if pos.startswith('V') and ('ing' in word or 'ed' in word))
    interjection_count = sum(1 for word, pos in pos_tags if pos == 'UH')
    pronoun_count = sum(1 for word, pos in pos_tags if pos.startswith('PRP'))
    preposition_count = sum(1 for word, pos in pos_tags if pos.startswith('IN'))
    adverb_count = sum(1 for word, pos in pos_tags if pos.startswith('RB'))
    conjunction_count = sum(1 for word, pos in pos_tags if pos.startswith('CC'))

    return {
        'Noun_Count': noun_count,
        'Verb_Count': verb_count,
        'Participle_Count': participle_count,
        'Interjection_Count': interjection_count,
        'Pronoun_Count': pronoun_count,
        'Preposition_Count': preposition_count,
        'Adverb_Count': adverb_count,
        'Conjunction_Count': conjunction_count
    }

def readability_score(tweet):
    return textstat.flesch_reading_ease(tweet)

def get_url_frequency(tweet):
    urls = re.findall(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', tweet)
    return len(urls)


# Define a function to extract features from a single tweet
def extract_features(tweet):
    features = {
        'Average_Word_Length': average_word_length(tweet),
        # 'Average_Sentence_Length': average_sentence_length(tweet),
        'Lexical_Diversity': lexical_diversity(tweet),
        'Capital_Letters_Count': count_capital_letters(tweet),  # Uncomment if you want to include this feature
        'Hashtag_Frequency': hashtag_frequency(tweet),
        'Mention_Frequency': mention_frequency(tweet),
        'count_emojis': count_emojis(tweet),
        'special_chars_count': count_special_characters(tweet),
        'Stop_Word_Frequency': stop_word_frequency(tweet),
        **get_linguistic_features(tweet),  # Include linguistic features
        'Readability_Score': readability_score(tweet),
        'URL_Frequency': get_url_frequency(tweet)  # Assuming you have the correct function for this
    }
    return features

# # Extract features for all tweets
# features_list = [extract_features(tweet) for tweet in X['text']]

# # Create a Pandas DataFrame
# X_new = pd.DataFrame(features_list)



# Loading personality model

def personality_detection(text, threshold=0.05, endpoint= 1.0):
    PERSONALITY_TOKEN =os.environ.get('PERSONALITY_TOKEN', None)
    print(PERSONALITY_TOKEN)
    tokenizer = AutoTokenizer.from_pretrained ("Nasserelsaman/microsoft-finetuned-personality",token=PERSONALITY_TOKEN)
    model = AutoModelForSequenceClassification.from_pretrained ("Nasserelsaman/microsoft-finetuned-personality",token=PERSONALITY_TOKEN)

    with torch.no_grad():
        inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
        outputs = model(**inputs)
        predictions = outputs.logits.squeeze().detach().numpy()
        
        # Get raw logits
        logits = model(**inputs).logits
        
        # Apply sigmoid to squash between 0 and 1
        probabilities = torch.sigmoid(logits)
    
    # # Set values less than the threshold to 0.05
    # predictions[predictions < threshold] = 0.05
    # predictions[predictions > endpoint] = 1.0
    # print("per",probabilities[0][0].detach().numpy())
    # print("per",probabilities[0][1].detach().numpy())
    # print("per",probabilities[0][2].detach().numpy())
    # print("per",probabilities[0][3].detach().numpy())
    # print("per",probabilities[0][4].detach().numpy())
    
    # label_names = ['Agreeableness', 'Conscientiousness', 'Extraversion', 'Neuroticism', 'Openness']
    # # result = {label_names[i]: f"{predictions[i]*100:.0f}%" for i in range(len(label_names))}
    # result = {label_names[i]: f"{probabilities}%" for i in range(len(label_names))}
    # probabilities
    print(probabilities)
    return [probabilities[0][0].detach().numpy()
            ,probabilities[0][1].detach().numpy()
            ,probabilities[0][2].detach().numpy()
            ,probabilities[0][3].detach().numpy()
            ,probabilities[0][4].detach().numpy()]


# tokenizer = AutoTokenizer.from_pretrained("Nasserelsaman/microsoft-finetuned-personality")
# model = AutoModelForSequenceClassification.from_pretrained("Nasserelsaman/microsoft-finetuned-personality")

#Loading emotion model

# tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-emotion-multilabel-latest")
# model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-emotion-multilabel-latest")

##use this for gpu
# pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-emotion-multilabel-latest", return_all_scores=True,device=device )

##use this for cpu
def calc_emotion_score(tweet):
    pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-emotion-multilabel-latest", return_all_scores=True )
    emotions = pipe(tweet)[0]
    for i in emotions:
        print(i)

    return [emotions[0]['score'],emotions[1]['score'],emotions[2]['score'],emotions[3]['score'],emotions[4]['score'],emotions[5]['score'],emotions[6]['score'],emotions[7]['score'],emotions[8]['score'],emotions[9]['score'],emotions[10]['score']]
    





#DCL model launching

def load_model(tweet):
    # model = torch.load("./authormodel.pt",map_location ='cpu') 
    # print(model)

    model_name = "vinai/bertweet-base"
    PADDING_MAX_LENGTH = 45
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    inputs = tokenizer(tweet, truncation=True, padding='max_length',max_length=PADDING_MAX_LENGTH,add_special_tokens=True, return_tensors="pt")
    print(inputs)
    emotion_list = calc_emotion_score(tweet)
    print(emotion_list)
    preemotion_list = emotion_list[:]

    features_list = extract_features(tweet)
    for i in features_list.values():
        emotion_list.append(i)
    print("emotion + author",emotion_list)
    # print()
    # print(features_list)
    personality_list = personality_detection(tweet)
    print("personality",personality_list)
    # person_list = [personality_list["Extraversion"],personality_list['Neuroticism'],personality_list['Agreeableness'],personality_list['Conscientiousness'],personality_list['Openness']]
    emotion_list.extend(personality_list)
    print("final list",emotion_list)
    # print(str(features_list["Average_Word_Length"]))
    inputs['emotion_author_vector'] =  torch.tensor([emotion_list])

    print("final inputs    ",inputs)
    
    
    # []
    # inputs["emotion_author_vector"] = 
    # train_dataloader=DataLoader(inputs, batch_size=1 , shuffle=False)
    # print(train_dataloader)
    device = torch.device("cuda:0"  if torch.cuda.is_available() else "cpu")
    # def tokenize_function(examples):
    #     return tokenizer.batch_encode_plus(examples["text"], padding='max_length',max_length=PADDING_MAX_LENGTH,add_special_tokens=True,truncation=True)
    class EmotionAuthorGuidedDCLModel(nn.Module):
        def __init__(self,dcl_model:nn.Module,dropout:float=0.5):
            super(EmotionAuthorGuidedDCLModel, self).__init__()
            self.dcl_model = dcl_model
            self.dim = 802
            self.dropout = nn.Dropout(dropout)
            self.linear = nn.Linear(self.dim, 1)
            # Freeze all layers
            for param in self.dcl_model.parameters():
                param.requires_grad = False

        def forward(self,batch_tokenized):
            input_ids = batch_tokenized['input_ids']
            attention_mask = batch_tokenized['attention_mask']
            emotion_vector = batch_tokenized['emotion_author_vector']
            bert_output = self.dcl_model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
            bert_cls_hidden_state = bert_output[1]
            combined_vector =torch.cat((bert_cls_hidden_state,emotion_vector), 1)
            d_combined_vector=self.dropout(combined_vector)
            linear_output = self.linear(d_combined_vector)
            pred_linear = linear_output.squeeze(1)
            return pred_linear
    # twee
   
    checkpoint = {
        "model_state_dict":torch.load("./model.pt",map_location ='cpu') ,
    }
     
    # checkpoint=load_checkpoint(run=run_dcl_study,check_point_name="model_checkpoints/")
    
    class DCLArchitecture(nn.Module):
        def __init__(self,dropout:float,bert_model_name:str='vinai/bertweet-base'):
            super(DCLArchitecture, self).__init__()
            self.bert = AutoModel.from_pretrained(bert_model_name)
            self.dim = 768
            self.dense = nn.Linear(self.dim, 1)
            self.dropout = nn.Dropout(dropout)

        def forward(self,batch_tokenized, if_train=False):
            input_ids = batch_tokenized['input_ids']
            attention_mask = batch_tokenized['attention_mask']
            bert_output = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
            bert_cls_hidden_state = bert_output[1]
            torch.cuda.empty_cache()

            if if_train:
                bert_cls_hidden_state_aug = self.dropout(bert_cls_hidden_state)
                bert_cls_hidden_state = torch.cat((bert_cls_hidden_state, bert_cls_hidden_state_aug), dim=1).reshape(-1, self.dim)
            else:
                bert_cls_hidden_state = self.dropout(bert_cls_hidden_state)

            linear_output = self.dense(bert_cls_hidden_state)
            linear_output = linear_output.squeeze(1)

            return bert_cls_hidden_state, linear_output
    

    # dcl_model = DCLArchitecture(bert_model_name=model_name,dropout=best_prams["DROPOUT"])
    dcl_model = DCLArchitecture(bert_model_name=model_name,dropout=0.5)
    dcl_model.to(device)
    
    DROPOUT = 0.5
    fined_tuned_bert_model=dcl_model.bert
    model = EmotionAuthorGuidedDCLModel(dcl_model=fined_tuned_bert_model,dropout=DROPOUT)
    model.to(device)
    model.load_state_dict(checkpoint["model_state_dict"])
    



    # def test_loop(model, test_dataloader, device):
    # # collection_metric = MetricCollection(
    # #       BinaryAccuracy(),
    # #       MulticlassPrecision(num_classes=2,average=average),
    # #       MulticlassRecall(num_classes=2,average=average),
    # #       MulticlassF1Score(num_classes=2,average=average),
    # #       BinaryConfusionMatrix()
    # # )
    # # collection_metric.to(device)
    #     model.eval()
    #     print(test_dataloader)
    #     # total_test_loss = 0.0
    #     for batch in test_dataloader:
    #         print(batch)
    #         batch = {k: v.to(device) for k, v in batch.items()}
    #         # labels = batch["labels"]
    #         with torch.no_grad():
    #             pred = model(batch)
    #             # loss = criteon(pred, labels.float())
    #             pred = torch.round(torch.sigmoid(pred))
        
    #     return pred
    # result_metrics=test_loop(model=model, test_dataloader=train_dataloader,device=device)
    # print("Hate speech result",result_metrics)

    def predict_single_text(model, inputs,device):
        # Preprocess the text
        # inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Pass the preprocessed text through the model
        with torch.no_grad():
            model.eval()
            pred = model(inputs)
            print("prediction ",pred)
            print("sigmoid output",torch.sigmoid(pred))
            pred = torch.sigmoid(pred)
            # Assuming your model returns a single value for prediction
            
        
        return pred

    predicted_class = predict_single_text(model, inputs, device)
    return predicted_class,preemotion_list,personality_list
    # print("Hate speech result",predicted_class)




#Gradio interface
simple  = None
personality_values =None
def greet(tweet):
    print("start")
    prediction,preemotion_list,personality_list = load_model(tweet)
    # preemotion_list = [x * 100 for x in preemotion_list]
    simple = pd.DataFrame(
    {
        "Emotions": ["Anger", "Anticipation", "Disgust", "Fear", "Joy", "Love", "Optimism", "Pessimism", "Sadness","Surprise","Trust"],
        "Values": preemotion_list,
    }
    )
    personality_values = pd.DataFrame(
        {
            "Personality": ['Agreeableness', 'Conscientiousness', 'Extraversion', 'Neuroticism', 'Openness'],
            "Values": personality_list,
        }
    )

    # with gr.Blocks() as bar_plot:
    #     bar_plot.load(outputs= gr.BarPlot(
    #             simple,
    #             x="Emotions",
    #             y="Values",
    #             title="Simple Bar Plot with made up data",
    #             tooltip=["a", "b"],
    #             y_lim=[20, 100],
    #         ))

    # bar_plot.launch()

    prediction_value = round(prediction.item(),2)
    # features_list = extract_features(tweet)
    # print(personality_detection(tweet))
    # print(str(features_list["Average_Word_Length"]))
    # print(calc_emotion_score(tweet))
    predicted_class = torch.round(prediction).item()
    print(preemotion_list)
    print(personality_list)
    print("end")
    if (predicted_class==0.0):
        label = "Non Hate"
    else:
        label = "Hate"
    

    return label,str(round(prediction_value,2))+"%",str(round(1-prediction_value,2))+"%",simple,personality_values

# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo = gr.Interface(theme= gr.themes.Soft(),
    title = "Unmasking Hate: An Integrated Approach to Detecting Hate Speech in Social Media",
    #  css="""
    #                 .gradio-container {background-color: #caf0f8}
    #              """,
    # live = True, 
    # fn=greet,
    fn=greet, inputs=gr.Textbox(placeholder="Enter an input sentence...",label="Input Sentence"), 
    allow_flagging = "never",
    outputs=[
        gr.Label(label="Label"),
        gr.Textbox(label="Hate Speech Percentage"),
        gr.Textbox(label="Non Hate Speech Percentage"),
        gr.BarPlot(
                simple,
                x="Emotions",
                y="Values",
                title="Emotion Analysis",
                tooltip=["Emotions", "Values"],
                y_lim=[0, 1],
                label="Emotion bar graph"
            ),
            gr.BarPlot(
                personality_values,
                x="Personality",
                y="Values",
                title="Personality Analysis",
                tooltip=["Personality", "Values"],
                y_lim=[0, 1],
                label="Personality bar graph"
            )
    ],
    examples=[
        ["I like you"],
        ["I hate you"],
        ["I can't stand those asian always causing trouble. They need to go back to where they came from!"],
        ["Just saw a Sunni preaching on the street. Why don't they go worship in their own country instead of invading ours?"],
        ["I wish all bisexuals would just disappear. Sick of their agenda being shoved in our faces"],
        ["Had a great time celebrating diversity at the multicultural festival today!"],
        ["Congratulations to Sri Lankans for their cultural contributions to our society"],
        ["Love is love, no matter who you are or who you love"] ]
        )

demo.launch()