import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from aksharamukha import transliterate
import torch
from dotenv import load_dotenv
import os
import requests

access_token = os.getenv('token')

# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"

chat_language = 'sin_Sinh'

trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
eng_trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")

device = "cuda" if torch.cuda.is_available() else "cpu"

translator = pipeline('translation', model=trans_model, tokenizer=eng_trans_tokenizer, src_lang="eng_Latn", tgt_lang=chat_language, max_length = 400, device=device)

# Initialize translation pipelines
pipe = pipeline("translation", model="thilina/mt5-sinhalese-english")

sin_trans_model = AutoModelForSeq2SeqLM.from_pretrained("thilina/mt5-sinhalese-english")
si_trans_tokenizer = AutoTokenizer.from_pretrained("thilina/mt5-sinhalese-english")

singlish_pipe = pipeline("text2text-generation", model="Dhahlan2000/Simple_Translation-model-for-GPT-v15")

# Translation functions
def translate_Singlish_to_sinhala(text):

    translated_text = singlish_pipe(f"translate Singlish to Sinhala: {text}", clean_up_tokenization_spaces=False)[0]['generated_text']

    return translated_text.replace('\u200d', '')

def translate_english_to_sinhala(text):
    # Split the text into sentences or paragraphs
    parts = text.split("\n")  # Split by new lines for paragraphs, adjust as needed
    translated_parts = []
    for part in parts:
        translated_part = translator(part, clean_up_tokenization_spaces=False)[0]['translation_text']
        translated_parts.append(translated_part)
    # Join the translated parts back together
    translated_text = "\n".join(translated_parts)
    return translated_text.replace("ප් රභූවරුන්", "").replace('\u200d', '')

def translate_sinhala_to_english(text):
    # Split the text into sentences or paragraphs
    parts = text.split("\n")  # Split by new lines for paragraphs, adjust as needed
    translated_parts = []
    for part in parts:
        # Tokenize each part
        inputs = si_trans_tokenizer(part.strip(), return_tensors="pt", padding=True, truncation=True, max_length=512)
        # Generate translation
        outputs = sin_trans_model.generate(**inputs)
        # Decode translated text while preserving formatting
        translated_part = si_trans_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
        translated_parts.append(translated_part)
    # Join the translated parts back together
    translated_text = "\n".join(translated_parts)
    return translated_text

def transliterate_from_sinhala(text):
    # Define the source and target scripts
    source_script = 'Sinhala'
    target_script = 'Velthuis'

    # Perform transliteration
    latin_text = transliterate.process(source_script, target_script, text)

    # Convert to a list to allow modification
    latin_text_list = list(latin_text)

    # Replace periods with the following character
    i = 0
    for i in range(len(latin_text_list) - 1):
        if latin_text_list[i] == '.':
            latin_text_list[i] = ''
        if latin_text_list[i] == '*':
            latin_text_list[i] = ''
        if latin_text_list[i] == '\"':
            latin_text_list[i] = ''

    # Convert back to a string
    latin_text = ''.join(latin_text_list)

    return latin_text.lower()

def transliterate_to_sinhala(text):

  # Define the source and target scripts
  source_script = 'Velthuis'
  target_script = 'Sinhala'

  # Perform transliteration
  latin_text = transliterate.process(source_script, target_script, text)
  return latin_text

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token = access_token)
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    torch_dtype=torch.bfloat16,
    token = access_token
)

    
def conversation_predict(input_text):
    input_ids = tokenizer(input_text, return_tensors="pt")
    
    outputs = model.generate(**input_ids)

    return tokenizer.decode(outputs[0])

def ai_predicted(user_input):
    user_input = translate_Singlish_to_sinhala(user_input)
    user_input = transliterate_to_sinhala(user_input)
    print("You(Sinhala): ", user_input,"\n")
    user_input = translate_sinhala_to_english(user_input)
    print("You(English): ", user_input,"\n")

    # Get AI response
    ai_response = conversation_predict(user_input)

    # Split the AI response into separate lines
    # ai_response_lines = ai_response.split("</s>")
    print("AI(English): ", ai_response,"\n")

    response = translate_english_to_sinhala(ai_response)
    print("AI(Sinhala): ", response,"\n")
    response = transliterate_from_sinhala(response)
    print(response)

    return response

# Gradio Interface
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ai_predicted(message)

    yield response

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    demo.launch(share=True)