import streamlit as st
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM

# Set page config first
st.set_page_config(page_title="Coding Multiple Choice Q&A", layout="wide")

# Use the specified model
MODEL_PATH = "tuandunghcmut/Qwen25_Coder_MultipleChoice_v4"



from coding_examples import CODING_EXAMPLES_BY_CATEGORY

# Flatten examples
CODING_EXAMPLES = []
for category, examples in CODING_EXAMPLES_BY_CATEGORY.items():
    for example in examples:
        example["category"] = category
        CODING_EXAMPLES.append(example)

class PromptCreator:
    def __init__(self, prompt_type="yaml"):
        self.prompt_type = prompt_type
        
    def format_choices(self, choices):
        if not choices: return ""
        if isinstance(choices, str): return choices
        return "\n".join(f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices))

    def get_max_letter(self, choices):
        if not choices: return "A"
        if isinstance(choices, str):
            num_choices = len([line for line in choices.split("\n") if line.strip()])
            return "A" if num_choices == 0 else chr(64 + num_choices)
        return chr(64 + len(choices))

    def create_inference_prompt(self, question, choices):
        if not question: return ""
        formatted_choices = self.format_choices(choices)
        max_letter = self.get_max_letter(choices)
        
        return f"""Question: {question}

Choices:
{formatted_choices}

Analyze this question step-by-step and provide a detailed explanation.
Your response MUST be in YAML format as follows:

understanding: |
  <your understanding of what the question is asking>
analysis: |
  <your analysis of each option>
reasoning: |
  <your step-by-step reasoning process>
conclusion: |
  <your final conclusion>
answer: <single letter A through {max_letter}>

The answer field MUST contain ONLY a single character letter."""

class QwenModelHandler:
    def __init__(self, model_path):
        with st.spinner("Loading model..."):
            try:
                # Explicitly disable quantization options
                self.tokenizer = AutoTokenizer.from_pretrained(
                    model_path,
                    trust_remote_code=True
                )
                
                # Load with standard precision on CPU
                from peft import PeftModel
                from transformers import AutoModelForCausalLM

                base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct")
                self.model = PeftModel.from_pretrained(base_model, "tuandunghcmut/Qwen25_Coder_MultipleChoice_v4")
                # self.model = AutoModelForCausalLM.from_pretrained(
                #     model_path,
                #     torch_dtype=torch.float32,
                #     device_map="cpu",
                #     trust_remote_code=True,
                #     # Explicitly disable quantization
                #     load_in_8bit=False,
                #     load_in_4bit=False
                # )
                
                if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
            except Exception as e:
                st.error(f"Error: {str(e)}")
                raise

    def generate_response(self, prompt, max_tokens=512, temperature=0.7, 
                          top_p=0.9, top_k=50, repetition_penalty=1.0, 
                          do_sample=True):
        try:
            inputs = self.tokenizer(prompt, return_tensors="pt")
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                    repetition_penalty=repetition_penalty,
                    do_sample=do_sample,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            if prompt in response:
                response = response[len(prompt):].strip()
            return response
        except Exception as e:
            return f"Error during generation: {str(e)}"

# Create prompt without requiring model
def create_prompt(question, choices):
    creator = PromptCreator(prompt_type="yaml")
    return creator.create_inference_prompt(question, choices)

