Stoub commited on
Commit
9de8ea5
·
verified ·
1 Parent(s): 94b78bf

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. replay.mp4 +0 -0
  2. results.json +1 -1
  3. unit4.ipynb +53 -27
replay.mp4 CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
 
results.json CHANGED
@@ -1 +1 @@
1
- {"env_id": "CartPole-v1", "mean_reward": 500.0, "n_evaluation_episodes": 10, "eval_datetime": "2024-09-25T19:18:52.143176"}
 
1
+ {"env_id": "CartPole-v1", "mean_reward": 500.0, "n_evaluation_episodes": 10, "eval_datetime": "2024-09-25T19:28:40.813449"}
unit4.ipynb CHANGED
@@ -308,7 +308,7 @@
308
  },
309
  {
310
  "cell_type": "code",
311
- "execution_count": 2,
312
  "metadata": {
313
  "id": "V8oadoJSWp7C"
314
  },
@@ -351,7 +351,7 @@
351
  },
352
  {
353
  "cell_type": "code",
354
- "execution_count": 3,
355
  "metadata": {
356
  "id": "kaJu5FeZxXGY"
357
  },
@@ -362,7 +362,7 @@
362
  },
363
  {
364
  "cell_type": "code",
365
- "execution_count": 4,
366
  "metadata": {
367
  "id": "U5TNYa14aRav"
368
  },
@@ -440,7 +440,7 @@
440
  },
441
  {
442
  "cell_type": "code",
443
- "execution_count": 5,
444
  "metadata": {
445
  "id": "POOOk15_K6KA"
446
  },
@@ -463,7 +463,7 @@
463
  },
464
  {
465
  "cell_type": "code",
466
- "execution_count": 6,
467
  "metadata": {
468
  "id": "FMLFrjiBNLYJ"
469
  },
@@ -475,7 +475,7 @@
475
  "_____OBSERVATION SPACE_____ \n",
476
  "\n",
477
  "The State Space is: 4\n",
478
- "Sample observation [ 1.5935603e+00 -1.3312091e+38 3.4518537e-01 3.0665638e+38]\n"
479
  ]
480
  }
481
  ],
@@ -487,7 +487,7 @@
487
  },
488
  {
489
  "cell_type": "code",
490
- "execution_count": 7,
491
  "metadata": {
492
  "id": "Lu6t4sRNNWkN"
493
  },
@@ -500,7 +500,7 @@
500
  " _____ACTION SPACE_____ \n",
501
  "\n",
502
  "The Action Space is: 2\n",
503
- "Action Space Sample 1\n"
504
  ]
505
  }
506
  ],
@@ -539,7 +539,7 @@
539
  },
540
  {
541
  "cell_type": "code",
542
- "execution_count": 8,
543
  "metadata": {
544
  "id": "w2LHcHhVZvPZ"
545
  },
@@ -768,7 +768,7 @@
768
  },
769
  {
770
  "cell_type": "code",
771
- "execution_count": 9,
772
  "metadata": {
773
  "id": "iOdv8Q9NfLK7"
774
  },
@@ -969,7 +969,7 @@
969
  },
970
  {
971
  "cell_type": "code",
972
- "execution_count": 10,
973
  "metadata": {
974
  "id": "utRe1NgtVBYF"
975
  },
@@ -1046,7 +1046,7 @@
1046
  },
1047
  {
1048
  "cell_type": "code",
1049
- "execution_count": 11,
1050
  "metadata": {
1051
  "id": "3FamHmxyhBEU"
1052
  },
@@ -1140,7 +1140,7 @@
1140
  },
1141
  {
1142
  "cell_type": "code",
1143
- "execution_count": 12,
1144
  "metadata": {
1145
  "id": "LIVsvlW_8tcw"
1146
  },
@@ -1165,7 +1165,7 @@
1165
  },
1166
  {
1167
  "cell_type": "code",
1168
- "execution_count": 13,
1169
  "metadata": {
1170
  "id": "Lo4JH45if81z"
1171
  },
@@ -1188,12 +1188,8 @@
1188
  " # Set up the video writer\n",
1189
  " fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for the video\n",
1190
  " video_writer = cv2.VideoWriter(out_directory, fourcc, fps, (width, height))\n",
1191
- " print(out_directory)\n",
1192
- " step=0\n",
1193
  " \n",
1194
  " while not done:\n",
1195
- " print(step)\n",
1196
- " step+=1\n",
1197
  " # Take the action determined by the policy\n",
1198
  " action, _ = policy.act(state)\n",
1199
  " state, reward, terminated, truncated, _ = env.step(action)\n",
@@ -1213,7 +1209,36 @@
1213
  },
