from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from PIL import Image 
import os
from utils import load_json_file, str2time
from openai import OpenAI
import base64

def get_smallest_timestamp(timestamps):
    assert len(timestamps) > 0

    timestamps_in_ms = [str2time(elem) for elem in timestamps]

    smallest_timestamp_in_ms = timestamps_in_ms[0]
    smallest_timestamp = timestamps[0]
    for i, elem in enumerate(timestamps_in_ms):
        if elem < smallest_timestamp_in_ms:
            smallest_timestamp_in_ms = elem
            smallest_timestamp = timestamps[i]
    return smallest_timestamp

def generate(query, context, relevant_timestamps=None):
    prompt = PromptTemplate(input_variables=["question", "context"], template="You're a helpful LLM assistant in answering questions regarding a video. Given contexts are segments relevant to the question, please answer the question. Do not refer to segments. Context: {context}, question: {question} \nA:")

    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(question=query, context=context)

    if relevant_timestamps is not None and len(relevant_timestamps)>0:
        # get smallest timestamp = earliest mention
        smallest_timestamp = get_smallest_timestamp(relevant_timestamps)
        response += f' {smallest_timestamp}'
    return response


def check_relevance(query, relevant_metadatas):
    transcripts = [frame['transcript'] for frame in relevant_metadatas]
    captions = [frame['caption'] for frame in relevant_metadatas]
    timestamps = [frame['start_time'] for frame in relevant_metadatas]

    context = ""
    for i in range(len(transcripts)):
        context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
    # print(context)

    prompt = PromptTemplate(input_variables=["question", "context"], template="""
    You are a grader assessing relevance of a retrieved video segment to a user question. \n 
    If the video segment contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
    Give a binary score 'yes' or 'no' score to indicate whether the video segment is relevant to the question. \n
    Answer in a string, separated by commas. For example: if there are segments provided, answer: yes,no,no,yes. \n
    Question: {question} Context: {context}\n A:""")

    # query = "What are the books mentioned in the video?"
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(question=query, context=context)
    # print(response)

    relevance_response = response.split(',')

    actual_relevant_context = ""
    relevant_timestamps = []
    for i, relevance_check in enumerate(relevance_response):
        if relevance_check.strip() == 'yes':
            actual_relevant_context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
            relevant_timestamps.append(timestamps[i])
    return actual_relevant_context, relevant_timestamps


def retrieve_segments_from_timestamp(metadatas, timestamps):
    relevant_segments = []

    for timestamp in timestamps:
        time_to_find_ms = str2time(timestamp)
        buffer = 5000 # 5 seconds before and after

        for segment in metadatas:
            start = str2time(segment['start_time'])
            end = str2time(segment['end_time'])
            if start <= time_to_find_ms + buffer and end >= time_to_find_ms - buffer:
                relevant_segments.append(segment)

    return relevant_segments


def check_timestamps(query):
    prompt = PromptTemplate(input_variables=["question"], template="You're a helpful LLM assistant. You're good at detecting any timestamps provided in a query. Please detect the question and timestamp in the the following question and separated them by commas such as question,timestamp1,timestamp2 if timestamps are provided else just question. Question: {question} \nA:")

    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(question=query)

    timestamps = []
    if len(response.split(',')) > 1:
        query = response.split(',')[0].strip()
        timestamps = [f"00:{elem.strip()}.00" for elem in response.split(',')[1:]]

    return query, timestamps

def retrieve_by_embedding(index, video_path, query, text_model):
    print(query)
    query_embedding = text_model.encode(query)

    res = index.query(vector=query_embedding.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )

    metadatas = []
    for id, match_ in enumerate(res['matches']):
        result = index.fetch(ids=[match_['id']])
        
        # Extract the vector data
        vector_data = result.vectors.get(match_['id'], {})

        # Extract metadata
        metadata = vector_data.metadata
        metadatas.append(metadata)

    return metadatas

def self_reflection(query, answer, summary):
    prompt = PromptTemplate(input_variables=["summary", "question", "answer"], template="You're a helpful LLM assistant. You're good at determining if the provided answer is satisfactory to a question relating to a video. You have access to the video summary as follows: {summary}. Given a pair of question and answer, give the answer's satisfactory score in either yes or no. Question: {question}, Answer: {answer} \nA:")
 
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(summary=summary, question=query, answer=answer)
    return response


def get_full_transcript(metadatas):
    # metadatas = webvtt.read(path_to_transcript)
    transcripts = [frame['transcript'] for frame in metadatas]

    full_text = ''
    for idx, transcript in enumerate(transcripts):
        text = transcript.strip().replace("  ", " ")
        full_text += f"{text} "

    full_text = full_text.strip()
    return full_text

