import streamlit as st
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM
import torch
import re
from transformers import StoppingCriteria, StoppingCriteriaList
import os

# Set cache directory and get token
os.environ['HF_HOME'] = '/app/cache'
hf_token = os.getenv('HF_TOKEN')

class StopWordCriteria(StoppingCriteria):
    def __init__(self, tokenizer, stop_word):
        self.stop_word_id = tokenizer.encode(stop_word, add_special_tokens=False)

    def __call__(self, input_ids, scores, **kwargs):
        if len(input_ids[0]) >= len(self.stop_word_id) and input_ids[0][-len(self.stop_word_id):].tolist() == self.stop_word_id:
            return True
        return False

def load_model():
    try:
        # Ensure cache directory exists
        cache_dir = '/app/cache'
        os.makedirs(cache_dir, exist_ok=True)
        
        # Check for HF token
        if not hf_token:
            st.warning("HuggingFace token not found. Some models may not be accessible.")
        
        # Check CUDA availability
        if torch.cuda.is_available():
            device = torch.device("cuda")
            st.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
        else:
            device = torch.device("cpu")
            st.warning("CUDA is not available. Using CPU.")

        # Fine-tuned model for generating scripts
        model_name = "Sidharthan/gemma2_scripter"
        
        try:
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True,
                token=hf_token,
                cache_dir=cache_dir
            )
        except Exception as e:
            st.error(f"Error loading tokenizer: {str(e)}")
            if "401" in str(e):
                st.error("Authentication error. Please check your HuggingFace token.")
            raise e
        
        try:
            # Load model with appropriate device settings
            model = AutoPeftModelForCausalLM.from_pretrained(
                model_name,
                device_map=None,  # We'll handle device placement manually
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                trust_remote_code=True,
                low_cpu_mem_usage=True,
                token=hf_token,
                cache_dir=cache_dir
            )
            
            # Move model to device
            model = model.to(device)
            
            return model, tokenizer

        except Exception as e:
            st.error(f"Error loading model: {str(e)}")
            if "401" in str(e):
                st.error("Authentication error. Please check your HuggingFace token.")
            elif "disk space" in str(e).lower():
                st.error("Insufficient disk space in cache directory.")
            raise e

    except Exception as e:
        st.error(f"General error during model loading: {str(e)}")
        raise e

def generate_script(tags, model, tokenizer, params):
    device = next(model.parameters()).device
    
    # Create prompt with tags
    prompt = f"<bos><start_of_turn>keywords\n{tags}<end_of_turn>\n<start_of_turn>script\n"
    
    # Tokenize and move to device
    inputs = tokenizer(prompt, return_tensors='pt')
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    stop_word = 'script'
    stopping_criteria = StoppingCriteriaList([StopWordCriteria(tokenizer, stop_word)])
    
    try:
        outputs = model.generate(
            **inputs,
            max_length=params['max_length'],
            do_sample=True,
            temperature=params['temperature'],
            top_p=params['top_p'],
            top_k=params['top_k'],
            repetition_penalty=params['repetition_penalty'],
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            stopping_criteria=stopping_criteria
        )
        
        # Move outputs back to CPU for decoding
        outputs = outputs.cpu()
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean up response
        response = re.sub(r'keywords\s.*?script\s', '', response, flags=re.DOTALL)
        response = re.sub(r'\bscript\b.*$', '', response, flags=re.IGNORECASE).strip()
        
        return response
        
    except RuntimeError as e:
        if "out of memory" in str(e):
            st.error("GPU out of memory error. Try reducing max_length or using CPU.")
            return "Error: GPU out of memory"
        else:
            st.error(f"Error during generation: {str(e)}")
            return f"Error during generation: {str(e)}"

def main():
    st.title("🎥 YouTube Script Generator")
    
    # Sidebar for model parameters
    st.sidebar.title("Generation Parameters")
    params = {
        'max_length': st.sidebar.slider('Max Length', 64, 1024, 512),
        'temperature': st.sidebar.slider('Temperature', 0.1, 1.0, 0.7),
        'top_p': st.sidebar.slider('Top P', 0.1, 1.0, 0.95),
        'top_k': st.sidebar.slider('Top K', 1, 100, 50),
        'repetition_penalty': st.sidebar.slider('Repetition Penalty', 1.0, 2.0, 1.2)
    }
    
    # Load model and tokenizer
    @st.cache_resource
    def get_model():
        return load_model()
    
    try:
        model, tokenizer = get_model()
        
        # Tag input section
        st.markdown("### Add Tags")
        st.markdown("Enter tags separated by commas to generate a YouTube script")
        
        # Create columns for tag input and generate button
        col1, col2 = st.columns([3, 1])
        
        with col1:
            tags = st.text_input("Enter tags", placeholder="tech, AI, future, innovations...")
        
        with col2:
            generate_button = st.button("Generate Script", type="primary")
        
        # Generated script section
        if generate_button and tags:
            st.markdown("### Generated Script")
            with st.spinner("Generating script..."):
                script = generate_script(tags, model, tokenizer, params)
                st.text_area("Your script:", value=script, height=400)
                
                # Add download button
                st.download_button(
                    label="Download Script",
                    data=script,
                    file_name="youtube_script.txt",
                    mime="text/plain"
                )
        
        elif generate_button and not tags:
            st.warning("Please enter some tags first!")
            
    except Exception as e:
        st.error("Failed to initialize the application. Please check the logs for details.")
        st.error(f"Error: {str(e)}")

if __name__ == "__main__":
    main()