1214
  {
1215
  "cell_type": "code",
1216
- "execution_count": 15,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1217
  "metadata": {
1218
  "id": "_TPdq47D7_f_"
1219
  },
@@ -1329,7 +1354,6 @@
1329
  " # Step 6: Record a video\n",
1330
  " video_path = current_directory / \"replay.mp4\"\n",
1331
  " record_video(render_env, model, video_path, video_fps)\n",
1332
- " print(\"video recorded\")\n",
1333
  "\n",
1334
  " # Step 7. Push everything to the Hub\n",
1335
  " api.upload_folder(\n",
@@ -1377,7 +1401,7 @@
1377
  },
1378
  {
1379
  "cell_type": "code",
1380
- "execution_count": 16,
1381
  "metadata": {
1382
  "id": "QB5nIcxR8paT"
1383
  },
@@ -1385,7 +1409,7 @@
1385
  {
1386
  "data": {
1387
  "application/vnd.jupyter.widget-view+json": {
1388
- "model_id": "5cda3e494c2e4cc3bf833eb935891576",
1389
  "version_major": 2,
1390
  "version_minor": 0
1391
  },
@@ -1421,7 +1445,7 @@
1421
  },
1422
  {
1423
  "cell_type": "code",
1424
- "execution_count": 17,
1425
  "metadata": {},
1426
  "outputs": [],
1427
  "source": [
@@ -1447,7 +1471,7 @@
1447
  },
1448
  {
1449
  "cell_type": "code",
1450
- "execution_count": 18,
1451
  "metadata": {},
1452
  "outputs": [
1453
  {
@@ -1461,7 +1485,7 @@
1461
  "name": "stderr",
1462
  "output_type": "stream",
1463
  "text": [
1464
- "C:\\Users\\Utilisateur\\AppData\\Local\\Temp\\ipykernel_472\\2361131237.py:16: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
1465
  " model = torch.load(model_path)\n"
1466
  ]
1467
  }
@@ -1476,7 +1500,7 @@
1476
  },
1477
  {
1478
  "cell_type": "code",
1479
- "execution_count": null,
1480
  "metadata": {
1481
  "id": "UNwkTS65Uq3Q"
1482
  },
@@ -1985,7 +2009,9 @@
1985
  "496\n",
1986
  "497\n",
1987
  "498\n",
1988
- "499\n"
 
 
1989
  ]
1990
  }
1991
  ],
 
308
  },
309
  {
310
  "cell_type": "code",
311
+ "execution_count": 1,
312
  "metadata": {
313
  "id": "V8oadoJSWp7C"
314
  },
 
351
  },
352
  {
353
  "cell_type": "code",
354
+ "execution_count": 2,
355
  "metadata": {
356
  "id": "kaJu5FeZxXGY"
357
  },
 
362
  },
363
  {
364
  "cell_type": "code",
365
+ "execution_count": 3,
366
  "metadata": {
367
  "id": "U5TNYa14aRav"
368
  },
 
440
  },
441
  {
442
  "cell_type": "code",
443
+ "execution_count": 4,
444
  "metadata": {
445
  "id": "POOOk15_K6KA"
446
  },
 
463
  },
464
  {
465
  "cell_type": "code",
466
+ "execution_count": 5,
467
  "metadata": {
468
  "id": "FMLFrjiBNLYJ"
469
  },
 
475
  "_____OBSERVATION SPACE_____ \n",
476
  "\n",
477
  "The State Space is: 4\n",
478
+ "Sample observation [1.0296973e+00 2.7594529e+38 3.3057278e-01 3.1155016e+38]\n"
479
  ]
480
  }
481
  ],
 
487
  },
488
  {
489
  "cell_type": "code",
490
+ "execution_count": 6,
491
  "metadata": {
492
  "id": "Lu6t4sRNNWkN"
493
  },
 
500
  " _____ACTION SPACE_____ \n",
501
  "\n",
502
  "The Action Space is: 2\n",
503
+ "Action Space Sample 0\n"
504
  ]
505
  }
506
  ],
 
539
  },
540
  {
541
  "cell_type": "code",
542
+ "execution_count": 7,
543
  "metadata": {
544
  "id": "w2LHcHhVZvPZ"
545
  },
 
768
  },
769
  {
770
  "cell_type": "code",
771
+ "execution_count": 8,
772
  "metadata": {
773
  "id": "iOdv8Q9NfLK7"
774
  },
 
969
  },
970
  {
971
  "cell_type": "code",
972
+ "execution_count": 9,
973
  "metadata": {
974
  "id": "utRe1NgtVBYF"
975
  },
 
1046
  },
