import spaces

import os
os.environ["KERAS_BACKEND"] = "torch"  # "jax", "torch" or "tensorflow"

import gradio as gr
import keras_nlp
import keras
# import spaces
import torch

from typing import Iterator
import time

from chess_board import Game
from datasets import load_dataset
import google.generativeai as genai


print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")


DESCRIPTION = """
# Chess Tutor AI
**Welcome to the Chess Chatbot!**

The goal of this project is to showcase the use of AI in learning chess. This app allows you to play a game against a custom fine-tuned model (Gemma 2B).\n
The challenge is that input must be in *algebraic notation*.

## Features

### For New & Beginner Players
- The chat interface uses the Gemini API, if you need help with chess rules or learning algebraic notation, just ask!

### For Advanced Users
- Pick an opening to play, and ask Gemini for more info.

Enjoy your game!  
**- Valentin**
"""

api_key = os.getenv("GEMINI_API_KEY")
genai.configure(api_key = api_key)

model = genai.GenerativeModel(model_name='gemini-1.5-flash-latest')

chat = model.start_chat()

ds = load_dataset("Lichess/chess-openings", split="train")
df = ds.to_pandas()

opening_names = df['name'].unique().tolist()


# @spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    ) -> Iterator[str]:

    response = chat.send_message(message)

    outputs = ""
    
    for char in response.text:
        outputs += char
        yield outputs


def get_opening_details(opening_name):
    opening_data = df[df['name'] == opening_name].iloc[0]
    moves = opening_data['pgn']
    return f"Opening: {opening_data['name']}\nMoves: {moves}"

def get_move_list(opening_name):
    opening_data = df[df['name'] == opening_name].iloc[0]
    moves = opening_data['pgn']
    pgn_string = moves.split()
    return [move for idx,move in enumerate(pgn_string[1:],1) if idx%3!=0]
   

chat_interface = gr.ChatInterface(
    fn=generate,
    stop_btn=None,
    examples=[
        ["Hi Gemini, what is a good first move in chess?"],
        ["How does the Knight move?"],
        ["Explain algebraic notation for capturing a piece in chess?"]
    ],
    cache_examples=False,
    type="messages",
)

    
with gr.Blocks(css_paths="styles.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
        
    play_match = Game()

    with gr.Row():
        with gr.Column():
            board_image = gr.HTML(play_match.display_board())
        with gr.Column():
            chat_interface.render()

    game_logs = gr.Label(label="Game Logs", elem_classes=["big-text"])
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Play a Match vs Gemma")

            move_input = gr.Textbox(label="Enter your move in algebraic notation: (e.g., e4, Nf3, Bxc4)")
            submit_move = gr.Button("Submit Move")
            submit_move.click(play_match.generate_moves, inputs=move_input, outputs=[board_image, game_logs])
            submit_move.click(lambda x: gr.update(value=''), [],[move_input])

            reset_board = gr.Button("Reset Game")
            reset_board.click(play_match.reset_board, outputs=board_image)
            reset_board.click(lambda x: gr.update(value=''), [],[game_logs])

        with gr.Column():
            gr.Markdown("### Chess Openings Explorer")

            opening_choice = gr.Dropdown(label="Choose a Chess Opening", choices=opening_names)
            opening_output = gr.Textbox(label="Opening Details", lines=4)
            opening_moves = gr.State()

            opening_choice.change(fn=get_opening_details, inputs=opening_choice, outputs=opening_output)
            opening_choice.change(fn=get_move_list, inputs=opening_choice, outputs=opening_moves)


            load_opening = gr.Button("Load Opening")
            load_opening.click(play_match.reset_board, outputs=board_image)
            load_opening.click(play_match.load_opening, inputs=[opening_choice, opening_moves], outputs=game_logs)
    
if __name__ == "__main__":
    demo.queue(max_size=20).launch()