File size: 20,165 Bytes
5f26252 1504ee5 5f26252 1504ee5 5f26252 1504ee5 5f26252 1504ee5 5f26252 1504ee5 5f26252 7a4fc48 21e77ce 7a4fc48 5f26252 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from model import NBAModel, NBAConfig\n",
"from torch import device as torch_device, load as torch_load, int32, Tensor, bfloat16\n",
"import matplotlib.pyplot as plt\n",
"\n",
"device = torch_device(\"cpu\") \n",
"num_age_tokens=32\n",
"num_player_tokens=5141\n",
"num_net_score_tokens=41\n",
"players_per_team=8\n",
"\n",
"model_config = NBAConfig(\n",
" players_per_team=players_per_team,\n",
" player_tokens=num_player_tokens+2,\n",
" age_tokens=num_age_tokens+2,\n",
" num_labels=num_net_score_tokens+2,\n",
" n_layer=4,\n",
" n_head=4,\n",
" n_embd=1024,\n",
" dropout=0.0,\n",
" bias=False,\n",
" dtype=bfloat16,\n",
" seed=29,\n",
")\n",
"\n",
"model = NBAModel(model_config).to(device)\n",
"state_dict = torch_load('weights.pt', map_location='cpu')\n",
"model.load_state_dict(state_dict)\n",
"model = model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Home team win probability: 0.55\n"
]
},
{
"data": {
"text/plain": [
"<BarContainer object of 40 artists>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGhCAYAAABCse9yAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAArFElEQVR4nO3df3RU9Z3/8VcSyIxBiUgkAxQNViQgMSmBxKG21HUOE0+6EnVjYHv4kcOhR2uUdtxUwmKiZd2gLBgqaVN6BOnuYticrVmLbLbp1FBtRliSsJSusOoRg8AkYEuCsSaY3O8ffh07ZSCZEJhPhufjnHs0n3nfm/fnXJO8/My9d2Isy7IEAABgsNhINwAAANAfAgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMN6gAktlZaVSUlJkt9uVnZ2tvXv3XrC+pqZGqampstvtSktL065du4Jej4mJCbmtW7duMO0BAIAoE3Zg2bFjhzwej8rKytTc3Kz09HS53W61t7eHrG9sbNTChQu1bNkytbS0KC8vT3l5eTp48GCg5sSJE0Hbli1bFBMTo/vvv3/wMwMAAFEjJtwPP8zOztbs2bO1adMmSVJfX58mTZqkRx55RCtXrjynvqCgQF1dXdq5c2dg7Pbbb1dGRoaqqqpCfo+8vDydOXNGXq93QD319fXp+PHjuuaaaxQTExPOdAAAQIRYlqUzZ85owoQJio298BrKiHAO3NPTo6amJpWUlATGYmNj5XK55PP5Qu7j8/nk8XiCxtxut2pra0PWt7W16dVXX9W2bdvO20d3d7e6u7sDXx87dkzTp08PYyYAAMAUR48e1Ze+9KUL1oQVWE6dOqXe3l4lJycHjScnJ+vQoUMh9/H7/SHr/X5/yPpt27bpmmuu0X333XfePsrLy/XUU0+dM3706FGNHj26v2kAAAADdHZ2atKkSbrmmmv6rQ0rsFwOW7Zs0be+9S3Z7fbz1pSUlASt2nw+4dGjRxNYAAAYZgZyOUdYgSUpKUlxcXFqa2sLGm9ra5PD4Qi5j8PhGHD966+/rsOHD2vHjh0X7MNms8lms4XTOgAAGMbCuksoPj5emZmZQRfD9vX1yev1yul0htzH6XSec/FsfX19yPoXXnhBmZmZSk9PD6ctAAAQ5cJ+S8jj8WjJkiWaNWuWsrKyVFFRoa6uLhUWFkqSFi9erIkTJ6q8vFyStGLFCs2dO1fr169Xbm6uqqurtW/fPm3evDnouJ2dnaqpqdH69euHYFoAACCahB1YCgoKdPLkSZWWlsrv9ysjI0N1dXWBC2tbW1uDbk2aM2eOtm/frtWrV2vVqlWaMmWKamtrNWPGjKDjVldXy7IsLVy48CKnBAAAok3Yz2ExUWdnpxITE9XR0cFFtwAADBPh/P3ms4QAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPHCfjQ/AACILikrX+235sja3MvQyfmxwgIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGG9QgaWyslIpKSmy2+3Kzs7W3r17L1hfU1Oj1NRU2e12paWladeuXefUvPXWW7rnnnuUmJioUaNGafbs2WptbR1MewAAIMqEHVh27Nghj8ejsrIyNTc3Kz09XW63W+3t7SHrGxsbtXDhQi1btkwtLS3Ky8tTXl6eDh48GKh59913dccddyg1NVUNDQ06cOCAnnjiCdnt9sHPDAAARI0Yy7KscHbIzs7W7NmztWnTJklSX1+fJk2apEceeUQrV648p76goEBdXV3auXNnYOz2229XRkaGqqqqJEkLFizQyJEj9c///M+DmkRnZ6cSExPV0dGh0aNHD+oYAABcqVJWvtpvzZG1uUP+fcP5+x3WCktPT4+amprkcrm+OEBsrFwul3w+X8h9fD5fUL0kud3uQH1fX59effVV3XLLLXK73Ro3bpyys7NVW1t73j66u7vV2dkZtAEAgOgVVmA5deqUent7lZycHDSenJwsv98fch+/33/B+vb2dn300Udau3atcnJy9Mtf/lL33nuv7rvvPu3evTvkMcvLy5WYmBjYJk2aFM40AADAMBPxu4T6+vokSfPnz9f3vvc9ZWRkaOXKlfrmN78ZeMvoL5WUlKijoyOwHT169HK2DAAALrMR4RQnJSUpLi5ObW1tQeNtbW1yOBwh93E4HBesT0pK0ogRIzR9+vSgmmnTpumNN94IeUybzSabzRZO6wAAYBgLa4UlPj5emZmZ8nq9gbG+vj55vV45nc6Q+zidzqB6Saqvrw/Ux8fHa/bs2Tp8+HBQzf/93//pxhtvDKc9AAAQpcJaYZEkj8ejJUuWaNasWcrKylJFRYW6urpUWFgoSVq8eLEmTpyo8vJySdKKFSs0d+5crV+/Xrm5uaqurta+ffu0efPmwDGLi4tVUFCgr3/967rzzjtVV1enX/ziF2poaBiaWQIAgGEt7MBSUFCgkydPqrS0VH6/XxkZGaqrqwtcWNva2qrY2C8WbubMmaPt27dr9erVWrVqlaZMmaLa2lrNmDEjUHPvvfeqqqpK5eXlevTRRzV16lT9+7//u+64444hmCIAABjuwn4Oi4l4DgsAAIMXdc9hAQAAiAQCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGC8QQWWyspKpaSkyG63Kzs7W3v37r1gfU1NjVJTU2W325WWlqZdu3YFvb506VLFxMQEbTk5OYNpDQAARKGwA8uOHTvk8XhUVlam5uZmpaeny+12q729PWR9Y2OjFi5cqGXLlqmlpUV5eXnKy8vTwYMHg+pycnJ04sSJwPbSSy8NbkYAACDqhB1YNmzYoOXLl6uwsFDTp09XVVWVEhIStGXLlpD1GzduVE5OjoqLizVt2jStWbNGM2fO1KZNm4LqbDabHA5HYBszZszgZgQAAKJOWIGlp6dHTU1NcrlcXxwgNlYul0s+ny/kPj6fL6hektxu9zn1DQ0NGjdunKZOnaqHHnpIH3744Xn76O7uVmdnZ9AGAACiV1iB5dSpU+rt7VVycnLQeHJysvx+f8h9/H5/v/U5OTn62c9+Jq/Xq2eeeUa7d+/W3Xffrd7e3pDHLC8vV2JiYmCbNGlSONMAAADDzIhINyBJCxYsCPx7WlqabrvtNn35y19WQ0OD7rrrrnPqS0pK5PF4Al93dnYSWgAAiGJhrbAkJSUpLi5ObW1tQeNtbW1yOBwh93E4HGHVS9JNN92kpKQkvfPOOyFft9lsGj16dNAGAACiV1iBJT4+XpmZmfJ6vYGxvr4+eb1eOZ3OkPs4nc6gekmqr68/b70kffDBB/rwww81fvz4cNoDAABRKuy7hDwej376059q27Zteuutt/TQQw+pq6tLhYWFkqTFixerpKQkUL9ixQrV1dVp/fr1OnTokJ588knt27dPRUVFkqSPPvpIxcXFevPNN3XkyBF5vV7Nnz9fN998s9xu9xBNEwAADGdhX8NSUFCgkydPqrS0VH6/XxkZGaqrqwtcWNva2qrY2C9y0Jw5c7R9+3atXr1aq1at0pQpU1RbW6sZM2ZIkuLi4nTgwAFt27ZNp0+f1oQJEzRv3jytWbNGNpttiKYJAACGsxjLsqxIN3GxOjs7lZiYqI6ODq5nAQAgTCkrX+235sja3CH/vuH8/eazhAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGC8EZFuYDhIWflqvzVH1uZehk4AALgyscICAACMR2ABAADG4y0hAFcM3t4Fhi9WWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMNyLSDQDAlSRl5av91hxZm3sZOgGGF1ZYAACA8VhhAYAQWAkBzMIKCwAAMB6BBQAAGG9QgaWyslIpKSmy2+3Kzs7W3r17L1hfU1Oj1NRU2e12paWladeuXeetffDBBxUTE6OKiorBtAYAAKJQ2IFlx44d8ng8KisrU3Nzs9LT0+V2u9Xe3h6yvrGxUQsXLtSyZcvU0tKivLw85eXl6eDBg+fUvvzyy3rzzTc1YcKE8GcCAACiVtiBZcOGDVq+fLkKCws1ffp0VVVVKSEhQVu2bAlZv3HjRuXk5Ki4uFjTpk3TmjVrNHPmTG3atCmo7tixY3rkkUf0r//6rxo5cuTgZgMAAKJSWIGlp6dHTU1NcrlcXxwgNlYul0s+ny/kPj6fL6hektxud1B9X1+fFi1apOLiYt1666399tHd3a3Ozs6gDQAARK+wAsupU6fU29ur5OTkoPHk5GT5/f6Q+/j9/n7rn3nmGY0YMUKPPvrogPooLy9XYmJiYJs0aVI40wAAAMNMxO8Sampq0saNG/Xiiy8qJiZmQPuUlJSoo6MjsB09evQSdwkAACIprAfHJSUlKS4uTm1tbUHjbW1tcjgcIfdxOBwXrH/99dfV3t6uG264IfB6b2+vHnvsMVVUVOjIkSPnHNNms8lms4XTOgBENR50h2gX1gpLfHy8MjMz5fV6A2N9fX3yer1yOp0h93E6nUH1klRfXx+oX7RokQ4cOKD9+/cHtgkTJqi4uFj/9V//Fe58AABAFAr70fwej0dLlizRrFmzlJWVpYqKCnV1damwsFCStHjxYk2cOFHl5eWSpBUrVmju3Llav369cnNzVV1drX379mnz5s2SpLFjx2rs2LFB32PkyJFyOByaOnXqxc4PAABEgbADS0FBgU6ePKnS0lL5/X5lZGSorq4ucGFta2urYmO/WLiZM2eOtm/frtWrV2vVqlWaMmWKamtrNWPGjKGbBQAAiGqD+vDDoqIiFRUVhXytoaHhnLH8/Hzl5+cP+PihrlsBAABXrojfJQQAANAfAgsAADAegQUAABiPwAIAAIxHYAEAAMYb1F1CAAaOJ5ACwMVjhQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxuOzhADgIkX750X1N7/hPDcMHwQWAMNatIcFAJ/hLSEAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHg8OA64AvBwNQDDHSssAADAeKywAIChWBkDvsAKCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMN6jAUllZqZSUFNntdmVnZ2vv3r0XrK+pqVFqaqrsdrvS0tK0a9euoNeffPJJpaamatSoURozZoxcLpf27NkzmNYAAEAUCjuw7NixQx6PR2VlZWpublZ6errcbrfa29tD1jc2NmrhwoVatmyZWlpalJeXp7y8PB08eDBQc8stt2jTpk363e9+pzfeeEMpKSmaN2+eTp48OfiZAQCAqBH2pzVv2LBBy5cvV2FhoSSpqqpKr776qrZs2aKVK1eeU79x40bl5OSouLhYkrRmzRrV19dr06ZNqqqqkiT97d/+7Tnf44UXXtCBAwd01113hT0p4FLr71N0+QRdABhaYa2w9PT0qKmpSS6X64sDxMbK5XLJ5/OF3Mfn8wXVS5Lb7T5vfU9PjzZv3qzExESlp6eHrOnu7lZnZ2fQBgAAoldYKyynTp1Sb2+vkpOTg8aTk5N16NChkPv4/f6Q9X6/P2hs586dWrBggT7++GONHz9e9fX1SkpKCnnM8vJyPfXUU+G0DgC4DFh9xKVizF1Cd955p/bv36/Gxkbl5OTogQceOO91MSUlJero6AhsR48evczdAgCAyymswJKUlKS4uDi1tbUFjbe1tcnhcITcx+FwDKh+1KhRuvnmm3X77bfrhRde0IgRI/TCCy+EPKbNZtPo0aODNgAAEL3CCizx8fHKzMyU1+sNjPX19cnr9crpdIbcx+l0BtVLUn19/Xnr//y43d3d4bQHAACiVNh3CXk8Hi1ZskSzZs1SVlaWKioq1NXVFbhraPHixZo4caLKy8slSStWrNDcuXO1fv165ebmqrq6Wvv27dPmzZslSV1dXXr66ad1zz33aPz48Tp16pQqKyt17Ngx5efnD+FUAQDAcBV2YCkoKNDJkydVWloqv9+vjIwM1dXVBS6sbW1tVWzsFws3c+bM0fbt27V69WqtWrVKU6ZMUW1trWbMmCFJiouL06FDh7Rt2zadOnVKY8eO1ezZs/X666/r1ltvHaJpAgCA4SzswCJJRUVFKioqCvlaQ0PDOWP5+fnnXS2x2+36+c9/Ppg2gCtaf3djSNyRASB6GHOXEAAAwPkQWAAAgPEILAAAwHgEFgAAYDwCCwAAMN6g7hICEL24+wiAiVhhAQAAxiOwAAAA4xFYAACA8biGBcCgcb0LgMuFFRYAAGA8AgsAADAegQUAABiPa1gAg3BNCACExgoLAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAeT7oF/r/+njLLE2YBIHJYYQEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjMeD4wAAEcHDGhEOVlgAAIDxCCwAAMB4BBYAAGA8AgsAADAeF90CwBWmv4tdJS54hXlYYQEAAMZjhQXAZcH/1QO4GKywAAAA4xFYAACA8QYVWCorK5WSkiK73a7s7Gzt3bv3gvU1NTVKTU2V3W5XWlqadu3aFXjt7Nmzevzxx5WWlqZRo0ZpwoQJWrx4sY4fPz6Y1gAAQBQK+xqWHTt2yOPxqKqqStnZ2aqoqJDb7dbhw4c1bty4c+obGxu1cOFClZeX65vf/Ka2b9+uvLw8NTc3a8aMGfr444/V3NysJ554Qunp6frjH/+oFStW6J577tG+ffuGZJKIHlwHAQBXprBXWDZs2KDly5ersLBQ06dPV1VVlRISErRly5aQ9Rs3blROTo6Ki4s1bdo0rVmzRjNnztSmTZskSYmJiaqvr9cDDzygqVOn6vbbb9emTZvU1NSk1tbWi5sdAACICmEFlp6eHjU1Ncnlcn1xgNhYuVwu+Xy+kPv4fL6geklyu93nrZekjo4OxcTE6Nprrw35end3tzo7O4M2AAAQvcJ6S+jUqVPq7e1VcnJy0HhycrIOHToUch+/3x+y3u/3h6z/5JNP9Pjjj2vhwoUaPXp0yJry8nI99dRT4bSOKxSfBgsA0cGou4TOnj2rBx54QJZl6cc//vF560pKStTR0RHYjh49ehm7BAAAl1tYKyxJSUmKi4tTW1tb0HhbW5scDkfIfRwOx4DqPw8r77//vn7961+fd3VFkmw2m2w2WzitAwCAYSysFZb4+HhlZmbK6/UGxvr6+uT1euV0OkPu43Q6g+olqb6+Pqj+87Dy9ttv61e/+pXGjh0bTlsAACDKhX1bs8fj0ZIlSzRr1ixlZWWpoqJCXV1dKiwslCQtXrxYEydOVHl5uSRpxYoVmjt3rtavX6/c3FxVV1dr37592rx5s6TPwsrf/M3fqLm5WTt37lRvb2/g+pbrrrtO8fHxQzVXAAAwTIUdWAoKCnTy5EmVlpbK7/crIyNDdXV1gQtrW1tbFRv7xcLNnDlztH37dq1evVqrVq3SlClTVFtbqxkzZkiSjh07pldeeUWSlJGREfS9XnvtNX3jG98Y5NQAANGCZzBhUB9+WFRUpKKiopCvNTQ0nDOWn5+v/Pz8kPUpKSmyLGswbQAAgCuEUXcJAQAAhEJgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgvEHd1owrE89BAABECoHlCkcIAQAMBwQWRByhCQDQH65hAQAAxiOwAAAA4xFYAACA8biGBQAQVbguLjqxwgIAAIxHYAEAAMbjLSFcEizJAhgO+F01fLDCAgAAjEdgAQAAxiOwAAAA43ENSxTiPVkAQLRhhQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeCMi3cCVLGXlq/3WHFmbexk6AQDAbKywAAAA47HCAgDAALAqHlmssAAAAOMRWAAAgPEILAAAwHgEFgAAYDwuugUAIIK4mHdgBrXCUllZqZSUFNntdmVnZ2vv3r0XrK+pqVFqaqrsdrvS0tK0a9euoNd//vOfa968eRo7dqxiYmK0f//+wbQFAACiVNiBZceOHfJ4PCorK1Nzc7PS09PldrvV3t4esr6xsVELFy7UsmXL1NLSory8POXl5engwYOBmq6uLt1xxx165plnBj8TAAAQtcIOLBs2bNDy5ctVWFio6dOnq6qqSgkJCdqyZUvI+o0bNyonJ0fFxcWaNm2a1qxZo5kzZ2rTpk2BmkWLFqm0tFQul2vwMwEAAFErrMDS09OjpqamoGARGxsrl8sln88Xch+fz3dOEHG73eetH4ju7m51dnYGbQAAIHqFFVhOnTql3t5eJScnB40nJyfL7/eH3Mfv94dVPxDl5eVKTEwMbJMmTRr0sQAAgPmG5W3NJSUl6ujoCGxHjx6NdEsAAOASCuu25qSkJMXFxamtrS1ovK2tTQ6HI+Q+DocjrPqBsNlsstlsg95/OOK2NwDAlSysFZb4+HhlZmbK6/UGxvr6+uT1euV0OkPu43Q6g+olqb6+/rz1AAAAfynsB8d5PB4tWbJEs2bNUlZWlioqKtTV1aXCwkJJ0uLFizVx4kSVl5dLklasWKG5c+dq/fr1ys3NVXV1tfbt26fNmzcHjvmHP/xBra2tOn78uCTp8OHDkj5bnbmYlRgAAK5U0bYyH3ZgKSgo0MmTJ1VaWiq/36+MjAzV1dUFLqxtbW1VbOwXCzdz5szR9u3btXr1aq1atUpTpkxRbW2tZsyYEah55ZVXAoFHkhYsWCBJKisr05NPPjnYuQEAgCgxqEfzFxUVqaioKORrDQ0N54zl5+crPz//vMdbunSpli5dOphWAADAFWBY3iUEAACuLAQWAABgPD6tGQCAIRZtF7yagBUWAABgPFZYAAAYJq7klRtWWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeDyaf4hdyY9NBgDgUmGFBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGG9QgaWyslIpKSmy2+3Kzs7W3r17L1hfU1Oj1NRU2e12paWladeuXUGvW5al0tJSjR8/XldddZVcLpfefvvtwbQGAACiUNiBZceOHfJ4PCorK1Nzc7PS09PldrvV3t4esr6xsVELFy7UsmXL1NLSory8POXl5engwYOBmmeffVY//OEPVVVVpT179mjUqFFyu9365JNPBj8zAAAQNcIOLBs2bNDy5ctVWFio6dOnq6qqSgkJCdqyZUvI+o0bNyonJ0fFxcWaNm2a1qxZo5kzZ2rTpk2SPltdqaio0OrVqzV//nzddttt+tnPfqbjx4+rtrb2oiYHAACiw4hwint6etTU1KSSkpLAWGxsrFwul3w+X8h9fD6fPB5P0Jjb7Q6Ekffee09+v18ulyvwemJiorKzs+Xz+bRgwYJzjtnd3a3u7u7A1x0dHZKkzs7OcKYzYH3dH/db8/n3ptac2oHUm1D75/XUhv8zbEK/0Vw7kHoTav+8nlpzfj4HekzLsvovtsJw7NgxS5LV2NgYNF5cXGxlZWWF3GfkyJHW9u3bg8YqKyutcePGWZZlWb/97W8tSdbx48eDavLz860HHngg5DHLysosSWxsbGxsbGxRsB09erTfDBLWCospSkpKglZt+vr69Ic//EFjx45VTEzMJf3enZ2dmjRpko4eParRo0df0u8VCdE8P+Y2PEXz3KTonh9zG74u1/wsy9KZM2c0YcKEfmvDCixJSUmKi4tTW1tb0HhbW5scDkfIfRwOxwXrP/9nW1ubxo8fH1STkZER8pg2m002my1o7Nprrw1nKhdt9OjRUfkf6eeieX7MbXiK5rlJ0T0/5jZ8XY75JSYmDqgurItu4+PjlZmZKa/XGxjr6+uT1+uV0+kMuY/T6Qyql6T6+vpA/eTJk+VwOIJqOjs7tWfPnvMeEwAAXFnCfkvI4/FoyZIlmjVrlrKyslRRUaGuri4VFhZKkhYvXqyJEyeqvLxckrRixQrNnTtX69evV25urqqrq7Vv3z5t3rxZkhQTE6Pvfve7+od/+AdNmTJFkydP1hNPPKEJEyYoLy9v6GYKAACGrbADS0FBgU6ePKnS0lL5/X5lZGSorq5OycnJkqTW1lbFxn6xcDNnzhxt375dq1ev1qpVqzRlyhTV1tZqxowZgZrvf//76urq0re//W2dPn1ad9xxh+rq6mS324dgikPLZrOprKzsnLekokU0z4+5DU/RPDcpuufH3IYvE+cXY1kDuZcIAAAgcvgsIQAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwDNCRI0e0bNkyTZ48WVdddZW+/OUvq6ysTD09PUF1Bw4c0Ne+9jXZ7XZNmjRJzz77bIQ6Dt/TTz+tOXPmKCEh4bxPDo6JiTlnq66uvryNDsJA5tba2qrc3FwlJCRo3LhxKi4u1qeffnp5Gx0CKSkp55yjtWvXRrqtQausrFRKSorsdruys7O1d+/eSLd00Z588slzzlFqamqk2xq03/zmN/rrv/5rTZgwQTExMYEPt/2cZVkqLS3V+PHjddVVV8nlcuntt9+OTLNh6m9uS5cuPedc5uTkRKbZMJWXl2v27Nm65pprNG7cOOXl5enw4cNBNZ988okefvhhjR07VldffbXuv//+c55ef7kQWAbo0KFD6uvr009+8hP9/ve/13PPPaeqqiqtWrUqUNPZ2al58+bpxhtvVFNTk9atW6cnn3wy8JA80/X09Cg/P18PPfTQBeu2bt2qEydOBLbh8IC//ubW29ur3Nxc9fT0qLGxUdu2bdOLL76o0tLSy9zp0PjBD34QdI4eeeSRSLc0KDt27JDH41FZWZmam5uVnp4ut9ut9vb2SLd20W699dagc/TGG29EuqVB6+rqUnp6uiorK0O+/uyzz+qHP/yhqqqqtGfPHo0aNUput1uffPLJZe40fP3NTZJycnKCzuVLL710GTscvN27d+vhhx/Wm2++qfr6ep09e1bz5s1TV1dXoOZ73/uefvGLX6impka7d+/W8ePHdd9990Wm4X4/HhHn9eyzz1qTJ08OfP2jH/3IGjNmjNXd3R0Ye/zxx62pU6dGor1B27p1q5WYmBjyNUnWyy+/fFn7GUrnm9uuXbus2NhYy+/3B8Z+/OMfW6NHjw46n8PBjTfeaD333HORbmNIZGVlWQ8//HDg697eXmvChAlWeXl5BLu6eGVlZVZ6enqk27gk/vJ3RF9fn+VwOKx169YFxk6fPm3ZbDbrpZdeikCHgxfq99+SJUus+fPnR6Sfodbe3m5Jsnbv3m1Z1mfnaeTIkVZNTU2g5q233rIkWT6f77L3xwrLRejo6NB1110X+Nrn8+nrX/+64uPjA2Nut1uHDx/WH//4x0i0eEk8/PDDSkpKUlZWlrZs2SIrCp496PP5lJaWFnhis/TZuevs7NTvf//7CHY2OGvXrtXYsWP1la98RevWrRuWb2319PSoqalJLpcrMBYbGyuXyyWfzxfBzobG22+/rQkTJuimm27St771LbW2tka6pUvivffek9/vDzqPiYmJys7OjorzKEkNDQ0aN26cpk6dqoceekgffvhhpFsalI6ODkkK/F1ramrS2bNng85damqqbrjhhoicu7AfzY/PvPPOO3r++ef1T//0T4Exv9+vyZMnB9V9/gfQ7/drzJgxl7XHS+EHP/iB/uqv/koJCQn65S9/qe985zv66KOP9Oijj0a6tYvi9/uDwooUfO6Gk0cffVQzZ87Uddddp8bGRpWUlOjEiRPasGFDpFsLy6lTp9Tb2xvyvBw6dChCXQ2N7Oxsvfjii5o6dapOnDihp556Sl/72td08OBBXXPNNZFub0h9/vMT6jwOt5+tUHJycnTfffdp8uTJevfdd7Vq1Srdfffd8vl8iouLi3R7A9bX16fvfve7+upXvxr46By/36/4+PhzrvuL1Lm74ldYVq5cGfJC0j/f/vKX47Fjx5STk6P8/HwtX748Qp0PzGDmdyFPPPGEvvrVr+orX/mKHn/8cX3/+9/XunXrLuEMzm+o52aycObq8Xj0jW98Q7fddpsefPBBrV+/Xs8//7y6u7sjPAt87u6771Z+fr5uu+02ud1u7dq1S6dPn9a//du/Rbo1hGnBggW65557lJaWpry8PO3cuVP//d//rYaGhki3FpaHH35YBw8eNPomiit+heWxxx7T0qVLL1hz0003Bf79+PHjuvPOOzVnzpxzLqZ1OBznXD39+dcOh2NoGg5TuPMLV3Z2ttasWaPu7u7L/iFZQzk3h8Nxzt0nkT53f+5i5pqdna1PP/1UR44c0dSpUy9Bd5dGUlKS4uLiQv5MmXBOhtK1116rW265Re+8806kWxlyn5+rtrY2jR8/PjDe1tamjIyMCHV16dx0001KSkrSO++8o7vuuivS7QxIUVGRdu7cqd/85jf60pe+FBh3OBzq6enR6dOng1ZZIvUzeMUHluuvv17XX3/9gGqPHTumO++8U5mZmdq6dWvQp1JLktPp1N///d/r7NmzGjlypCSpvr5eU6dOjdjbQeHMbzD279+vMWPGROQTPYdybk6nU08//bTa29s1btw4SZ+du9GjR2v69OlD8j0uxsXMdf/+/YqNjQ3Ma7iIj49XZmamvF5v4E60vr4+eb1eFRUVRba5IfbRRx/p3Xff1aJFiyLdypCbPHmyHA6HvF5vIKB0dnZqz549/d6ROBx98MEH+vDDD4PCmaksy9Ijjzyil19+WQ0NDedc0pCZmamRI0fK6/Xq/vvvlyQdPnxYra2tcjqdEWkYA/DBBx9YN998s3XXXXdZH3zwgXXixInA9rnTp09bycnJ1qJFi6yDBw9a1dXVVkJCgvWTn/wkgp0P3Pvvv2+1tLRYTz31lHX11VdbLS0tVktLi3XmzBnLsizrlVdesX76059av/vd76y3337b+tGPfmQlJCRYpaWlEe68f/3N7dNPP7VmzJhhzZs3z9q/f79VV1dnXX/99VZJSUmEOw9PY2Oj9dxzz1n79++33n33Xetf/uVfrOuvv95avHhxpFsblOrqastms1kvvvii9b//+7/Wt7/9bevaa68NuptrOHrssceshoYG67333rN++9vfWi6Xy0pKSrLa29sj3dqgnDlzJvAzJcnasGGD1dLSYr3//vuWZVnW2rVrrWuvvdb6j//4D+vAgQPW/PnzrcmTJ1t/+tOfItx5/y40tzNnzlh/93d/Z/l8Puu9996zfvWrX1kzZ860pkyZYn3yySeRbr1fDz30kJWYmGg1NDQE/U37+OOPAzUPPvigdcMNN1i//vWvrX379llOp9NyOp0R6ZfAMkBbt261JIXc/tz//M//WHfccYdls9msiRMnWmvXro1Qx+FbsmRJyPm99tprlmVZ1n/+539aGRkZ1tVXX22NGjXKSk9Pt6qqqqze3t7INj4A/c3NsizryJEj1t13321dddVVVlJSkvXYY49ZZ8+ejVzTg9DU1GRlZ2dbiYmJlt1ut6ZNm2b94z/+47D45Xk+zz//vHXDDTdY8fHxVlZWlvXmm29GuqWLVlBQYI0fP96Kj4+3Jk6caBUUFFjvvPNOpNsatNdeey3kz9eSJUssy/rs1uYnnnjCSk5Otmw2m3XXXXdZhw8fjmzTA3ShuX388cfWvHnzrOuvv94aOXKkdeONN1rLly8fNoH6fH/Ttm7dGqj505/+ZH3nO9+xxowZYyUkJFj33ntv0P+oX04x/79pAAAAY13xdwkBAADzEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHj/D+wKfjR9046xAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Change player and age tokens here!\n",
"# You can find these values in player_tokens.csv and age_tokens.csv\n",
"# You must provide exactly 8 player tokens and 8 age tokens for each team.\n",
"\n",
"# Denver Nuggets first game of 2023-24 season roster\n",
"home_player_tokens = [5035, 4298, 4626, 4690, 4750, 5082, 4286, 4311]\n",
"home_age_tokens = [14, 16, 13, 12, 10, 19, 8, 8]\n",
"\n",
"# Uncomment to take Jokic off team, replace with Peyton Watson\n",
"# home_player_tokens = [4331, 4298, 4626, 4690, 4750, 5082, 4286, 4311]\n",
"# home_age_tokens = [6, 16, 13, 12, 10, 19, 8, 8]\n",
"\n",
"# Boston Celtics final game of 2023-24 season roster\n",
"away_player_tokens = [5042, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n",
"away_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n",
"\n",
"# Uncomment to take Tatum off team, replace with Pritchard\n",
"# away_player_tokens = [4999, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n",
"# away_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n",
"\n",
"# The model usually gives the home team a bump in win probability.\n",
"# Change this to \"True\" to swap home and away teams.\n",
"swap_home_away = False\n",
"if swap_home_away:\n",
" home_player_tokens, away_player_tokens = away_player_tokens, home_player_tokens\n",
" home_age_tokens, away_age_tokens = away_age_tokens, home_age_tokens\n",
"\n",
"assert len(home_player_tokens) == players_per_team\n",
"assert len(home_age_tokens) == players_per_team\n",
"assert len(away_player_tokens) == players_per_team\n",
"assert len(away_age_tokens) == players_per_team\n",
"\n",
"batch = {\n",
" 'home_player_tokens': Tensor([num_player_tokens+1] + home_player_tokens).to(dtype=int32).unsqueeze(0),\n",
" 'home_age_tokens': Tensor([num_age_tokens+1] + home_age_tokens).to(dtype=int32).unsqueeze(0),\n",
" 'away_player_tokens': Tensor(away_player_tokens).to(dtype=int32).unsqueeze(0),\n",
" 'away_age_tokens': Tensor(away_age_tokens).to(dtype=int32).unsqueeze(0),\n",
"}\n",
"\n",
"for key, value in batch.items():\n",
" if hasattr(value, 'to'):\n",
" batch[key] = value.to(device)\n",
"\n",
"output, _ = model(**batch)\n",
"output = output.squeeze().softmax(dim=0)\n",
"\n",
"probs = {}\n",
"loss_prob = 0\n",
"win_prob = 0\n",
"\n",
"first = True\n",
"for i, token in enumerate(output):\n",
" if first:\n",
" first = False\n",
" continue\n",
"\n",
" if i-21 < 0:\n",
" loss_prob += token.item()\n",
" elif i-21 > 0 and i-21 < 21:\n",
" win_prob += token.item()\n",
"\n",
" probs[i-21] = token.item()\n",
"\n",
"del probs[0]\n",
"del probs[21]\n",
"\n",
"print(f\"Home team win probability: {win_prob:.2f}\")\n",
"\n",
"plt.bar(probs.keys(), probs.values())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "nba",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|