def summarize_video(metadatas_path:str):
    metadatas = load_json_file(metadatas_path)

    # get full transcript 
    transcript = get_full_transcript(metadatas)
    prompt = PromptTemplate(input_variables=["transcript"], template="You're a helpful LLM assistant. Please provide a summary for the video given its full transcript: {transcript} \nA:")
 
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(transcript=transcript)
    return response

def answer_wrt_timestamp(query, context):
    prompt = PromptTemplate(input_variables=["question", "context"], template="""
    You're a helpful LLM assistant. Given a question and a timestamp, I have retrieved the relevant context as follows. Please answer the question using the information provided in the context. Question: {question}, context: {context} \n
    For example: Question="What happens at 4:20?" Caption="a person is standing up" Transcript="I have to go" Appropriate Answer="At 4:20, a person is standing up and saying he has to go."
    A:""")
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    chain = LLMChain(llm=llm, prompt=prompt)
    response = chain.run(question=query, context=context)
    return response


def answer_question(index, model_stack, metadatas_path, video_summary:str, video_path:str, query:str, image_input_path:str=None):
    metadatas = load_json_file(metadatas_path)
    if image_input_path is not None:
        return answer_image_question(index, model_stack, metadatas, video_summary, video_path, query, image_input_path)

    # check if timestamp provided 
    query, timestamps = check_timestamps(query)

    if len(timestamps) > 0:
        # retrieve by timestamps 
        relevant_segments_metadatas = retrieve_segments_from_timestamp(metadatas, timestamps)
        transcripts = [frame['transcript'] for frame in relevant_segments_metadatas]
        captions = [frame['caption'] for frame in relevant_segments_metadatas]
        context = ""
        for i in range(len(transcripts)):
            context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
        # print(context)
        return answer_wrt_timestamp(query, context)
    else:
        # retrieve by embedding 
        relevant_segments_metadatas = retrieve_by_embedding(index, video_path, query, model_stack[0])

    # check relevance 
    actual_relevant_context, relevant_timestamps = check_relevance(query, relevant_segments_metadatas)
    # relevant_timestamps = [frame['start_time'] for frame in relevant_segments_metadatas]
    # print(actual_relevant_context)

    # generate 
    answer = generate(query, actual_relevant_context, relevant_timestamps)
    # print(answer)

    # self-reflection 
    reflect = self_reflection(query, answer, video_summary)

    # print("Reflect", reflect)
    if reflect.lower() == 'no':
        answer = generate(query, f"{actual_relevant_context}\nSummary={video_summary}")
    
    return answer

def retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path):
    image_query = Image.open(image_query_path)
    _, vision_model, vision_model_processor, _, _ = model_stack
    inputs = vision_model_processor(images=image_query, return_tensors="pt")
    outputs = vision_model(**inputs)
    image_query_embeds = outputs.pooler_output

    res = index.query(vector=image_query_embeds.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )
    
    metadatas = []
    for id_, match_ in enumerate(res['matches']):
        result = index.fetch(ids=[match_['id']])
        
        # Extract the vector data
        vector_data = result.vectors.get(match_['id'], {})

        # Extract metadata
        metadata = vector_data.metadata
        metadatas.append(metadata)

    return metadatas


def answer_image_question(index, model_stack, metadatas, video_summary:str, video_path:str, query:str, image_query_path:str=None):
    # search segment by image 
    relevant_segments = retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path)

    # generate answer using those segments 
    return generate_w_image(query, image_query_path, relevant_segments)


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def generate_w_image(query:str, image_query_path:str, relevant_metadatas):
    base64_image = encode_image(image_query_path)
    transcripts = [frame['transcript'] for frame in relevant_metadatas]
    captions = [frame['caption'] for frame in relevant_metadatas]
    # timestamps = [frame['start_time'] for frame in relevant_metadatas]

    context = ""
    for i in range(len(transcripts)):
        context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
    # print(context)

    client = OpenAI()
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "user", "content": [
                {"type": "text", "text": f"Here is some context about the image: {context}"},  # Add context here
                {"type": "text", "text": "You are a helpful LLM assistant. You are good at answering questions about a video given an image. Given the context surrounding the frames most correlated with the image and image, please answer the question. Question: {query}"},
                {"type": "image_url", "image_url": {
                    "url": f"data:image/png;base64,{base64_image}"
                    }
                }
            ]}
        ],
        temperature=0.0,
        max_tokens=100,
    )

    response = response.choices[0].message.content
    # print(response)
    return response