def main():
    # Initialize session state
    if 'model_loaded' not in st.session_state:
        st.session_state.model_loaded = False
    if 'model_output' not in st.session_state:
        st.session_state.model_output = ""
    
    st.title("Coding Multiple Choice Q&A with YAML Reasoning")
    st.warning("⚠️ Running on CPU - model loading and inference will be slow")
    
    # Two-column layout
    col1, col2 = st.columns([4, 6])
    
    with col1:
        st.subheader("Examples")
        
        # Category selector
        category_options = ["All Categories"] + list(CODING_EXAMPLES_BY_CATEGORY.keys())
        selected_category = st.selectbox("Select a category", category_options)
        
        # Example selector
        if selected_category == "All Categories":
            example_options = [f"Example {i+1}: {ex['question']}" for i, ex in enumerate(CODING_EXAMPLES)]
        else:
            example_options = []
            start_idx = 0
            for cat, examples in CODING_EXAMPLES_BY_CATEGORY.items():
                if cat == selected_category:
                    example_options = [f"Example {start_idx+i+1}: {ex['question']}" for i, ex in enumerate(examples)]
                    break
                start_idx += len(examples)
        
        selected_example = st.selectbox("Select an example question", [""] + example_options)
        
        # Process selected example
        if selected_example:
            try:
                example_idx = int(selected_example.split(":")[0].split()[-1]) - 1
                example = CODING_EXAMPLES[example_idx]
                question = example["question"]
                choices = "\n".join(f"{chr(65+i)}. {choice}" for i, choice in enumerate(example["choices"]))
            except:
                question = ""
                choices = ""
        else:
            question = ""
            choices = ""
        
        st.subheader("Your Question")
        question_input = st.text_area("Question", value=question, height=100, 
                                     placeholder="Enter your coding question here...")
        
        choices_input = st.text_area("Choices", value=choices, height=150,
                                    placeholder="Enter each choice on a new line...")
        
        # Model Parameters
        temperature = st.slider("Temperature", 0.0, 1.0, 0.7, 0.1)
        
        with st.expander("Advanced Parameters"):
            max_tokens = st.slider("Max Tokens", 128, 1024, 512, 128)
            top_p = st.slider("Top-p", 0.1, 1.0, 0.9, 0.1)
            top_k = st.slider("Top-k", 1, 100, 50, 10)
            repetition_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.1, 0.1)
            do_sample = st.checkbox("Enable Sampling", True)
        
        # Load model button
        if not st.session_state.model_loaded:
            if st.button("Load Model", type="primary"):
                try:
                    st.session_state.model_handler = QwenModelHandler(MODEL_PATH)
                    st.session_state.prompt_creator = PromptCreator("yaml")
                    st.session_state.model_loaded = True
                    # st.experimental_rerun()
                    st.rerun()
                except Exception as e:
                    st.error(f"Failed to load model: {str(e)}")
        
        # Generate button
        if st.session_state.model_loaded:
            generate_button = st.button("Generate Response", type="primary")
        else:
            st.info("Please load the model first")
            generate_button = False
    
    with col2:
        # Show prompt
        st.subheader("Model Input")
        if question_input and choices_input:
            prompt = create_prompt(question_input, choices_input)
            st.text_area("Prompt", value=prompt, height=200, disabled=True)
        else:
            st.text_area("Prompt", value="", height=200, disabled=True)
        
        # Results Area
        st.subheader("Model Response")
        st.text_area("Response", value=st.session_state.model_output, height=300)
        
        # YAML parsing
        if st.session_state.model_output:
            try:
                with st.expander("Raw Output"):
                    st.code(st.session_state.model_output, language="yaml")
                
                try:
                    yaml_data = yaml.safe_load(st.session_state.model_output)
                    with st.expander("Parsed Output", expanded=True):
                        st.json(yaml_data)
                except:
                    st.warning("Could not parse output as YAML")
            except:
                pass
    
    # Handle generation
    if generate_button and st.session_state.model_loaded:
        if not question_input or not choices_input:
            st.error("Please provide both a question and choices.")
        else:
            try:
                prompt = st.session_state.prompt_creator.create_inference_prompt(question_input, choices_input)
                with st.spinner("Generating response..."):
                    response = st.session_state.model_handler.generate_response(
                        prompt=prompt,
                        max_tokens=max_tokens,
                        temperature=temperature,
                        top_p=top_p,
                        top_k=top_k,
                        repetition_penalty=repetition_penalty,
                        do_sample=do_sample
                    )
                    st.session_state.model_output = response
                    st.experimental_rerun()
            except Exception as e:
                st.error(f"Error generating response: {e}")

if __name__ == "__main__":
    main()