{ "cells": [ { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from model import NBAModel, NBAConfig\n", "from torch import device as torch_device, load as torch_load, int32, Tensor, bfloat16\n", "import matplotlib.pyplot as plt\n", "\n", "device = torch_device(\"cpu\") \n", "num_age_tokens=32\n", "num_player_tokens=5141\n", "num_net_score_tokens=41\n", "players_per_team=8\n", "\n", "model_config = NBAConfig(\n", " players_per_team=players_per_team,\n", " player_tokens=num_player_tokens+2,\n", " age_tokens=num_age_tokens+2,\n", " num_labels=num_net_score_tokens+2,\n", " n_layer=4,\n", " n_head=4,\n", " n_embd=1024,\n", " dropout=0.0,\n", " bias=False,\n", " dtype=bfloat16,\n", " seed=29,\n", ")\n", "\n", "model = NBAModel(model_config).to(device)\n", "state_dict = torch_load('weights.pt', map_location='cpu')\n", "model.load_state_dict(state_dict)\n", "model = model.eval()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Home team win probability: 0.66\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Change player and age tokens here!\n", "# You can find these values in player_tokens.csv and age_tokens.csv\n", "# You must provide exactly 8 player tokens and 8 age tokens for each team.\n", "\n", "# Denver Nuggets first game of 2023-24 season roster\n", "home_player_tokens = [5035, 4298, 4626, 4690, 4750, 5082, 4286, 4311]\n", "home_age_tokens = [14, 16, 13, 12, 10, 19, 8, 8]\n", "\n", "# Uncomment to take Jokic off team, replace with Peyton Watson\n", "# home_player_tokens = [4331, 4298, 4626, 4690, 4750, 5082, 4286, 4311]\n", "# home_age_tokens = [6, 16, 13, 12, 10, 19, 8, 8]\n", "\n", "# Boston Celtics final game of 2023-24 season roster\n", "away_player_tokens = [5042, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n", "away_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n", "\n", "# Uncomment to take Tatum off team, replace with Pritchard\n", "# away_player_tokens = [4999, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n", "# away_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n", "\n", "# The model usually gives the home team a bump in win probability.\n", "# Change this to \"True\" to swap home and away teams.\n", "swap_home_away = False\n", "if swap_home_away:\n", " home_player_tokens, away_player_tokens = away_player_tokens, home_player_tokens\n", " home_age_tokens, away_age_tokens = away_age_tokens, home_age_tokens\n", "\n", "assert len(home_player_tokens) == players_per_team\n", "assert len(home_age_tokens) == players_per_team\n", "assert len(away_player_tokens) == players_per_team\n", "assert len(away_age_tokens) == players_per_team\n", "\n", "batch = {\n", " 'home_player_tokens': Tensor([num_player_tokens+1] + home_player_tokens).to(dtype=int32).unsqueeze(0),\n", " 'home_age_tokens': Tensor([num_age_tokens+1] + home_age_tokens).to(dtype=int32).unsqueeze(0),\n", " 'away_player_tokens': Tensor(away_player_tokens).to(dtype=int32).unsqueeze(0),\n", " 'away_age_tokens': Tensor(away_age_tokens).to(dtype=int32).unsqueeze(0),\n", "}\n", "\n", "for key, value in batch.items():\n", " if hasattr(value, 'to'):\n", " batch[key] = value.to(device)\n", "\n", "output, _ = model(**batch)\n", "output = output.squeeze().softmax(dim=0)\n", "\n", "probs = {}\n", "loss_prob = 0\n", "win_prob = 0\n", "\n", "first = True\n", "for i, token in enumerate(output):\n", " if first:\n", " first = False\n", " continue\n", "\n", " if i-21 < 0:\n", " loss_prob += token.item()\n", " elif i-21 > 0 and i-21 < 21:\n", " win_prob += token.item()\n", "\n", " probs[i-21] = token.item()\n", "\n", "del probs[0]\n", "del probs[21]\n", "\n", "print(f\"Home team win probability: {win_prob:.2f}\")\n", "\n", "plt.bar(probs.keys(), probs.values())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "nba", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 2 }