1047
  {
1048
  "cell_type": "code",
1049
+ "execution_count": 10,
1050
  "metadata": {
1051
  "id": "3FamHmxyhBEU"
1052
  },
 
1140
  },
1141
  {
1142
  "cell_type": "code",
1143
+ "execution_count": 11,
1144
  "metadata": {
1145
  "id": "LIVsvlW_8tcw"
1146
  },
 
1165
  },
1166
  {
1167
  "cell_type": "code",
1168
+ "execution_count": 20,
1169
  "metadata": {
1170
  "id": "Lo4JH45if81z"
1171
  },
 
1188
  " # Set up the video writer\n",
1189
  " fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for the video\n",
1190
  " video_writer = cv2.VideoWriter(out_directory, fourcc, fps, (width, height))\n",
 
 
1191
  " \n",
1192
  " while not done:\n",
 
 
1193
  " # Take the action determined by the policy\n",
1194
  " action, _ = policy.act(state)\n",
1195
  " state, reward, terminated, truncated, _ = env.step(action)\n",
 
1209
  },
1210
  {
1211
  "cell_type": "code",
1212
+ "execution_count": null,
1213
+ "metadata": {},
1214
+ "outputs": [],
1215
+ "source": [
1216
+ "def record_video(env, policy, out_directory, fps=30):\n",
1217
+ " \"\"\"\n",
1218
+ " Generate a replay video of the agent\n",
1219
+ " :param env\n",
1220
+ " :param Qtable: Qtable of our agent\n",
1221
+ " :param out_directory\n",
1222
+ " :param fps: how many frame per seconds (with taxi-v3 and frozenlake-v1 we use 1)\n",
1223
+ " \"\"\"\n",
1224
+ " images = []\n",
1225
+ " done = False\n",
1226
+ " state = env.reset()\n",
1227
+ " img = env.render()\n",
1228
+ " images.append(img)\n",
1229
+ " while not done:\n",
1230
+ " # Take the action (index) that have the maximum expected future reward given that state\n",
1231
+ " action, _ = policy.act(state)\n",
1232
+ " state, reward, terminated, truncated, _ = env.step(action) # We directly put next_state = state for recording logic\n",
1233
+ " done = \n",
1234
+ " img = env.render()\n",
1235
+ " images.append(img)\n",
1236
+ " imageio.mimsave(out_directory, [np.array(img) for i, img in enumerate(images)], fps=fps)"
1237
+ ]
1238
+ },
1239
+ {
1240
+ "cell_type": "code",
1241
+ "execution_count": 22,
1242
  "metadata": {
1243
  "id": "_TPdq47D7_f_"
1244
  },
 
1354
  " # Step 6: Record a video\n",
1355
  " video_path = current_directory / \"replay.mp4\"\n",
1356
  " record_video(render_env, model, video_path, video_fps)\n",
 
1357
  "\n",
1358
  " # Step 7. Push everything to the Hub\n",
1359
  " api.upload_folder(\n",
 
1401
  },
1402
  {
1403
  "cell_type": "code",
1404
+ "execution_count": 14,
1405
  "metadata": {
1406
  "id": "QB5nIcxR8paT"
1407
  },
 
1409
  {
1410
  "data": {
1411
  "application/vnd.jupyter.widget-view+json": {
1412
+ "model_id": "c3448403be744d20bba1601212f90dd2",
1413
  "version_major": 2,
1414
  "version_minor": 0
1415
  },
 
1445
  },
1446
  {
1447
  "cell_type": "code",
1448
+ "execution_count": 15,
1449
  "metadata": {},
1450
  "outputs": [],
1451
  "source": [
 
1471
  },
1472
  {
1473
  "cell_type": "code",
1474
+ "execution_count": 16,
1475
  "metadata": {},
1476
  "outputs": [
1477
  {
 
1485
  "name": "stderr",
1486
  "output_type": "stream",
1487
  "text": [
1488
+ "C:\\Users\\Utilisateur\\AppData\\Local\\Temp\\ipykernel_2256\\2361131237.py:16: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
1489
  " model = torch.load(model_path)\n"
1490
  ]
1491
  }
 
1500
  },
1501
  {
1502
  "cell_type": "code",
1503
+ "execution_count": 19,
1504
  "metadata": {
1505
  "id": "UNwkTS65Uq3Q"
1506
  },
 
2009
  "496\n",
2010
  "497\n",
2011
  "498\n",
2012
+ "499\n",
2013
+ "video recorded\n",
2014
+ "Your model is pushed to the Hub. You can view your model here: https://huggingface.co/Stoub/Reinforce-Cartpole-v1\n"
2015
  ]
2016
  }
2017
  ],