Spaces:
Sleeping
Sleeping
import streamlit as st | |
import streamlit.components.v1 as components | |
from streamlit.components.v1 import html | |
import chess | |
import streamlit_scrollable_textbox as stx | |
from st_bridge import bridge | |
from modules.chess import Chess | |
from modules.utility import set_page | |
from modules.states import init_states | |
import datetime as dt | |
from gradio_client import Client | |
import random | |
set_page(title='Chess vs LLaMA 3.1 405B', page_icon="♟️") | |
init_states() | |
st.session_state.board_width = 400 | |
# Initialize the LLaMA 3.1 405B client | |
llama_client = Client("xianbao/SambaNova-fast") | |
# Initialize all session state variables | |
if 'player_color' not in st.session_state: | |
st.session_state.player_color = 'white' | |
if 'current_turn' not in st.session_state: | |
st.session_state.current_turn = 'white' | |
if 'game_started' not in st.session_state: | |
st.session_state.game_started = False | |
if 'curfen' not in st.session_state: | |
st.session_state.curfen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" | |
if 'lastfen' not in st.session_state: | |
st.session_state.lastfen = None | |
if 'moves' not in st.session_state: | |
st.session_state.moves = {} | |
if 'curside' not in st.session_state: | |
st.session_state.curside = 'white' | |
def get_ai_move(fen): | |
board = chess.Board(fen) | |
legal_moves = list(board.legal_moves) | |
if not legal_moves: | |
return None | |
prompt = f"You are a chess engine. Given the following chess position in FEN notation: {fen}, suggest a good move. Respond with only the move in UCI notation (e.g., e2e4)." | |
for _ in range(3): # Try up to 3 times to get a valid move from the AI | |
try: | |
response = llama_client.predict( | |
message=prompt, | |
system_message="You are a chess engine assistant.", | |
max_tokens=10, | |
temperature=0.7, # Increased temperature for more varied moves | |
top_p=0.9, | |
top_k=50, | |
api_name="/chat" | |
) | |
move = chess.Move.from_uci(response.strip()) | |
if move in legal_moves: | |
return move.uci() | |
except ValueError: | |
pass # If the AI produces an invalid move, we'll try again | |
# If the AI fails to produce a valid move after 3 attempts, choose a random legal move | |
return random.choice(legal_moves).uci() | |
def reset_game(player_color): | |
st.session_state.player_color = player_color | |
st.session_state.curfen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" | |
st.session_state.moves = {} | |
st.session_state.current_turn = 'white' | |
st.session_state.game_started = True | |
st.session_state.lastfen = None | |
st.session_state.game_over = False | |
# If player chose black, make the first move for AI | |
if st.session_state.player_color == 'black': | |
ai_move = get_ai_move(st.session_state.curfen) | |
board = chess.Board(st.session_state.curfen) | |
if ai_move: | |
move = chess.Move.from_uci(ai_move) | |
board.push(move) | |
st.session_state.curfen = board.fen() | |
st.session_state.moves.update( | |
{ | |
st.session_state.curfen : { | |
'side': 'white', | |
'curfen': st.session_state.curfen, | |
'last_fen': "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", | |
'last_move': ai_move, | |
'data': None, | |
'timestamp': str(dt.datetime.now()) | |
} | |
} | |
) | |
st.session_state.current_turn = 'black' | |
def check_game_end(board): | |
outcome = board.outcome() | |
if outcome: | |
st.session_state.game_over = True | |
if outcome.winner is None: | |
return "Draw" | |
return "White" if outcome.winner else "Black" | |
return None | |
st.title("Chess vs LLaMA 3.1 405B") | |
# Game controls | |
col1, col2, col3 = st.columns([1,1,1]) | |
with col1: | |
player_color = st.selectbox("Choose your color", ['white', 'black'], key='color_select') | |
with col2: | |
if st.button('Start New Game', key='start_game'): | |
reset_game(player_color) | |
st.rerun() | |
with col3: | |
st.write(f"Current turn: {st.session_state.current_turn}") | |
st.write(f"Your color: {st.session_state.player_color}") | |
# Get the info from current board after the user made the move. | |
data = bridge("my-bridge") | |
if data is not None and st.session_state.game_started and not st.session_state.game_over: | |
st.session_state.lastfen = st.session_state.curfen | |
st.session_state.curfen = data['fen'] | |
st.session_state.curside = data['move']['color'].replace('w','white').replace('b','black') | |
st.session_state.moves.update( | |
{ | |
st.session_state.curfen : { | |
'side':st.session_state.curside, | |
'curfen':st.session_state.curfen, | |
'last_fen':st.session_state.lastfen, | |
'last_move':data['pgn'], | |
'data': None, | |
'timestamp': str(dt.datetime.now()) | |
} | |
} | |
) | |
st.session_state.current_turn = 'white' if st.session_state.curside == 'black' else 'black' | |
board = chess.Board(st.session_state.curfen) | |
game_result = check_game_end(board) | |
if game_result: | |
st.success(f"Game Over! Winner: {game_result}") | |
elif st.session_state.current_turn != st.session_state.player_color: | |
# AI's turn | |
ai_move = get_ai_move(st.session_state.curfen) | |
if ai_move: | |
move = chess.Move.from_uci(ai_move) | |
board.push(move) | |
st.session_state.curfen = board.fen() | |
st.session_state.moves.update( | |
{ | |
st.session_state.curfen : { | |
'side': st.session_state.current_turn, | |
'curfen': st.session_state.curfen, | |
'last_fen': st.session_state.lastfen, | |
'last_move': ai_move, | |
'data': None, | |
'timestamp': str(dt.datetime.now()) | |
} | |
} | |
) | |
st.session_state.current_turn = st.session_state.player_color | |
game_result = check_game_end(board) | |
if game_result: | |
st.success(f"Game Over! Winner: {game_result}") | |
else: | |
st.error("The AI couldn't make a move. The game may be over.") | |
# Main game display | |
cols = st.columns([3, 2]) | |
with cols[0]: | |
if st.session_state.game_started: | |
puzzle = Chess(st.session_state.board_width, st.session_state.curfen) | |
components.html( | |
puzzle.puzzle_board(), | |
height=st.session_state.board_width + 75, | |
scrolling=False | |
) | |
board = chess.Board(st.session_state.curfen) | |
# Game status | |
status_col1, status_col2 = st.columns(2) | |
with status_col1: | |
st.write("Game Status:") | |
st.write(f"Check: {'Yes' if board.is_check() else 'No'}") | |
st.write(f"Checkmate: {'Yes' if board.is_checkmate() else 'No'}") | |
with status_col2: | |
st.write("\u200B") # Invisible character for alignment | |
st.write(f"Stalemate: {'Yes' if board.is_stalemate() else 'No'}") | |
st.write(f"Insufficient material: {'Yes' if board.is_insufficient_material() else 'No'}") | |
if st.session_state.game_over: | |
st.success(f"Game Over! Winner: {check_game_end(board)}") | |
else: | |
st.info("Welcome to Chess vs LLaMA 3.1 405B!") | |
st.write("To start a new game:") | |
st.write("1. Choose your color (white or black)") | |
st.write("2. Click 'Start New Game'") | |
st.write("3. Make your moves on the chess board") | |
st.write("Enjoy playing against the AI!") | |
with cols[1]: | |
if st.session_state.game_started: | |
st.subheader("Move History") | |
records = [ | |
f"##### {value['timestamp'].split('.')[0]} \n {value['side']} - {value.get('last_move','')}" | |
for key, value in st.session_state['moves'].items() | |
] | |
stx.scrollableTextbox('\n\n'.join(records), height = 400, border=True) | |
else: | |
st.image("https://upload.wikimedia.org/wikipedia/commons/6/6f/ChessSet.jpg", caption="Chess pieces", use_column_width=True) |