import os
import urllib
import requests
from bs4 import BeautifulSoup
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
import feedparser

# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Define device and load model and tokenizer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"

# Load model and tokenizer
try:
    logger.debug("Attempting to load the model and tokenizer")
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    logger.debug("Model and tokenizer loaded successfully")
except Exception as e:
    logger.error(f"Error loading model and tokenizer: {e}")
    model = None
    tokenizer = None

# Function to fetch news from Google News RSS feed
def fetch_news(term, num_results=2):
    logger.debug(f"Fetching news for term: {term}")
    encoded_term = urllib.parse.quote(term)
    url = f"https://news.google.com/rss/search?q={encoded_term}"
    feed = feedparser.parse(url)
    results = []
    for entry in feed.entries[:num_results]:
        results.append({"link": entry.link, "text": entry.title})
    logger.debug(f"Fetched news results: {results}")
    return results

# Function to perform a Google search and return the results
def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None):
    logger.debug(f"Starting search for term: {term}")
    escaped_term = urllib.parse.quote_plus(term)
    start = 0
    all_results = []
    max_chars_per_page = 8000

    with requests.Session() as session:
        while start < num_results:
            try:
                resp = session.get(
                    url="https://www.google.com/search",
                    headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
                    params={
                        "q": term,
                        "num": num_results - start,
                        "hl": lang,
                        "start": start,
                        "safe": safe,
                    },
                    timeout=timeout,
                    verify=ssl_verify,
                )
                resp.raise_for_status()
                soup = BeautifulSoup(resp.text, "html.parser")
                result_block = soup.find_all("div", attrs={"class": "g"})
                if not result_block:
                    start += 1
                    continue
                for result in result_block:
                    link = result.find("a", href=True)
                    if link:
                        link = link["href"]
                        try:
                            webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"})
                            webpage.raise_for_status()
                            visible_text = extract_text_from_webpage(webpage.text)
                            if len(visible_text) > max_chars_per_page:
                                visible_text = visible_text[:max_chars_per_page] + "..."
                            all_results.append({"link": link, "text": visible_text})
                        except requests.exceptions.RequestException as e:
                            logger.error(f"Error fetching or processing {link}: {e}")
                            all_results.append({"link": link, "text": None})
                    else:
                        all_results.append({"link": None, "text": None})
                start += len(result_block)
            except Exception as e:
                logger.error(f"Error during search: {e}")
                break
    logger.debug(f"Search results: {all_results}")
    return all_results

# Function to extract visible text from HTML content
def extract_text_from_webpage(html_content):
    soup = BeautifulSoup(html_content, "html.parser")
    for tag in soup(["script", "style", "header", "footer", "nav"]):
        tag.extract()
    visible_text = soup.get_text(strip=True)
    return visible_text

# Function to format the prompt for the language model
def format_prompt(user_prompt, chat_history):
    logger.debug(f"Formatting prompt with user prompt: {user_prompt} and chat history: {chat_history}")
    prompt = ""
    for item in chat_history:
        prompt += f"User: {item[0]}\nAssistant: {item[1]}\n"
    prompt += f"User: {user_prompt}\nAssistant:"
    logger.debug(f"Formatted prompt: {prompt}")
    return prompt

# Function for model inference
def model_inference(
        user_prompt,
        chat_history,
        web_search,
        temperature,
        max_new_tokens,
        repetition_penalty,
        top_p,
        tokenizer  # Pass tokenizer as an argument
):
    logger.debug(f"Starting model inference with user prompt: {user_prompt}, chat history: {chat_history}, web_search: {web_search}")
    if not isinstance(user_prompt, dict):
        logger.error("Invalid input format. Expected a dictionary.")
        return "Invalid input format. Expected a dictionary."

    if "files" not in user_prompt:
        user_prompt["files"] = []

    if not user_prompt["files"]:
        if web_search:
            logger.debug("Performing news search")
            news_results = fetch_news(user_prompt["text"])
            news2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in news_results])
            formatted_prompt = format_prompt(f"{user_prompt['text']} [NEWS] {news2}", chat_history)
            inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
            if model:
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    repetition_penalty=repetition_penalty,
                    do_sample=True,
                    temperature=temperature,
                    top_p=top_p
                )
                response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            else:
                response = "Model is not available. Please try again later."
            logger.debug(f"Model response: {response}")
            return response
        else:
            formatted_prompt = format_prompt(user_prompt["text"], chat_history)
            inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
            if model:
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    repetition_penalty=repetition_penalty,
                    do_sample=True,
                    temperature=temperature,
                    top_p=top_p
                )
                response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            else:
                response = "Model is not available. Please try again later."
            logger.debug(f"Model response: {response}")
            return response
    else:
        return "Image input not supported in this implementation."

# Define Gradio interface components
max_new_tokens = gr.Slider(
    minimum=1,
    maximum=16000,
    value=2048,
    step=64,
    interactive=True,
    label="Maximum number of new tokens to generate",
)
repetition_penalty = gr.Slider(
    minimum=0.01,
    maximum=5.0,
    value=1,
    step=0.01,
    interactive=True,
    label="Repetition penalty",
    info="1.0 is equivalent to no penalty",
)
decoding_strategy = gr.Radio(
    [
        "Greedy",
        "Top P Sampling",
    ],
    value="Top P Sampling",
    label="Decoding strategy",
    interactive=True,
    info="Higher values are equivalent to sampling more low-probability tokens.",
)
temperature = gr.Slider(
    minimum=0.0,
    maximum=2.0,
    value=0.5,
    step=0.05,
    visible=True,
    interactive=True,
    label="Sampling temperature",
    info="Higher values will produce more diverse outputs.",
)
top_p = gr.Slider(
    minimum=0.01,
    maximum=0.99,
    value=0.9,
    step=0.01,
    visible=True,
    interactive=True,
    label="Top P",
    info="Higher values are equivalent to sampling more low-probability tokens.",
)

# Create a chatbot interface
chatbot = gr.Chatbot(
    label="OpenGPT-4o-Chatty",
    show_copy_button=True,
    likeable=True,
    layout="panel"
)

# Define Gradio interface
def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p):
    # Ensure the tokenizer is accessible within the function scope
    global tokenizer

    # Wrap the user input in a dictionary as expected by the model_inference function
    user_prompt = {"text": user_input, "files": []}
    
    # Perform model inference
    response = model_inference(
        user_prompt=user_prompt,
        chat_history=history,
        web_search=web_search,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        repetition_penalty=repetition_penalty,
        top_p=top_p,
        tokenizer=tokenizer  # Pass tokenizer to the model_inference function
    )
    
    # Update history with the user input and model response
    history.append((user_input, response))
    
    # Return the response and updated history
    return response, history

# Define the Gradio interface components
interface = gr.Interface(
    fn=chat_interface,
    inputs=[
        gr.Textbox(label="User Input", placeholder="Type your message here..."),
        gr.State([]),  # Initialize the chat history as an empty list
        gr.Checkbox(label="Perform Web Search"),
        gr.Radio(["Greedy", "Top P Sampling"], label="Decoding strategy"),
        gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label="Sampling temperature", value=0.5),
        gr.Slider(minimum=1, maximum=16000, step=64, label="Maximum number of new tokens to generate", value=2048),
        gr.Slider(minimum=0.01, maximum=5.0, step=0.01, label="Repetition penalty", value=1),
        gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label="Top P", value=0.9)
    ],
    outputs=[
        gr.Textbox(label="Assistant Response"),
        gr.State([])  # Update the chat history
    ],
    live=True
)

# Launch the Gradio interface
interface.launch()