{ "cells": [ { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from model import NBAModel, NBAConfig\n", "from torch import device as torch_device, load as torch_load, int32, Tensor, bfloat16\n", "import matplotlib.pyplot as plt\n", "\n", "device = torch_device(\"cpu\") \n", "num_age_tokens=32\n", "num_player_tokens=5141\n", "num_net_score_tokens=41\n", "players_per_team=8\n", "\n", "model_config = NBAConfig(\n", " players_per_team=players_per_team,\n", " player_tokens=num_player_tokens+2,\n", " age_tokens=num_age_tokens+2,\n", " num_labels=num_net_score_tokens+2,\n", " n_layer=4,\n", " n_head=4,\n", " n_embd=1024,\n", " dropout=0.0,\n", " bias=False,\n", " dtype=bfloat16,\n", " seed=29,\n", ")\n", "\n", "model = NBAModel(model_config).to(device)\n", "state_dict = torch_load('weights.pt', map_location='cpu')\n", "model.load_state_dict(state_dict)\n", "model = model.eval()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Home team win probability: 0.66\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAliklEQVR4nO3df3BU9b3/8VcSSELQBCWQJRgJtpQfBRMJJIZaqTXD4qTVqMXIdCTNMDhaQOx6UwkXklrrDVcEYyFtSmfA/qKhTCulyM1tzDX0tlmlJKFerFJ0xCC4CbSXBIMkmP18//Dr2r0sJLuE7CfL8zFzRvbkfc6+P3Py4+Vnz48oY4wRAACAxaLD3QAAAEBfCCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsNC3cDA8Hr9er48eO6+uqrFRUVFe52AABAPxhjdPr0aaWmpio6+uJzKBERWI4fP660tLRwtwEAAEJw9OhRXXfddRetiYjAcvXVV0v6eMCJiYlh7gYAAPRHZ2en0tLSfH/HLyYiAssnHwMlJiYSWAAAGGL6czoHJ90CAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWG9YuBsAAADhlb7yxT5rjqzNH4ROLowZFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgvZACS1VVldLT0xUfH6+cnBzt27fvgrWvv/667r33XqWnpysqKkqVlZWXvE8AAHBlCTqwbN++XS6XS+Xl5WpublZGRoacTqfa29sD1p85c0Y33HCD1q5dK4fDMSD7BAAAV5agA8uGDRu0ZMkSFRcXa9q0aaqurlZCQoK2bNkSsH727Nlat26d7r//fsXFxQ3IPgEAwJUlqMDS09OjpqYm5eXlfbqD6Gjl5eXJ7XaH1EAo++zu7lZnZ6ffAgAAIldQgeXkyZPq7e1VSkqK3/qUlBR5PJ6QGghlnxUVFUpKSvItaWlpIb03AAAYGobkVUKlpaXq6OjwLUePHg13SwAA4DIaFkxxcnKyYmJi1NbW5re+ra3tgifUXo59xsXFXfB8GAAAEHmCmmGJjY1VVlaW6uvrfeu8Xq/q6+uVm5sbUgOXY58AACCyBDXDIkkul0tFRUWaNWuWsrOzVVlZqa6uLhUXF0uSFi1apPHjx6uiokLSxyfV/vWvf/X9+9ixYzpw4ICuuuoqffazn+3XPgEAwJUt6MBSWFioEydOqKysTB6PR5mZmaqtrfWdNNva2qro6E8nbo4fP66bbrrJ9/qZZ57RM888o7lz56qhoaFf+wQAAFe2KGOMCXcTl6qzs1NJSUnq6OhQYmJiuNsBAGBISV/5Yp81R9bmD/j7BvP3e0heJQQAAK4sBBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPVCCixVVVVKT09XfHy8cnJytG/fvovW79ixQ1OmTFF8fLxmzJihPXv2+H39gw8+0LJly3TddddpxIgRmjZtmqqrq0NpDQAARKCgA8v27dvlcrlUXl6u5uZmZWRkyOl0qr29PWB9Y2OjFi5cqMWLF6ulpUUFBQUqKCjQwYMHfTUul0u1tbX6+c9/rjfeeEOPPvqoli1bpl27doU+MgAAEDGijDEmmA1ycnI0e/Zsbdq0SZLk9XqVlpam5cuXa+XKlefVFxYWqqurS7t37/atu/nmm5WZmembRZk+fboKCwu1Zs0aX01WVpbuuOMOfe973+uzp87OTiUlJamjo0OJiYnBDAcAgCte+soX+6w5sjZ/wN83mL/fQc2w9PT0qKmpSXl5eZ/uIDpaeXl5crvdAbdxu91+9ZLkdDr96ufMmaNdu3bp2LFjMsbo5Zdf1t/+9jfNmzcvmPYAAECEGhZM8cmTJ9Xb26uUlBS/9SkpKXrzzTcDbuPxeALWezwe3+uNGzfqwQcf1HXXXadhw4YpOjpaP/7xj3XrrbcG3Gd3d7e6u7t9rzs7O4MZBgAAGGKsuEpo48aNeuWVV7Rr1y41NTVp/fr1Wrp0qV566aWA9RUVFUpKSvItaWlpg9wxAAAYTEHNsCQnJysmJkZtbW1+69va2uRwOAJu43A4Llr/4YcfatWqVXrhhReUn//x52M33nijDhw4oGeeeea8j5MkqbS0VC6Xy/e6s7OT0AIAQAQLaoYlNjZWWVlZqq+v963zer2qr69Xbm5uwG1yc3P96iWprq7OV3/u3DmdO3dO0dH+rcTExMjr9QbcZ1xcnBITE/0WAAAQuYKaYZE+vgS5qKhIs2bNUnZ2tiorK9XV1aXi4mJJ0qJFizR+/HhVVFRIklasWKG5c+dq/fr1ys/PV01Njfbv36/NmzdLkhITEzV37lyVlJRoxIgRmjBhgvbu3auf/vSn2rBhwwAOFQAADFVBB5bCwkKdOHFCZWVl8ng8yszMVG1tre/E2tbWVr/Zkjlz5mjbtm1avXq1Vq1apUmTJmnnzp2aPn26r6ampkalpaX6+te/rn/84x+aMGGCnnrqKT300EMDMEQAADDUBX0fFhtxHxYAAEIXcfdhAQAACAcCCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgvZACS1VVldLT0xUfH6+cnBzt27fvovU7duzQlClTFB8frxkzZmjPnj3n1bzxxhu68847lZSUpJEjR2r27NlqbW0NpT0AABBhgg4s27dvl8vlUnl5uZqbm5WRkSGn06n29vaA9Y2NjVq4cKEWL16slpYWFRQUqKCgQAcPHvTVvP3227rllls0ZcoUNTQ06LXXXtOaNWsUHx8f+sgAAEDEiDLGmGA2yMnJ0ezZs7Vp0yZJktfrVVpampYvX66VK1eeV19YWKiuri7t3r3bt+7mm29WZmamqqurJUn333+/hg8frp/97GchDaKzs1NJSUnq6OhQYmJiSPsAAOBKlb7yxT5rjqzNH/D3Debvd1AzLD09PWpqalJeXt6nO4iOVl5entxud8Bt3G63X70kOZ1OX73X69WLL76oz33uc3I6nRo7dqxycnK0c+fOC/bR3d2tzs5OvwUAAESuoALLyZMn1dvbq5SUFL/1KSkp8ng8AbfxeDwXrW9vb9cHH3ygtWvXav78+fr973+vu+++W/fcc4/27t0bcJ8VFRVKSkryLWlpacEMAwAADDFhv0rI6/VKku666y5961vfUmZmplauXKmvfOUrvo+M/q/S0lJ1dHT4lqNHjw5mywAAYJANC6Y4OTlZMTExamtr81vf1tYmh8MRcBuHw3HR+uTkZA0bNkzTpk3zq5k6dar++Mc/BtxnXFyc4uLigmkdAAAMYUHNsMTGxiorK0v19fW+dV6vV/X19crNzQ24TW5url+9JNXV1fnqY2NjNXv2bB06dMiv5m9/+5smTJgQTHsAACBCBTXDIkkul0tFRUWaNWuWsrOzVVlZqa6uLhUXF0uSFi1apPHjx6uiokKStGLFCs2dO1fr169Xfn6+ampqtH//fm3evNm3z5KSEhUWFurWW2/VbbfdptraWv3ud79TQ0PDwIwSAAAMaUEHlsLCQp04cUJlZWXyeDzKzMxUbW2t78Ta1tZWRUd/OnEzZ84cbdu2TatXr9aqVas0adIk7dy5U9OnT/fV3H333aqurlZFRYUeeeQRTZ48Wb/+9a91yy23DMAQAQDAUBf0fVhsxH1YAAAIXcTdhwUAACAcCCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPWGhbsBAMClS1/5Yp81R9bmD0InwOXBDAsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHrchwUAAuC+JoBdmGEBAADWY4YFAHBRfc02MdOEwcAMCwAAsB6BBQAAWI/AAgAArEdgAQAA1uOkWwAYRFwuDYSGGRYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANbjKiEAsBRXFAGfIrAAuGIQAIChi4+EAACA9UIKLFVVVUpPT1d8fLxycnK0b9++i9bv2LFDU6ZMUXx8vGbMmKE9e/ZcsPahhx5SVFSUKisrQ2kNAABEoKADy/bt2+VyuVReXq7m5mZlZGTI6XSqvb09YH1jY6MWLlyoxYsXq6WlRQUFBSooKNDBgwfPq33hhRf0yiuvKDU1NfiRAACAiBV0YNmwYYOWLFmi4uJiTZs2TdXV1UpISNCWLVsC1j/33HOaP3++SkpKNHXqVD355JOaOXOmNm3a5Fd37NgxLV++XL/4xS80fPjw0EYDAAAiUlCBpaenR01NTcrLy/t0B9HRysvLk9vtDriN2+32q5ckp9PpV+/1evXAAw+opKREn//85/vso7u7W52dnX4LAACIXEEFlpMnT6q3t1cpKSl+61NSUuTxeAJu4/F4+qz/93//dw0bNkyPPPJIv/qoqKhQUlKSb0lLSwtmGAAAYIgJ+1VCTU1Neu655/T8888rKiqqX9uUlpaqo6PDtxw9evQydwkAAMIpqMCSnJysmJgYtbW1+a1va2uTw+EIuI3D4bho/X//93+rvb1d119/vYYNG6Zhw4bp3Xff1WOPPab09PSA+4yLi1NiYqLfAgAAIldQgSU2NlZZWVmqr6/3rfN6vaqvr1dubm7AbXJzc/3qJamurs5X/8ADD+i1117TgQMHfEtqaqpKSkr0n//5n8GOBwAARKCg73TrcrlUVFSkWbNmKTs7W5WVlerq6lJxcbEkadGiRRo/frwqKiokSStWrNDcuXO1fv165efnq6amRvv379fmzZslSaNHj9bo0aP93mP48OFyOByaPHnypY4PAABEgKADS2FhoU6cOKGysjJ5PB5lZmaqtrbWd2Jta2uroqM/nbiZM2eOtm3bptWrV2vVqlWaNGmSdu7cqenTpw/cKAAAQEQL6VlCy5Yt07JlywJ+raGh4bx1CxYs0IIFC/q9/yNHjoTSFgAAiFA8/BAALhEPVQQuv7Bf1gwAANAXAgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHrcOA7AkMZN24ArAzMsAADAesywAAAGTF8zXsx2IVTMsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWGxbuBoBIl77yxT5rjqzNH4ROAGDoYoYFAABYjxkW4ArALA+AoY4ZFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9bgPCwAgLPq6PxD3BsI/Y4YFAABYj8ACAACsR2ABAADWI7AAAADrcdJtP/DgOAAAwovAAsAPAR024vsSBBYAuMLwxx9DEeewAAAA6xFYAACA9QgsAADAepzDAiBknAsBYLAwwwIAAKxHYAEAANbjIyEgBDxlFgAGF4EFGKI4fwTAlYSPhAAAgPUILAAAwHohBZaqqiqlp6crPj5eOTk52rdv30Xrd+zYoSlTpig+Pl4zZszQnj17fF87d+6cHn/8cc2YMUMjR45UamqqFi1apOPHj4fSGgAAiEBBB5bt27fL5XKpvLxczc3NysjIkNPpVHt7e8D6xsZGLVy4UIsXL1ZLS4sKCgpUUFCggwcPSpLOnDmj5uZmrVmzRs3NzfrNb36jQ4cO6c4777y0kQEAgIgRdGDZsGGDlixZouLiYk2bNk3V1dVKSEjQli1bAtY/99xzmj9/vkpKSjR16lQ9+eSTmjlzpjZt2iRJSkpKUl1dne677z5NnjxZN998szZt2qSmpia1trZe2ugAAEBECOoqoZ6eHjU1Nam0tNS3Ljo6Wnl5eXK73QG3cbvdcrlcfuucTqd27tx5wffp6OhQVFSURo0aFUx7AABwBV2ECiqwnDx5Ur29vUpJSfFbn5KSojfffDPgNh6PJ2C9x+MJWH/27Fk9/vjjWrhwoRITEwPWdHd3q7u72/e6s7MzmGEAAIAhxqr7sJw7d0733XefjDH64Q9/eMG6iooKPfHEE4PYGTA4+D9DAAgsqHNYkpOTFRMTo7a2Nr/1bW1tcjgcAbdxOBz9qv8krLz77ruqq6u74OyKJJWWlqqjo8O3HD16NJhhAACAISaowBIbG6usrCzV19f71nm9XtXX1ys3NzfgNrm5uX71klRXV+dX/0lYOXz4sF566SWNHj36on3ExcUpMTHRbwEAAJEr6I+EXC6XioqKNGvWLGVnZ6uyslJdXV0qLi6WJC1atEjjx49XRUWFJGnFihWaO3eu1q9fr/z8fNXU1Gj//v3avHmzpI/Dyte+9jU1Nzdr9+7d6u3t9Z3fcu211yo2NnagxgoAAIaooANLYWGhTpw4obKyMnk8HmVmZqq2ttZ3Ym1ra6uioz+duJkzZ462bdum1atXa9WqVZo0aZJ27typ6dOnS5KOHTumXbt2SZIyMzP93uvll1/Wl770pRCHBsAmnJ8DG/F9OXSEdNLtsmXLtGzZsoBfa2hoOG/dggULtGDBgoD16enpMsaE0gYAALhC8CwhAABgPQILAACwnlX3YQHCqa/PsvkcGwDChxkWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrcVkzhpRgb6PNpcoAEBkILAAAhBHPM+ofPhICAADWY4YFAIB+YCYkvJhhAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPa4SAgAgAkXaVU3MsAAAAOsxw4J+i7S0DgAYOggsuCwINwCAgURgQdgRbgCgf67k35cElivclfzNDwCXC79bBx4n3QIAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB5XCUUgzk4HAEQaZlgAAID1CCwAAMB6BBYAAGA9AgsAALAeJ92GESfHAgDQP8ywAAAA6zHDMkQwGwMAuJIxwwIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6/EsoQHGM38AABh4zLAAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYLKbBUVVUpPT1d8fHxysnJ0b59+y5av2PHDk2ZMkXx8fGaMWOG9uzZ4/d1Y4zKyso0btw4jRgxQnl5eTp8+HAorQEAgAgUdGDZvn27XC6XysvL1dzcrIyMDDmdTrW3twesb2xs1MKFC7V48WK1tLSooKBABQUFOnjwoK/m6aef1ve//31VV1fr1Vdf1ciRI+V0OnX27NnQRwYAACJG0IFlw4YNWrJkiYqLizVt2jRVV1crISFBW7ZsCVj/3HPPaf78+SopKdHUqVP15JNPaubMmdq0aZOkj2dXKisrtXr1at1111268cYb9dOf/lTHjx/Xzp07L2lwAAAgMgT1LKGenh41NTWptLTUty46Olp5eXlyu90Bt3G73XK5XH7rnE6nL4y888478ng8ysvL8309KSlJOTk5crvduv/++8/bZ3d3t7q7u32vOzo6JEmdnZ3BDKffvN1n+qz55L2ptae2P/U21P5zPbXB/wzb0G8k1/an3obaf66n1p6fz/7u0xjTd7EJwrFjx4wk09jY6Le+pKTEZGdnB9xm+PDhZtu2bX7rqqqqzNixY40xxvzpT38ykszx48f9ahYsWGDuu+++gPssLy83klhYWFhYWFgiYDl69GifGWRIPq25tLTUb9bG6/XqH//4h0aPHq2oqKjL+t6dnZ1KS0vT0aNHlZiYeFnfKxwieXyMbWiK5LFJkT0+xjZ0Ddb4jDE6ffq0UlNT+6wNKrAkJycrJiZGbW1tfuvb2trkcDgCbuNwOC5a/8l/29raNG7cOL+azMzMgPuMi4tTXFyc37pRo0YFM5RLlpiYGJHfpJ+I5PExtqEpkscmRfb4GNvQNRjjS0pK6lddUCfdxsbGKisrS/X19b51Xq9X9fX1ys3NDbhNbm6uX70k1dXV+eonTpwoh8PhV9PZ2alXX331gvsEAABXlqA/EnK5XCoqKtKsWbOUnZ2tyspKdXV1qbi4WJK0aNEijR8/XhUVFZKkFStWaO7cuVq/fr3y8/NVU1Oj/fv3a/PmzZKkqKgoPfroo/re976nSZMmaeLEiVqzZo1SU1NVUFAwcCMFAABDVtCBpbCwUCdOnFBZWZk8Ho8yMzNVW1urlJQUSVJra6uioz+duJkzZ462bdum1atXa9WqVZo0aZJ27typ6dOn+2q+/e1vq6urSw8++KBOnTqlW265RbW1tYqPjx+AIQ6suLg4lZeXn/eRVKSI5PExtqEpkscmRfb4GNvQZeP4oozpz7VEAAAA4cOzhAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BpZ+OHDmixYsXa+LEiRoxYoQ+85nPqLy8XD09PX51r732mr74xS8qPj5eaWlpevrpp8PUcfCeeuopzZkzRwkJCRe8EV9UVNR5S01NzeA2GoL+jK21tVX5+flKSEjQ2LFjVVJSoo8++mhwGx0A6enp5x2jtWvXhrutkFVVVSk9PV3x8fHKycnRvn37wt3SJfvOd75z3jGaMmVKuNsK2R/+8Ad99atfVWpqqqKios57cK0xRmVlZRo3bpxGjBihvLw8HT58ODzNBqmvsX3jG98471jOnz8/PM0GqaKiQrNnz9bVV1+tsWPHqqCgQIcOHfKrOXv2rJYuXarRo0frqquu0r333nvezWAHC4Gln9588015vV796Ec/0uuvv65nn31W1dXVWrVqla+ms7NT8+bN04QJE9TU1KR169bpO9/5ju+eM7br6enRggUL9PDDD1+0buvWrXr//fd9y1C4X05fY+vt7VV+fr56enrU2Nion/zkJ3r++edVVlY2yJ0OjO9+97t+x2j58uXhbikk27dvl8vlUnl5uZqbm5WRkSGn06n29vZwt3bJPv/5z/sdoz/+8Y/hbilkXV1dysjIUFVVVcCvP/300/r+97+v6upqvfrqqxo5cqScTqfOnj07yJ0Gr6+xSdL8+fP9juUvf/nLQewwdHv37tXSpUv1yiuvqK6uTufOndO8efPU1dXlq/nWt76l3/3ud9qxY4f27t2r48eP65577glPw30+bQgX9PTTT5uJEyf6Xv/gBz8w11xzjenu7vate/zxx83kyZPD0V7Itm7dapKSkgJ+TZJ54YUXBrWfgXShse3Zs8dER0cbj8fjW/fDH/7QJCYm+h3PoWDChAnm2WefDXcbAyI7O9ssXbrU97q3t9ekpqaaioqKMHZ16crLy01GRka427gs/u/vCK/XaxwOh1m3bp1v3alTp0xcXJz55S9/GYYOQxfo919RUZG56667wtLPQGtvbzeSzN69e40xHx+n4cOHmx07dvhq3njjDSPJuN3uQe+PGZZL0NHRoWuvvdb32u1269Zbb1VsbKxvndPp1KFDh/S///u/4Wjxsli6dKmSk5OVnZ2tLVu29O+x4JZzu92aMWOG7waI0sfHrrOzU6+//noYOwvN2rVrNXr0aN10001at27dkPxoq6enR01NTcrLy/Oti46OVl5entxudxg7GxiHDx9WamqqbrjhBn39619Xa2truFu6LN555x15PB6/45iUlKScnJyIOI6S1NDQoLFjx2ry5Ml6+OGH9fe//z3cLYWko6NDknx/15qamnTu3Dm/YzdlyhRdf/31YTl2Q/JpzTZ46623tHHjRj3zzDO+dR6PRxMnTvSr++QPoMfj0TXXXDOoPV4O3/3ud/XlL39ZCQkJ+v3vf69vfvOb+uCDD/TII4+Eu7VL4vF4/MKK5H/shpJHHnlEM2fO1LXXXqvGxkaVlpbq/fff14YNG8LdWlBOnjyp3t7egMflzTffDFNXAyMnJ0fPP/+8Jk+erPfff19PPPGEvvjFL+rgwYO6+uqrw93egPrk5yfQcRxqP1uBzJ8/X/fcc48mTpyot99+W6tWrdIdd9wht9utmJiYcLfXb16vV48++qi+8IUv+O5E7/F4FBsbe955f+E6dlf8DMvKlSsDnkj6z8v//eV47NgxzZ8/XwsWLNCSJUvC1Hn/hDK+i1mzZo2+8IUv6KabbtLjjz+ub3/721q3bt1lHMGFDfTYbBbMWF0ul770pS/pxhtv1EMPPaT169dr48aN6u7uDvMo8Ik77rhDCxYs0I033iin06k9e/bo1KlT+tWvfhXu1hCk+++/X3feeadmzJihgoIC7d69W3/+85/V0NAQ7taCsnTpUh08eNDqiyiu+BmWxx57TN/4xjcuWnPDDTf4/n38+HHddtttmjNnznkn0zocjvPOnv7ktcPhGJiGgxTs+IKVk5OjJ598Ut3d3YP+zImBHJvD4Tjv6pNwH7t/diljzcnJ0UcffaQjR45o8uTJl6G7yyM5OVkxMTEBf6ZsOCYDadSoUfrc5z6nt956K9ytDLhPjlVbW5vGjRvnW9/W1qbMzMwwdXX53HDDDUpOTtZbb72l22+/Pdzt9MuyZcu0e/du/eEPf9B1113nW+9wONTT06NTp075zbKE62fwig8sY8aM0ZgxY/pVe+zYMd12223KysrS1q1b/R7yKEm5ubn613/9V507d07Dhw+XJNXV1Wny5Mlh+zgomPGF4sCBA7rmmmvC8oCsgRxbbm6unnrqKbW3t2vs2LGSPj52iYmJmjZt2oC8x6W4lLEeOHBA0dHRvnENFbGxscrKylJ9fb3vSjSv16v6+notW7YsvM0NsA8++EBvv/22HnjggXC3MuAmTpwoh8Oh+vp6X0Dp7OzUq6++2ucViUPRe++9p7///e9+4cxWxhgtX75cL7zwghoaGs47pSErK0vDhw9XfX297r33XknSoUOH1Nraqtzc3LA0jH547733zGc/+1lz++23m/fee8+8//77vuUTp06dMikpKeaBBx4wBw8eNDU1NSYhIcH86Ec/CmPn/ffuu++alpYW88QTT5irrrrKtLS0mJaWFnP69GljjDG7du0yP/7xj83//M//mMOHD5sf/OAHJiEhwZSVlYW58771NbaPPvrITJ8+3cybN88cOHDA1NbWmjFjxpjS0tIwdx6cxsZG8+yzz5oDBw6Yt99+2/z85z83Y8aMMYsWLQp3ayGpqakxcXFx5vnnnzd//etfzYMPPmhGjRrldzXXUPTYY4+ZhoYG884775g//elPJi8vzyQnJ5v29vZwtxaS06dP+36mJJkNGzaYlpYW8+677xpjjFm7dq0ZNWqU+e1vf2tee+01c9ddd5mJEyeaDz/8MMyd9+1iYzt9+rT5l3/5F+N2u80777xjXnrpJTNz5kwzadIkc/bs2XC33qeHH37YJCUlmYaGBr+/aWfOnPHVPPTQQ+b66683//Vf/2X2799vcnNzTW5ublj6JbD009atW42kgMs/+8tf/mJuueUWExcXZ8aPH2/Wrl0bpo6DV1RUFHB8L7/8sjHGmP/4j/8wmZmZ5qqrrjIjR440GRkZprq62vT29oa38X7oa2zGGHPkyBFzxx13mBEjRpjk5GTz2GOPmXPnzoWv6RA0NTWZnJwck5SUZOLj483UqVPNv/3bvw2JX54XsnHjRnP99deb2NhYk52dbV555ZVwt3TJCgsLzbhx40xsbKwZP368KSwsNG+99Va42wrZyy+/HPDnq6ioyBjz8aXNa9asMSkpKSYuLs7cfvvt5tChQ+Ftup8uNrYzZ86YefPmmTFjxpjhw4ebCRMmmCVLlgyZQH2hv2lbt2711Xz44Yfmm9/8prnmmmtMQkKCufvuu/3+R30wRf3/pgEAAKx1xV8lBAAA7EdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1/h9TiaZMnPIi3AAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Change player and age tokens here!\n", "# You can find these values in player_tokens.csv and age_tokens.csv\n", "# You must provide exactly 8 player tokens and 8 age tokens for each team.\n", "\n", "# Denver Nuggets first game of 2023-24 season roster\n", "home_player_tokens = [5035, 4298, 4626, 4690, 4750, 5082, 4286, 4311]\n", "home_age_tokens = [14, 16, 13, 12, 10, 19, 8, 8]\n", "\n", "# Uncomment to take Jokic off team, replace with Peyton Watson\n", "# home_player_tokens = [4331, 4298, 4626, 4690, 4750, 5082, 4286, 4311]\n", "# home_age_tokens = [6, 16, 13, 12, 10, 19, 8, 8]\n", "\n", "# Boston Celtics final game of 2023-24 season roster\n", "away_player_tokens = [5042, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n", "away_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n", "\n", "# Uncomment to take Tatum off team, replace with Pritchard\n", "# away_player_tokens = [4999, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n", "# away_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n", "\n", "# The model usually gives the home team a bump in win probability.\n", "# Change this to \"True\" to swap home and away teams.\n", "swap_home_away = False\n", "if swap_home_away:\n", " home_player_tokens, away_player_tokens = away_player_tokens, home_player_tokens\n", " home_age_tokens, away_age_tokens = away_age_tokens, home_age_tokens\n", "\n", "assert len(home_player_tokens) == players_per_team\n", "assert len(home_age_tokens) == players_per_team\n", "assert len(away_player_tokens) == players_per_team\n", "assert len(away_age_tokens) == players_per_team\n", "\n", "batch = {\n", " 'home_player_tokens': Tensor([num_player_tokens+1] + home_player_tokens).to(dtype=int32).unsqueeze(0),\n", " 'home_age_tokens': Tensor([num_age_tokens+1] + home_age_tokens).to(dtype=int32).unsqueeze(0),\n", " 'away_player_tokens': Tensor(away_player_tokens).to(dtype=int32).unsqueeze(0),\n", " 'away_age_tokens': Tensor(away_age_tokens).to(dtype=int32).unsqueeze(0),\n", "}\n", "\n", "for key, value in batch.items():\n", " if hasattr(value, 'to'):\n", " batch[key] = value.to(device)\n", "\n", "output, _ = model(**batch)\n", "output = output.squeeze().softmax(dim=0)\n", "\n", "probs = {}\n", "loss_prob = 0\n", "win_prob = 0\n", "\n", "first = True\n", "for i, token in enumerate(output):\n", " if first:\n", " first = False\n", " continue\n", "\n", " if i-21 < 0:\n", " loss_prob += token.item()\n", " elif i-21 > 0 and i-21 < 21:\n", " win_prob += token.item()\n", "\n", " probs[i-21] = token.item()\n", "\n", "del probs[0]\n", "del probs[21]\n", "\n", "print(f\"Home team win probability: {win_prob:.2f}\")\n", "\n", "plt.bar(probs.keys(), probs.values())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "nba", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 2 }