import gradio as gr
import os
import requests
import torch
from transformers import (
    LEDTokenizer, LEDForConditionalGeneration,
    BartTokenizer, BartForConditionalGeneration,
    PegasusTokenizer, PegasusForConditionalGeneration,
    AutoTokenizer, AutoModelForSeq2SeqLM
)

# OpenAI API Key
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")  # Ensure this is set in your environment variables

# List of models in priority order
MODELS = [
    {
        "name": "allenai/led-large-16384",
        "tokenizer_class": LEDTokenizer,
        "model_class": LEDForConditionalGeneration
    },
    {
        "name": "facebook/bart-large-cnn",
        "tokenizer_class": BartTokenizer,
        "model_class": BartForConditionalGeneration
    },
    {
        "name": "Falconsai/text_summarization",
        "tokenizer_class": AutoTokenizer,
        "model_class": AutoModelForSeq2SeqLM
    },
    {
        "name": "google/pegasus-xsum",
        "tokenizer_class": PegasusTokenizer,
        "model_class": PegasusForConditionalGeneration
    }
]

# Load models sequentially
loaded_models = []
for model_info in MODELS:
    try:
        tokenizer = model_info["tokenizer_class"].from_pretrained(model_info["name"])
        model = model_info["model_class"].from_pretrained(model_info["name"])
        loaded_models.append({"name": model_info["name"], "tokenizer": tokenizer, "model": model})
        print(f"Loaded model: {model_info['name']}")
    except Exception as e:
        print(f"Failed to load {model_info['name']}: {e}")

def summarize_with_transformers(text):
    """
    Try summarizing with locally loaded Transformer models in order of priority.
    """
    for model_data in loaded_models:
        try:
            tokenizer = model_data["tokenizer"]
            model = model_data["model"]

            # Tokenize input with truncation
            inputs = tokenizer([text], max_length=16384, return_tensors="pt", truncation=True)

            # Generate summary
            summary_ids = model.generate(
                inputs["input_ids"],
                num_beams=4,
                max_length=512,
                min_length=100,
                early_stopping=True
            )

            summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
            return summary  # Return the first successful response

        except Exception as e:
            print(f"Error using {model_data['name']}: {e}")

    return None  # Indicate failure

def summarize_with_chatgpt(text):
    """
    Fallback to OpenAI ChatGPT API if all other models fail.
    """
    if not OPENAI_API_KEY:
        return "Error: No OpenAI API key provided."

    headers = {
        "Authorization": f"Bearer {OPENAI_API_KEY}",
        "Content-Type": "application/json"
    }
    
    payload = {
        "model": "gpt-3.5-turbo",
        "messages": [{"role": "user", "content": f"Summarize this article: {text}"}],
        "max_tokens": 512
    }

    response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
    
    if response.status_code == 200:
        return response.json()["choices"][0]["message"]["content"]
    else:
        return f"Error: Failed to summarize with ChatGPT (status {response.status_code})"

def summarize_text(text):
    """
    Main function to summarize text, trying Transformer models first, then ChatGPT if needed.
    """
    summary = summarize_with_transformers(text)
    
    if summary:
        return summary  # Return successful summary from a Transformer model
    
    print("All Transformer models failed. Falling back to ChatGPT...")
    return summarize_with_chatgpt(text)  # Use ChatGPT as last resort

# Gradio Interface
iface = gr.Interface(
    fn=summarize_text,
    inputs="text",
    outputs="text",
    title="Multi-Model Summarizer with Fallback",
    description="Tries multiple models for summarization, falling back to ChatGPT if needed."
)

if __name__ == "__main__":
    iface.launch()