Commit 
							
							·
						
						c6e31cd
	
1
								Parent(s):
							
							43e364e
								
PPO playing Walker2DBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- README.md +17 -14
 - pyproject.toml +23 -2
 - replay.meta.json +1 -1
 - replay.mp4 +2 -2
 - rl_algo_impls/a2c/a2c.py +13 -19
 - rl_algo_impls/a2c/optimize.py +1 -1
 - rl_algo_impls/benchmark_publish.py +2 -2
 - rl_algo_impls/compare_runs.py +2 -1
 - rl_algo_impls/dqn/policy.py +14 -7
 - rl_algo_impls/dqn/q_net.py +6 -6
 - rl_algo_impls/huggingface_publish.py +1 -1
 - rl_algo_impls/hyperparams/a2c.yml +17 -13
 - rl_algo_impls/hyperparams/dqn.yml +1 -1
 - rl_algo_impls/hyperparams/ppo.yml +125 -5
 - rl_algo_impls/hyperparams/vpg.yml +4 -4
 - rl_algo_impls/optimize.py +5 -4
 - rl_algo_impls/ppo/ppo.py +248 -227
 - rl_algo_impls/runner/config.py +9 -3
 - rl_algo_impls/runner/evaluate.py +2 -2
 - rl_algo_impls/runner/running_utils.py +33 -18
 - rl_algo_impls/runner/train.py +11 -10
 - rl_algo_impls/shared/actor/__init__.py +2 -0
 - rl_algo_impls/shared/actor/actor.py +42 -0
 - rl_algo_impls/shared/actor/categorical.py +64 -0
 - rl_algo_impls/shared/actor/gaussian.py +61 -0
 - rl_algo_impls/shared/actor/gridnet.py +108 -0
 - rl_algo_impls/shared/actor/gridnet_decoder.py +80 -0
 - rl_algo_impls/shared/actor/make_actor.py +95 -0
 - rl_algo_impls/shared/actor/multi_discrete.py +101 -0
 - rl_algo_impls/shared/{policy/actor.py → actor/state_dependent_noise.py} +33 -143
 - rl_algo_impls/shared/callbacks/eval_callback.py +26 -9
 - rl_algo_impls/shared/encoder/__init__.py +2 -0
 - rl_algo_impls/shared/encoder/cnn.py +72 -0
 - rl_algo_impls/shared/encoder/encoder.py +73 -0
 - rl_algo_impls/shared/encoder/gridnet_encoder.py +64 -0
 - rl_algo_impls/shared/encoder/impala_cnn.py +92 -0
 - rl_algo_impls/shared/encoder/microrts_cnn.py +45 -0
 - rl_algo_impls/shared/encoder/nature_cnn.py +53 -0
 - rl_algo_impls/shared/gae.py +29 -2
 - rl_algo_impls/shared/module/feature_extractor.py +0 -215
 - rl_algo_impls/shared/module/module.py +6 -3
 - rl_algo_impls/shared/policy/critic.py +22 -10
 - rl_algo_impls/shared/policy/on_policy.py +57 -34
 - rl_algo_impls/shared/policy/policy.py +6 -1
 - rl_algo_impls/shared/schedule.py +29 -1
 - rl_algo_impls/shared/stats.py +24 -6
 - rl_algo_impls/shared/vec_env/__init__.py +1 -0
 - rl_algo_impls/shared/vec_env/make_env.py +66 -0
 - rl_algo_impls/shared/vec_env/microrts.py +94 -0
 - rl_algo_impls/shared/vec_env/microrts_compat.py +49 -0
 
    	
        README.md
    CHANGED
    
    | 
         @@ -10,7 +10,7 @@ model-index: 
     | 
|
| 10 | 
         
             
              results:
         
     | 
| 11 | 
         
             
              - metrics:
         
     | 
| 12 | 
         
             
                - type: mean_reward
         
     | 
| 13 | 
         
            -
                  value:  
     | 
| 14 | 
         
             
                  name: mean_reward
         
     | 
| 15 | 
         
             
                task:
         
     | 
| 16 | 
         
             
                  type: reinforcement-learning
         
     | 
| 
         @@ -23,17 +23,17 @@ model-index: 
     | 
|
| 23 | 
         | 
| 24 | 
         
             
            This is a trained model of a **PPO** agent playing **Walker2DBulletEnv-v0** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
         
     | 
| 25 | 
         | 
| 26 | 
         
            -
            All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/ 
     | 
| 27 | 
         | 
| 28 | 
         
             
            ## Training Results
         
     | 
| 29 | 
         | 
| 30 | 
         
            -
            This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [ 
     | 
| 31 | 
         | 
| 32 | 
         
             
            | algo   | env                  |   seed |   reward_mean |   reward_std |   eval_episodes | best   | wandb_url                                                                    |
         
     | 
| 33 | 
         
             
            |:-------|:---------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
         
     | 
| 34 | 
         
            -
            | ppo    | Walker2DBulletEnv-v0 |      1 |        
     | 
| 35 | 
         
            -
            | ppo    | Walker2DBulletEnv-v0 |      2 |        
     | 
| 36 | 
         
            -
            | ppo    | Walker2DBulletEnv-v0 |      3 |        
     | 
| 37 | 
         | 
| 38 | 
         | 
| 39 | 
         
             
            ### Prerequisites: Weights & Biases (WandB)
         
     | 
| 
         @@ -53,10 +53,10 @@ login`. 
     | 
|
| 53 | 
         
             
            Note: While the model state dictionary and hyperaparameters are saved, the latest
         
     | 
| 54 | 
         
             
            implementation could be sufficiently different to not be able to reproduce similar
         
     | 
| 55 | 
         
             
            results. You might need to checkout the commit the agent was trained on:
         
     | 
| 56 | 
         
            -
            [ 
     | 
| 57 | 
         
             
            ```
         
     | 
| 58 | 
         
             
            # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
         
     | 
| 59 | 
         
            -
            python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/ 
     | 
| 60 | 
         
             
            ```
         
     | 
| 61 | 
         | 
| 62 | 
         
             
            Setup hasn't been completely worked out yet, so you might be best served by using Google
         
     | 
| 
         @@ -68,11 +68,11 @@ notebook. 
     | 
|
| 68 | 
         | 
| 69 | 
         
             
            ## Training
         
     | 
| 70 | 
         
             
            If you want the highest chance to reproduce these results, you'll want to checkout the
         
     | 
| 71 | 
         
            -
            commit the agent was trained on: [ 
     | 
| 72 | 
         
             
            training is deterministic, different hardware will give different results.
         
     | 
| 73 | 
         | 
| 74 | 
         
             
            ```
         
     | 
| 75 | 
         
            -
            python train.py --algo ppo --env Walker2DBulletEnv-v0 --seed  
     | 
| 76 | 
         
             
            ```
         
     | 
| 77 | 
         | 
| 78 | 
         
             
            Setup hasn't been completely worked out yet, so you might be best served by using Google
         
     | 
| 
         @@ -83,7 +83,7 @@ notebook. 
     | 
|
| 83 | 
         | 
| 84 | 
         | 
| 85 | 
         
             
            ## Benchmarking (with Lambda Labs instance)
         
     | 
| 86 | 
         
            -
            This and other models from https://api.wandb.ai/links/sgoodfriend/ 
     | 
| 87 | 
         
             
            Labs instance. In a Lambda Labs instance terminal:
         
     | 
| 88 | 
         
             
            ```
         
     | 
| 89 | 
         
             
            git clone [email protected]:sgoodfriend/rl-algo-impls.git
         
     | 
| 
         @@ -105,6 +105,7 @@ can be used. However, this requires a Google Colab Pro+ subscription and running 
     | 
|
| 105 | 
         
             
            This isn't exactly the format of hyperparams in hyperparams/ppo.yml, but instead the Wandb Run Config. However, it's very
         
     | 
| 106 | 
         
             
            close and has some additional data:
         
     | 
| 107 | 
         
             
            ```
         
     | 
| 
         | 
|
| 108 | 
         
             
            algo: ppo
         
     | 
| 109 | 
         
             
            algo_hyperparams:
         
     | 
| 110 | 
         
             
              batch_size: 128
         
     | 
| 
         @@ -134,13 +135,15 @@ policy_hyperparams: 
     | 
|
| 134 | 
         
             
              v_hidden_sizes:
         
     | 
| 135 | 
         
             
              - 256
         
     | 
| 136 | 
         
             
              - 256
         
     | 
| 137 | 
         
            -
            seed:  
     | 
| 138 | 
         
             
            use_deterministic_algorithms: true
         
     | 
| 139 | 
         
             
            wandb_entity: null
         
     | 
| 140 | 
         
             
            wandb_group: null
         
     | 
| 141 | 
         
             
            wandb_project_name: rl-algo-impls-benchmarks
         
     | 
| 142 | 
         
             
            wandb_tags:
         
     | 
| 143 | 
         
            -
            -  
     | 
| 144 | 
         
            -
            -  
     | 
| 
         | 
|
| 
         | 
|
| 145 | 
         | 
| 146 | 
         
             
            ```
         
     | 
| 
         | 
|
| 10 | 
         
             
              results:
         
     | 
| 11 | 
         
             
              - metrics:
         
     | 
| 12 | 
         
             
                - type: mean_reward
         
     | 
| 13 | 
         
            +
                  value: 1943.44 +/- 6.03
         
     | 
| 14 | 
         
             
                  name: mean_reward
         
     | 
| 15 | 
         
             
                task:
         
     | 
| 16 | 
         
             
                  type: reinforcement-learning
         
     | 
| 
         | 
|
| 23 | 
         | 
| 24 | 
         
             
            This is a trained model of a **PPO** agent playing **Walker2DBulletEnv-v0** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
         
     | 
| 25 | 
         | 
| 26 | 
         
            +
            All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/7lx79bf0.
         
     | 
| 27 | 
         | 
| 28 | 
         
             
            ## Training Results
         
     | 
| 29 | 
         | 
| 30 | 
         
            +
            This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
         
     | 
| 31 | 
         | 
| 32 | 
         
             
            | algo   | env                  |   seed |   reward_mean |   reward_std |   eval_episodes | best   | wandb_url                                                                    |
         
     | 
| 33 | 
         
             
            |:-------|:---------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
         
     | 
| 34 | 
         
            +
            | ppo    | Walker2DBulletEnv-v0 |      1 |       1943.44 |      6.02595 |              16 | *      | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/rkpqqxbp) |
         
     | 
| 35 | 
         
            +
            | ppo    | Walker2DBulletEnv-v0 |      2 |       1821.93 |     13.1212  |              16 |        | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/2dxhttk3) |
         
     | 
| 36 | 
         
            +
            | ppo    | Walker2DBulletEnv-v0 |      3 |       2109.58 |    509.27    |              16 |        | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/ormofluw) |
         
     | 
| 37 | 
         | 
| 38 | 
         | 
| 39 | 
         
             
            ### Prerequisites: Weights & Biases (WandB)
         
     | 
| 
         | 
|
| 53 | 
         
             
            Note: While the model state dictionary and hyperaparameters are saved, the latest
         
     | 
| 54 | 
         
             
            implementation could be sufficiently different to not be able to reproduce similar
         
     | 
| 55 | 
         
             
            results. You might need to checkout the commit the agent was trained on:
         
     | 
| 56 | 
         
            +
            [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c).
         
     | 
| 57 | 
         
             
            ```
         
     | 
| 58 | 
         
             
            # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
         
     | 
| 59 | 
         
            +
            python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/rkpqqxbp
         
     | 
| 60 | 
         
             
            ```
         
     | 
| 61 | 
         | 
| 62 | 
         
             
            Setup hasn't been completely worked out yet, so you might be best served by using Google
         
     | 
| 
         | 
|
| 68 | 
         | 
| 69 | 
         
             
            ## Training
         
     | 
| 70 | 
         
             
            If you want the highest chance to reproduce these results, you'll want to checkout the
         
     | 
| 71 | 
         
            +
            commit the agent was trained on: [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c). While
         
     | 
| 72 | 
         
             
            training is deterministic, different hardware will give different results.
         
     | 
| 73 | 
         | 
| 74 | 
         
             
            ```
         
     | 
| 75 | 
         
            +
            python train.py --algo ppo --env Walker2DBulletEnv-v0 --seed 1
         
     | 
| 76 | 
         
             
            ```
         
     | 
| 77 | 
         | 
| 78 | 
         
             
            Setup hasn't been completely worked out yet, so you might be best served by using Google
         
     | 
| 
         | 
|
| 83 | 
         | 
| 84 | 
         | 
| 85 | 
         
             
            ## Benchmarking (with Lambda Labs instance)
         
     | 
| 86 | 
         
            +
            This and other models from https://api.wandb.ai/links/sgoodfriend/7lx79bf0 were generated by running a script on a Lambda
         
     | 
| 87 | 
         
             
            Labs instance. In a Lambda Labs instance terminal:
         
     | 
| 88 | 
         
             
            ```
         
     | 
| 89 | 
         
             
            git clone [email protected]:sgoodfriend/rl-algo-impls.git
         
     | 
| 
         | 
|
| 105 | 
         
             
            This isn't exactly the format of hyperparams in hyperparams/ppo.yml, but instead the Wandb Run Config. However, it's very
         
     | 
| 106 | 
         
             
            close and has some additional data:
         
     | 
| 107 | 
         
             
            ```
         
     | 
| 108 | 
         
            +
            additional_keys_to_log: []
         
     | 
| 109 | 
         
             
            algo: ppo
         
     | 
| 110 | 
         
             
            algo_hyperparams:
         
     | 
| 111 | 
         
             
              batch_size: 128
         
     | 
| 
         | 
|
| 135 | 
         
             
              v_hidden_sizes:
         
     | 
| 136 | 
         
             
              - 256
         
     | 
| 137 | 
         
             
              - 256
         
     | 
| 138 | 
         
            +
            seed: 1
         
     | 
| 139 | 
         
             
            use_deterministic_algorithms: true
         
     | 
| 140 | 
         
             
            wandb_entity: null
         
     | 
| 141 | 
         
             
            wandb_group: null
         
     | 
| 142 | 
         
             
            wandb_project_name: rl-algo-impls-benchmarks
         
     | 
| 143 | 
         
             
            wandb_tags:
         
     | 
| 144 | 
         
            +
            - benchmark_0511de3
         
     | 
| 145 | 
         
            +
            - host_152-67-249-42
         
     | 
| 146 | 
         
            +
            - branch_main
         
     | 
| 147 | 
         
            +
            - v0.0.8
         
     | 
| 148 | 
         | 
| 149 | 
         
             
            ```
         
     | 
    	
        pyproject.toml
    CHANGED
    
    | 
         @@ -1,6 +1,6 @@ 
     | 
|
| 1 | 
         
             
            [project]
         
     | 
| 2 | 
         
             
            name = "rl_algo_impls"
         
     | 
| 3 | 
         
            -
            version = "0.0. 
     | 
| 4 | 
         
             
            description = "Implementations of reinforcement learning algorithms"
         
     | 
| 5 | 
         
             
            authors = [
         
     | 
| 6 | 
         
             
                {name = "Scott Goodfriend", email = "[email protected]"},
         
     | 
| 
         @@ -35,6 +35,7 @@ dependencies = [ 
     | 
|
| 35 | 
         
             
                "dash",
         
     | 
| 36 | 
         
             
                "kaleido",
         
     | 
| 37 | 
         
             
                "PyYAML",
         
     | 
| 
         | 
|
| 38 | 
         
             
            ]
         
     | 
| 39 | 
         | 
| 40 | 
         
             
            [tool.setuptools]
         
     | 
| 
         @@ -55,10 +56,30 @@ procgen = [ 
     | 
|
| 55 | 
         
             
                "glfw >= 1.12.0, < 1.13",
         
     | 
| 56 | 
         
             
                "procgen; platform_machine=='x86_64'",
         
     | 
| 57 | 
         
             
            ]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 58 | 
         | 
| 59 | 
         
             
            [project.urls]
         
     | 
| 60 | 
         
             
            "Homepage" = "https://github.com/sgoodfriend/rl-algo-impls"
         
     | 
| 61 | 
         | 
| 62 | 
         
             
            [build-system]
         
     | 
| 63 | 
         
             
            requires = ["setuptools==65.5.0", "setuptools-scm"]
         
     | 
| 64 | 
         
            -
            build-backend = "setuptools.build_meta"
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
             
            [project]
         
     | 
| 2 | 
         
             
            name = "rl_algo_impls"
         
     | 
| 3 | 
         
            +
            version = "0.0.8"
         
     | 
| 4 | 
         
             
            description = "Implementations of reinforcement learning algorithms"
         
     | 
| 5 | 
         
             
            authors = [
         
     | 
| 6 | 
         
             
                {name = "Scott Goodfriend", email = "[email protected]"},
         
     | 
| 
         | 
|
| 35 | 
         
             
                "dash",
         
     | 
| 36 | 
         
             
                "kaleido",
         
     | 
| 37 | 
         
             
                "PyYAML",
         
     | 
| 38 | 
         
            +
                "scikit-learn",
         
     | 
| 39 | 
         
             
            ]
         
     | 
| 40 | 
         | 
| 41 | 
         
             
            [tool.setuptools]
         
     | 
| 
         | 
|
| 56 | 
         
             
                "glfw >= 1.12.0, < 1.13",
         
     | 
| 57 | 
         
             
                "procgen; platform_machine=='x86_64'",
         
     | 
| 58 | 
         
             
            ]
         
     | 
| 59 | 
         
            +
            microrts-old = [
         
     | 
| 60 | 
         
            +
                "numpy < 1.24.0", # Support for gym-microrts < 0.6.0
         
     | 
| 61 | 
         
            +
                "gym-microrts == 0.2.0", # Match ppo-implementation-details
         
     | 
| 62 | 
         
            +
            ]
         
     | 
| 63 | 
         
            +
            microrts = [
         
     | 
| 64 | 
         
            +
                "numpy < 1.24.0", # Support for gym-microrts < 0.6.0
         
     | 
| 65 | 
         
            +
                "gym-microrts == 0.3.2",
         
     | 
| 66 | 
         
            +
            ]
         
     | 
| 67 | 
         
            +
            jupyter = [
         
     | 
| 68 | 
         
            +
                "jupyter",
         
     | 
| 69 | 
         
            +
                "notebook"
         
     | 
| 70 | 
         
            +
            ]
         
     | 
| 71 | 
         
            +
            all = [
         
     | 
| 72 | 
         
            +
                "rl-algo-impls[test]",
         
     | 
| 73 | 
         
            +
                "rl-algo-impls[procgen]",
         
     | 
| 74 | 
         
            +
                "rl-algo-impls[microrts]",
         
     | 
| 75 | 
         
            +
            ]
         
     | 
| 76 | 
         | 
| 77 | 
         
             
            [project.urls]
         
     | 
| 78 | 
         
             
            "Homepage" = "https://github.com/sgoodfriend/rl-algo-impls"
         
     | 
| 79 | 
         | 
| 80 | 
         
             
            [build-system]
         
     | 
| 81 | 
         
             
            requires = ["setuptools==65.5.0", "setuptools-scm"]
         
     | 
| 82 | 
         
            +
            build-backend = "setuptools.build_meta"
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            [tool.isort]
         
     | 
| 85 | 
         
            +
            profile = "black" 
         
     | 
    	
        replay.meta.json
    CHANGED
    
    | 
         @@ -1 +1 @@ 
     | 
|
| 1 | 
         
            -
            {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil      56. 31.100 / 56. 31.100\\nlibavcodec     58. 54.100 / 58. 54.100\\nlibavformat    58. 29.100 / 58. 29.100\\nlibavdevice    58.  8.100 / 58.  8.100\\nlibavfilter     7. 57.100 /  7. 57.100\\nlibavresample   4.  0.  0 /  4.  0.  0\\nlibswscale      5.  5.100 /  5.  5.100\\nlibswresample   3.  5.100 /  3.  5.100\\nlibpostproc    55.  5.100 / 55.  5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "320x240", "-pix_fmt", "rgb24", "-framerate", "60", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "60", "/tmp/ 
     | 
| 
         | 
|
| 1 | 
         
            +
            {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil      56. 31.100 / 56. 31.100\\nlibavcodec     58. 54.100 / 58. 54.100\\nlibavformat    58. 29.100 / 58. 29.100\\nlibavdevice    58.  8.100 / 58.  8.100\\nlibavfilter     7. 57.100 /  7. 57.100\\nlibavresample   4.  0.  0 /  4.  0.  0\\nlibswscale      5.  5.100 /  5.  5.100\\nlibswresample   3.  5.100 /  3.  5.100\\nlibpostproc    55.  5.100 / 55.  5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "320x240", "-pix_fmt", "rgb24", "-framerate", "60", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "60", "/tmp/tmpitcja6vi/ppo-Walker2DBulletEnv-v0/replay.mp4"]}, "episode": {"r": 1936.2642822265625, "l": 1000, "t": 28.126744}}
         
     | 
    	
        replay.mp4
    CHANGED
    
    | 
         @@ -1,3 +1,3 @@ 
     | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            -
            oid sha256: 
     | 
| 3 | 
         
            -
            size  
     | 
| 
         | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:5a4a32646a8d7dee1c9f24f52e21be8dcd4b1f79a86ba04667c6054d729f162c
         
     | 
| 3 | 
         
            +
            size 1180592
         
     | 
    	
        rl_algo_impls/a2c/a2c.py
    CHANGED
    
    | 
         @@ -10,6 +10,7 @@ from typing import Optional, TypeVar 
     | 
|
| 10 | 
         | 
| 11 | 
         
             
            from rl_algo_impls.shared.algorithm import Algorithm
         
     | 
| 12 | 
         
             
            from rl_algo_impls.shared.callbacks.callback import Callback
         
     | 
| 
         | 
|
| 13 | 
         
             
            from rl_algo_impls.shared.policy.on_policy import ActorCritic
         
     | 
| 14 | 
         
             
            from rl_algo_impls.shared.schedule import schedule, update_learning_rate
         
     | 
| 15 | 
         
             
            from rl_algo_impls.shared.stats import log_scalars
         
     | 
| 
         @@ -84,12 +85,12 @@ class A2C(Algorithm): 
     | 
|
| 84 | 
         
             
                    obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
         
     | 
| 85 | 
         
             
                    actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
         
     | 
| 86 | 
         
             
                    rewards = np.zeros(epoch_dim, dtype=np.float32)
         
     | 
| 87 | 
         
            -
                    episode_starts = np.zeros(epoch_dim, dtype=np. 
     | 
| 88 | 
         
             
                    values = np.zeros(epoch_dim, dtype=np.float32)
         
     | 
| 89 | 
         
             
                    logprobs = np.zeros(epoch_dim, dtype=np.float32)
         
     | 
| 90 | 
         | 
| 91 | 
         
             
                    next_obs = self.env.reset()
         
     | 
| 92 | 
         
            -
                    next_episode_starts = np. 
     | 
| 93 | 
         | 
| 94 | 
         
             
                    timesteps_elapsed = start_timesteps
         
     | 
| 95 | 
         
             
                    while timesteps_elapsed < start_timesteps + train_timesteps:
         
     | 
| 
         @@ -126,23 +127,16 @@ class A2C(Algorithm): 
     | 
|
| 126 | 
         
             
                                clamped_action
         
     | 
| 127 | 
         
             
                            )
         
     | 
| 128 | 
         | 
| 129 | 
         
            -
                        advantages =  
     | 
| 130 | 
         
            -
             
     | 
| 131 | 
         
            -
             
     | 
| 132 | 
         
            -
                             
     | 
| 133 | 
         
            -
             
     | 
| 134 | 
         
            -
             
     | 
| 135 | 
         
            -
                             
     | 
| 136 | 
         
            -
             
     | 
| 137 | 
         
            -
             
     | 
| 138 | 
         
            -
             
     | 
| 139 | 
         
            -
                                rewards[t] + self.gamma * next_value * next_nonterminal - values[t]
         
     | 
| 140 | 
         
            -
                            )
         
     | 
| 141 | 
         
            -
                            last_gae_lam = (
         
     | 
| 142 | 
         
            -
                                delta
         
     | 
| 143 | 
         
            -
                                + self.gamma * self.gae_lambda * next_nonterminal * last_gae_lam
         
     | 
| 144 | 
         
            -
                            )
         
     | 
| 145 | 
         
            -
                            advantages[t] = last_gae_lam
         
     | 
| 146 | 
         
             
                        returns = advantages + values
         
     | 
| 147 | 
         | 
| 148 | 
         
             
                        b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
         
     | 
| 
         | 
|
| 10 | 
         | 
| 11 | 
         
             
            from rl_algo_impls.shared.algorithm import Algorithm
         
     | 
| 12 | 
         
             
            from rl_algo_impls.shared.callbacks.callback import Callback
         
     | 
| 13 | 
         
            +
            from rl_algo_impls.shared.gae import compute_advantages
         
     | 
| 14 | 
         
             
            from rl_algo_impls.shared.policy.on_policy import ActorCritic
         
     | 
| 15 | 
         
             
            from rl_algo_impls.shared.schedule import schedule, update_learning_rate
         
     | 
| 16 | 
         
             
            from rl_algo_impls.shared.stats import log_scalars
         
     | 
| 
         | 
|
| 85 | 
         
             
                    obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
         
     | 
| 86 | 
         
             
                    actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
         
     | 
| 87 | 
         
             
                    rewards = np.zeros(epoch_dim, dtype=np.float32)
         
     | 
| 88 | 
         
            +
                    episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
         
     | 
| 89 | 
         
             
                    values = np.zeros(epoch_dim, dtype=np.float32)
         
     | 
| 90 | 
         
             
                    logprobs = np.zeros(epoch_dim, dtype=np.float32)
         
     | 
| 91 | 
         | 
| 92 | 
         
             
                    next_obs = self.env.reset()
         
     | 
| 93 | 
         
            +
                    next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
         
     | 
| 94 | 
         | 
| 95 | 
         
             
                    timesteps_elapsed = start_timesteps
         
     | 
| 96 | 
         
             
                    while timesteps_elapsed < start_timesteps + train_timesteps:
         
     | 
| 
         | 
|
| 127 | 
         
             
                                clamped_action
         
     | 
| 128 | 
         
             
                            )
         
     | 
| 129 | 
         | 
| 130 | 
         
            +
                        advantages = compute_advantages(
         
     | 
| 131 | 
         
            +
                            rewards,
         
     | 
| 132 | 
         
            +
                            values,
         
     | 
| 133 | 
         
            +
                            episode_starts,
         
     | 
| 134 | 
         
            +
                            next_episode_starts,
         
     | 
| 135 | 
         
            +
                            next_obs,
         
     | 
| 136 | 
         
            +
                            self.policy,
         
     | 
| 137 | 
         
            +
                            self.gamma,
         
     | 
| 138 | 
         
            +
                            self.gae_lambda,
         
     | 
| 139 | 
         
            +
                        )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 140 | 
         
             
                        returns = advantages + values
         
     | 
| 141 | 
         | 
| 142 | 
         
             
                        b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
         
     | 
    	
        rl_algo_impls/a2c/optimize.py
    CHANGED
    
    | 
         @@ -3,7 +3,7 @@ import optuna 
     | 
|
| 3 | 
         
             
            from copy import deepcopy
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams
         
     | 
| 6 | 
         
            -
            from rl_algo_impls. 
     | 
| 7 | 
         
             
            from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
         
     | 
| 8 | 
         
             
            from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
         
     | 
| 9 | 
         | 
| 
         | 
|
| 3 | 
         
             
            from copy import deepcopy
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams
         
     | 
| 6 | 
         
            +
            from rl_algo_impls.shared.vec_env import make_eval_env
         
     | 
| 7 | 
         
             
            from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
         
     | 
| 8 | 
         
             
            from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
         
     | 
| 9 | 
         | 
    	
        rl_algo_impls/benchmark_publish.py
    CHANGED
    
    | 
         @@ -54,8 +54,8 @@ def benchmark_publish() -> None: 
     | 
|
| 54 | 
         
             
                    "--virtual-display", action="store_true", help="Use headless virtual display"
         
     | 
| 55 | 
         
             
                )
         
     | 
| 56 | 
         
             
                # parser.set_defaults(
         
     | 
| 57 | 
         
            -
                #     wandb_tags=[" 
     | 
| 58 | 
         
            -
                #     wandb_report_url="https://api.wandb.ai/links/sgoodfriend/ 
     | 
| 59 | 
         
             
                #     envs=[],
         
     | 
| 60 | 
         
             
                #     exclude_envs=[],
         
     | 
| 61 | 
         
             
                # )
         
     | 
| 
         | 
|
| 54 | 
         
             
                    "--virtual-display", action="store_true", help="Use headless virtual display"
         
     | 
| 55 | 
         
             
                )
         
     | 
| 56 | 
         
             
                # parser.set_defaults(
         
     | 
| 57 | 
         
            +
                #     wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
         
     | 
| 58 | 
         
            +
                #     wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
         
     | 
| 59 | 
         
             
                #     envs=[],
         
     | 
| 60 | 
         
             
                #     exclude_envs=[],
         
     | 
| 61 | 
         
             
                # )
         
     | 
    	
        rl_algo_impls/compare_runs.py
    CHANGED
    
    | 
         @@ -194,5 +194,6 @@ def compare_runs() -> None: 
     | 
|
| 194 | 
         
             
                df.loc["mean"] = df.mean(numeric_only=True)
         
     | 
| 195 | 
         
             
                print(df.to_markdown())
         
     | 
| 196 | 
         | 
| 
         | 
|
| 197 | 
         
             
            if __name__ == "__main__":
         
     | 
| 198 | 
         
            -
                compare_runs()
         
     | 
| 
         | 
|
| 194 | 
         
             
                df.loc["mean"] = df.mean(numeric_only=True)
         
     | 
| 195 | 
         
             
                print(df.to_markdown())
         
     | 
| 196 | 
         | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
             
            if __name__ == "__main__":
         
     | 
| 199 | 
         
            +
                compare_runs()
         
     | 
    	
        rl_algo_impls/dqn/policy.py
    CHANGED
    
    | 
         @@ -1,16 +1,16 @@ 
     | 
|
| 1 | 
         
            -
            import numpy as np
         
     | 
| 2 | 
         
             
            import os
         
     | 
| 3 | 
         
            -
            import torch
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
             
            from typing import Optional, Sequence, TypeVar
         
     | 
| 6 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 7 | 
         
             
            from rl_algo_impls.dqn.q_net import QNetwork
         
     | 
| 8 | 
         
             
            from rl_algo_impls.shared.policy.policy import Policy
         
     | 
| 9 | 
         
             
            from rl_algo_impls.wrappers.vectorable_wrapper import (
         
     | 
| 10 | 
         
             
                VecEnv,
         
     | 
| 11 | 
         
             
                VecEnvObs,
         
     | 
| 12 | 
         
            -
                single_observation_space,
         
     | 
| 13 | 
         
             
                single_action_space,
         
     | 
| 
         | 
|
| 14 | 
         
             
            )
         
     | 
| 15 | 
         | 
| 16 | 
         
             
            DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
         
     | 
| 
         @@ -21,7 +21,7 @@ class DQNPolicy(Policy): 
     | 
|
| 21 | 
         
             
                    self,
         
     | 
| 22 | 
         
             
                    env: VecEnv,
         
     | 
| 23 | 
         
             
                    hidden_sizes: Sequence[int] = [],
         
     | 
| 24 | 
         
            -
                     
     | 
| 25 | 
         
             
                    cnn_style: str = "nature",
         
     | 
| 26 | 
         
             
                    cnn_layers_init_orthogonal: Optional[bool] = None,
         
     | 
| 27 | 
         
             
                    impala_channels: Sequence[int] = (16, 32, 32),
         
     | 
| 
         @@ -32,16 +32,23 @@ class DQNPolicy(Policy): 
     | 
|
| 32 | 
         
             
                        single_observation_space(env),
         
     | 
| 33 | 
         
             
                        single_action_space(env),
         
     | 
| 34 | 
         
             
                        hidden_sizes,
         
     | 
| 35 | 
         
            -
                         
     | 
| 36 | 
         
             
                        cnn_style=cnn_style,
         
     | 
| 37 | 
         
             
                        cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
         
     | 
| 38 | 
         
             
                        impala_channels=impala_channels,
         
     | 
| 39 | 
         
             
                    )
         
     | 
| 40 | 
         | 
| 41 | 
         
             
                def act(
         
     | 
| 42 | 
         
            -
                    self, 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 43 | 
         
             
                ) -> np.ndarray:
         
     | 
| 44 | 
         
             
                    assert eps == 0 if deterministic else eps >= 0
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 45 | 
         
             
                    if not deterministic and np.random.random() < eps:
         
     | 
| 46 | 
         
             
                        return np.array(
         
     | 
| 47 | 
         
             
                            [
         
     | 
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import os
         
     | 
| 
         | 
|
| 
         | 
|
| 2 | 
         
             
            from typing import Optional, Sequence, TypeVar
         
     | 
| 3 | 
         | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
             
            from rl_algo_impls.dqn.q_net import QNetwork
         
     | 
| 8 | 
         
             
            from rl_algo_impls.shared.policy.policy import Policy
         
     | 
| 9 | 
         
             
            from rl_algo_impls.wrappers.vectorable_wrapper import (
         
     | 
| 10 | 
         
             
                VecEnv,
         
     | 
| 11 | 
         
             
                VecEnvObs,
         
     | 
| 
         | 
|
| 12 | 
         
             
                single_action_space,
         
     | 
| 13 | 
         
            +
                single_observation_space,
         
     | 
| 14 | 
         
             
            )
         
     | 
| 15 | 
         | 
| 16 | 
         
             
            DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
         
     | 
| 
         | 
|
| 21 | 
         
             
                    self,
         
     | 
| 22 | 
         
             
                    env: VecEnv,
         
     | 
| 23 | 
         
             
                    hidden_sizes: Sequence[int] = [],
         
     | 
| 24 | 
         
            +
                    cnn_flatten_dim: int = 512,
         
     | 
| 25 | 
         
             
                    cnn_style: str = "nature",
         
     | 
| 26 | 
         
             
                    cnn_layers_init_orthogonal: Optional[bool] = None,
         
     | 
| 27 | 
         
             
                    impala_channels: Sequence[int] = (16, 32, 32),
         
     | 
| 
         | 
|
| 32 | 
         
             
                        single_observation_space(env),
         
     | 
| 33 | 
         
             
                        single_action_space(env),
         
     | 
| 34 | 
         
             
                        hidden_sizes,
         
     | 
| 35 | 
         
            +
                        cnn_flatten_dim=cnn_flatten_dim,
         
     | 
| 36 | 
         
             
                        cnn_style=cnn_style,
         
     | 
| 37 | 
         
             
                        cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
         
     | 
| 38 | 
         
             
                        impala_channels=impala_channels,
         
     | 
| 39 | 
         
             
                    )
         
     | 
| 40 | 
         | 
| 41 | 
         
             
                def act(
         
     | 
| 42 | 
         
            +
                    self,
         
     | 
| 43 | 
         
            +
                    obs: VecEnvObs,
         
     | 
| 44 | 
         
            +
                    eps: float = 0,
         
     | 
| 45 | 
         
            +
                    deterministic: bool = True,
         
     | 
| 46 | 
         
            +
                    action_masks: Optional[np.ndarray] = None,
         
     | 
| 47 | 
         
             
                ) -> np.ndarray:
         
     | 
| 48 | 
         
             
                    assert eps == 0 if deterministic else eps >= 0
         
     | 
| 49 | 
         
            +
                    assert (
         
     | 
| 50 | 
         
            +
                        action_masks is None
         
     | 
| 51 | 
         
            +
                    ), f"action_masks not currently supported in {self.__class__.__name__}"
         
     | 
| 52 | 
         
             
                    if not deterministic and np.random.random() < eps:
         
     | 
| 53 | 
         
             
                        return np.array(
         
     | 
| 54 | 
         
             
                            [
         
     | 
    	
        rl_algo_impls/dqn/q_net.py
    CHANGED
    
    | 
         @@ -1,11 +1,11 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import gym
         
     | 
| 2 | 
         
             
            import torch as th
         
     | 
| 3 | 
         
             
            import torch.nn as nn
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
             
            from gym.spaces import Discrete
         
     | 
| 6 | 
         
            -
            from typing import Optional, Sequence, Type
         
     | 
| 7 | 
         | 
| 8 | 
         
            -
            from rl_algo_impls.shared. 
     | 
| 9 | 
         
             
            from rl_algo_impls.shared.module.module import mlp
         
     | 
| 10 | 
         | 
| 11 | 
         | 
| 
         @@ -16,17 +16,17 @@ class QNetwork(nn.Module): 
     | 
|
| 16 | 
         
             
                    action_space: gym.Space,
         
     | 
| 17 | 
         
             
                    hidden_sizes: Sequence[int] = [],
         
     | 
| 18 | 
         
             
                    activation: Type[nn.Module] = nn.ReLU,  # Used by stable-baselines3
         
     | 
| 19 | 
         
            -
                     
     | 
| 20 | 
         
             
                    cnn_style: str = "nature",
         
     | 
| 21 | 
         
             
                    cnn_layers_init_orthogonal: Optional[bool] = None,
         
     | 
| 22 | 
         
             
                    impala_channels: Sequence[int] = (16, 32, 32),
         
     | 
| 23 | 
         
             
                ) -> None:
         
     | 
| 24 | 
         
             
                    super().__init__()
         
     | 
| 25 | 
         
             
                    assert isinstance(action_space, Discrete)
         
     | 
| 26 | 
         
            -
                    self._feature_extractor =  
     | 
| 27 | 
         
             
                        observation_space,
         
     | 
| 28 | 
         
             
                        activation,
         
     | 
| 29 | 
         
            -
                         
     | 
| 30 | 
         
             
                        cnn_style=cnn_style,
         
     | 
| 31 | 
         
             
                        cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
         
     | 
| 32 | 
         
             
                        impala_channels=impala_channels,
         
     | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Sequence, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
             
            import gym
         
     | 
| 4 | 
         
             
            import torch as th
         
     | 
| 5 | 
         
             
            import torch.nn as nn
         
     | 
| 
         | 
|
| 6 | 
         
             
            from gym.spaces import Discrete
         
     | 
| 
         | 
|
| 7 | 
         | 
| 8 | 
         
            +
            from rl_algo_impls.shared.encoder import Encoder
         
     | 
| 9 | 
         
             
            from rl_algo_impls.shared.module.module import mlp
         
     | 
| 10 | 
         | 
| 11 | 
         | 
| 
         | 
|
| 16 | 
         
             
                    action_space: gym.Space,
         
     | 
| 17 | 
         
             
                    hidden_sizes: Sequence[int] = [],
         
     | 
| 18 | 
         
             
                    activation: Type[nn.Module] = nn.ReLU,  # Used by stable-baselines3
         
     | 
| 19 | 
         
            +
                    cnn_flatten_dim: int = 512,
         
     | 
| 20 | 
         
             
                    cnn_style: str = "nature",
         
     | 
| 21 | 
         
             
                    cnn_layers_init_orthogonal: Optional[bool] = None,
         
     | 
| 22 | 
         
             
                    impala_channels: Sequence[int] = (16, 32, 32),
         
     | 
| 23 | 
         
             
                ) -> None:
         
     | 
| 24 | 
         
             
                    super().__init__()
         
     | 
| 25 | 
         
             
                    assert isinstance(action_space, Discrete)
         
     | 
| 26 | 
         
            +
                    self._feature_extractor = Encoder(
         
     | 
| 27 | 
         
             
                        observation_space,
         
     | 
| 28 | 
         
             
                        activation,
         
     | 
| 29 | 
         
            +
                        cnn_flatten_dim=cnn_flatten_dim,
         
     | 
| 30 | 
         
             
                        cnn_style=cnn_style,
         
     | 
| 31 | 
         
             
                        cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
         
     | 
| 32 | 
         
             
                        impala_channels=impala_channels,
         
     | 
    	
        rl_algo_impls/huggingface_publish.py
    CHANGED
    
    | 
         @@ -19,7 +19,7 @@ from pyvirtualdisplay.display import Display 
     | 
|
| 19 | 
         
             
            from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
         
     | 
| 20 | 
         
             
            from rl_algo_impls.runner.config import EnvHyperparams
         
     | 
| 21 | 
         
             
            from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
         
     | 
| 22 | 
         
            -
            from rl_algo_impls. 
     | 
| 23 | 
         
             
            from rl_algo_impls.shared.callbacks.eval_callback import evaluate
         
     | 
| 24 | 
         
             
            from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
         
     | 
| 25 | 
         | 
| 
         | 
|
| 19 | 
         
             
            from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
         
     | 
| 20 | 
         
             
            from rl_algo_impls.runner.config import EnvHyperparams
         
     | 
| 21 | 
         
             
            from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
         
     | 
| 22 | 
         
            +
            from rl_algo_impls.shared.vec_env import make_eval_env
         
     | 
| 23 | 
         
             
            from rl_algo_impls.shared.callbacks.eval_callback import evaluate
         
     | 
| 24 | 
         
             
            from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
         
     | 
| 25 | 
         | 
    	
        rl_algo_impls/hyperparams/a2c.yml
    CHANGED
    
    | 
         @@ -97,31 +97,35 @@ Walker2DBulletEnv-v0: 
     | 
|
| 97 | 
         
             
            HopperBulletEnv-v0:
         
     | 
| 98 | 
         
             
              <<: *pybullet-defaults
         
     | 
| 99 | 
         | 
| 
         | 
|
| 100 | 
         
             
            CarRacing-v0:
         
     | 
| 101 | 
         
             
              n_timesteps: !!float 4e6
         
     | 
| 102 | 
         
             
              env_hyperparams:
         
     | 
| 103 | 
         
            -
                n_envs:  
     | 
| 104 | 
         
             
                frame_stack: 4
         
     | 
| 105 | 
         
             
                normalize: true
         
     | 
| 106 | 
         
             
                normalize_kwargs:
         
     | 
| 107 | 
         
             
                  norm_obs: false
         
     | 
| 108 | 
         
             
                  norm_reward: true
         
     | 
| 109 | 
         
             
              policy_hyperparams:
         
     | 
| 110 | 
         
            -
                use_sde:  
     | 
| 111 | 
         
            -
                log_std_init: - 
     | 
| 112 | 
         
            -
                init_layers_orthogonal:  
     | 
| 113 | 
         
            -
                activation_fn:  
     | 
| 114 | 
         
             
                share_features_extractor: false
         
     | 
| 115 | 
         
            -
                 
     | 
| 116 | 
         
             
                hidden_sizes: [256]
         
     | 
| 117 | 
         
             
              algo_hyperparams:
         
     | 
| 118 | 
         
            -
                n_steps:  
     | 
| 119 | 
         
            -
                learning_rate:  
     | 
| 120 | 
         
            -
                 
     | 
| 121 | 
         
            -
                 
     | 
| 122 | 
         
            -
                 
     | 
| 123 | 
         
            -
                 
     | 
| 124 | 
         
            -
                vf_coef: 0. 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 125 | 
         | 
| 126 | 
         
             
            _atari: &atari-defaults
         
     | 
| 127 | 
         
             
              n_timesteps: !!float 1e7
         
     | 
| 
         | 
|
| 97 | 
         
             
            HopperBulletEnv-v0:
         
     | 
| 98 | 
         
             
              <<: *pybullet-defaults
         
     | 
| 99 | 
         | 
| 100 | 
         
            +
            # Tuned
         
     | 
| 101 | 
         
             
            CarRacing-v0:
         
     | 
| 102 | 
         
             
              n_timesteps: !!float 4e6
         
     | 
| 103 | 
         
             
              env_hyperparams:
         
     | 
| 104 | 
         
            +
                n_envs: 16
         
     | 
| 105 | 
         
             
                frame_stack: 4
         
     | 
| 106 | 
         
             
                normalize: true
         
     | 
| 107 | 
         
             
                normalize_kwargs:
         
     | 
| 108 | 
         
             
                  norm_obs: false
         
     | 
| 109 | 
         
             
                  norm_reward: true
         
     | 
| 110 | 
         
             
              policy_hyperparams:
         
     | 
| 111 | 
         
            +
                use_sde: false
         
     | 
| 112 | 
         
            +
                log_std_init: -1.3502584927786276
         
     | 
| 113 | 
         
            +
                init_layers_orthogonal: true
         
     | 
| 114 | 
         
            +
                activation_fn: tanh
         
     | 
| 115 | 
         
             
                share_features_extractor: false
         
     | 
| 116 | 
         
            +
                cnn_flatten_dim: 256
         
     | 
| 117 | 
         
             
                hidden_sizes: [256]
         
     | 
| 118 | 
         
             
              algo_hyperparams:
         
     | 
| 119 | 
         
            +
                n_steps: 16
         
     | 
| 120 | 
         
            +
                learning_rate: 0.000025630993245026736
         
     | 
| 121 | 
         
            +
                learning_rate_decay: linear
         
     | 
| 122 | 
         
            +
                gamma: 0.99957617037542
         
     | 
| 123 | 
         
            +
                gae_lambda: 0.949455676599436
         
     | 
| 124 | 
         
            +
                ent_coef: !!float 1.707983205298309e-7
         
     | 
| 125 | 
         
            +
                vf_coef: 0.10428178193833336
         
     | 
| 126 | 
         
            +
                max_grad_norm: 0.5406643389792273
         
     | 
| 127 | 
         
            +
                normalize_advantage: true
         
     | 
| 128 | 
         
            +
                use_rms_prop: false
         
     | 
| 129 | 
         | 
| 130 | 
         
             
            _atari: &atari-defaults
         
     | 
| 131 | 
         
             
              n_timesteps: !!float 1e7
         
     | 
    	
        rl_algo_impls/hyperparams/dqn.yml
    CHANGED
    
    | 
         @@ -108,7 +108,7 @@ _impala-atari: &impala-atari-defaults 
     | 
|
| 108 | 
         
             
              <<: *atari-defaults
         
     | 
| 109 | 
         
             
              policy_hyperparams:
         
     | 
| 110 | 
         
             
                cnn_style: impala
         
     | 
| 111 | 
         
            -
                 
     | 
| 112 | 
         
             
                init_layers_orthogonal: true
         
     | 
| 113 | 
         
             
                cnn_layers_init_orthogonal: false
         
     | 
| 114 | 
         | 
| 
         | 
|
| 108 | 
         
             
              <<: *atari-defaults
         
     | 
| 109 | 
         
             
              policy_hyperparams:
         
     | 
| 110 | 
         
             
                cnn_style: impala
         
     | 
| 111 | 
         
            +
                cnn_flatten_dim: 256
         
     | 
| 112 | 
         
             
                init_layers_orthogonal: true
         
     | 
| 113 | 
         
             
                cnn_layers_init_orthogonal: false
         
     | 
| 114 | 
         | 
    	
        rl_algo_impls/hyperparams/ppo.yml
    CHANGED
    
    | 
         @@ -112,7 +112,7 @@ CarRacing-v0: &carracing-defaults 
     | 
|
| 112 | 
         
             
                init_layers_orthogonal: false
         
     | 
| 113 | 
         
             
                activation_fn: relu
         
     | 
| 114 | 
         
             
                share_features_extractor: false
         
     | 
| 115 | 
         
            -
                 
     | 
| 116 | 
         
             
                hidden_sizes: [256]
         
     | 
| 117 | 
         
             
              algo_hyperparams:
         
     | 
| 118 | 
         
             
                n_steps: 512
         
     | 
| 
         @@ -152,7 +152,7 @@ _atari: &atari-defaults 
     | 
|
| 152 | 
         
             
                vec_env_class: async
         
     | 
| 153 | 
         
             
              policy_hyperparams: &atari-policy-defaults
         
     | 
| 154 | 
         
             
                activation_fn: relu
         
     | 
| 155 | 
         
            -
              algo_hyperparams:
         
     | 
| 156 | 
         
             
                n_steps: 128
         
     | 
| 157 | 
         
             
                batch_size: 256
         
     | 
| 158 | 
         
             
                n_epochs: 4
         
     | 
| 
         @@ -192,7 +192,7 @@ _impala-atari: &impala-atari-defaults 
     | 
|
| 192 | 
         
             
              policy_hyperparams:
         
     | 
| 193 | 
         
             
                <<: *atari-policy-defaults
         
     | 
| 194 | 
         
             
                cnn_style: impala
         
     | 
| 195 | 
         
            -
                 
     | 
| 196 | 
         
             
                init_layers_orthogonal: true
         
     | 
| 197 | 
         
             
                cnn_layers_init_orthogonal: false
         
     | 
| 198 | 
         | 
| 
         @@ -212,6 +212,126 @@ impala-QbertNoFrameskip-v4: 
     | 
|
| 212 | 
         
             
              <<: *impala-atari-defaults
         
     | 
| 213 | 
         
             
              env_id: QbertNoFrameskip-v4
         
     | 
| 214 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 215 | 
         
             
            HalfCheetahBulletEnv-v0: &pybullet-defaults
         
     | 
| 216 | 
         
             
              n_timesteps: !!float 2e6
         
     | 
| 217 | 
         
             
              env_hyperparams: &pybullet-env-defaults
         
     | 
| 
         @@ -282,7 +402,7 @@ _procgen: &procgen-defaults 
     | 
|
| 282 | 
         
             
              policy_hyperparams: &procgen-policy-defaults
         
     | 
| 283 | 
         
             
                activation_fn: relu
         
     | 
| 284 | 
         
             
                cnn_style: impala
         
     | 
| 285 | 
         
            -
                 
     | 
| 286 | 
         
             
                init_layers_orthogonal: true
         
     | 
| 287 | 
         
             
                cnn_layers_init_orthogonal: false
         
     | 
| 288 | 
         
             
              algo_hyperparams: &procgen-algo-defaults
         
     | 
| 
         @@ -368,7 +488,7 @@ procgen-starpilot-hard-2xIMPALA-fat: 
     | 
|
| 368 | 
         
             
              policy_hyperparams:
         
     | 
| 369 | 
         
             
                <<: *procgen-policy-defaults
         
     | 
| 370 | 
         
             
                impala_channels: [32, 64, 64]
         
     | 
| 371 | 
         
            -
                 
     | 
| 372 | 
         
             
              algo_hyperparams:
         
     | 
| 373 | 
         
             
                <<: *procgen-hard-algo-defaults
         
     | 
| 374 | 
         
             
                learning_rate: !!float 2.5e-4
         
     | 
| 
         | 
|
| 112 | 
         
             
                init_layers_orthogonal: false
         
     | 
| 113 | 
         
             
                activation_fn: relu
         
     | 
| 114 | 
         
             
                share_features_extractor: false
         
     | 
| 115 | 
         
            +
                cnn_flatten_dim: 256
         
     | 
| 116 | 
         
             
                hidden_sizes: [256]
         
     | 
| 117 | 
         
             
              algo_hyperparams:
         
     | 
| 118 | 
         
             
                n_steps: 512
         
     | 
| 
         | 
|
| 152 | 
         
             
                vec_env_class: async
         
     | 
| 153 | 
         
             
              policy_hyperparams: &atari-policy-defaults
         
     | 
| 154 | 
         
             
                activation_fn: relu
         
     | 
| 155 | 
         
            +
              algo_hyperparams: &atari-algo-defaults
         
     | 
| 156 | 
         
             
                n_steps: 128
         
     | 
| 157 | 
         
             
                batch_size: 256
         
     | 
| 158 | 
         
             
                n_epochs: 4
         
     | 
| 
         | 
|
| 192 | 
         
             
              policy_hyperparams:
         
     | 
| 193 | 
         
             
                <<: *atari-policy-defaults
         
     | 
| 194 | 
         
             
                cnn_style: impala
         
     | 
| 195 | 
         
            +
                cnn_flatten_dim: 256
         
     | 
| 196 | 
         
             
                init_layers_orthogonal: true
         
     | 
| 197 | 
         
             
                cnn_layers_init_orthogonal: false
         
     | 
| 198 | 
         | 
| 
         | 
|
| 212 | 
         
             
              <<: *impala-atari-defaults
         
     | 
| 213 | 
         
             
              env_id: QbertNoFrameskip-v4
         
     | 
| 214 | 
         | 
| 215 | 
         
            +
            _microrts: µrts-defaults
         
     | 
| 216 | 
         
            +
              <<: *atari-defaults
         
     | 
| 217 | 
         
            +
              n_timesteps: !!float 2e6
         
     | 
| 218 | 
         
            +
              env_hyperparams: µrts-env-defaults
         
     | 
| 219 | 
         
            +
                n_envs: 8
         
     | 
| 220 | 
         
            +
                vec_env_class: sync
         
     | 
| 221 | 
         
            +
                mask_actions: true
         
     | 
| 222 | 
         
            +
              policy_hyperparams: µrts-policy-defaults
         
     | 
| 223 | 
         
            +
                <<: *atari-policy-defaults
         
     | 
| 224 | 
         
            +
                cnn_style: microrts
         
     | 
| 225 | 
         
            +
                cnn_flatten_dim: 128
         
     | 
| 226 | 
         
            +
              algo_hyperparams: µrts-algo-defaults
         
     | 
| 227 | 
         
            +
                <<: *atari-algo-defaults
         
     | 
| 228 | 
         
            +
                clip_range_decay: none
         
     | 
| 229 | 
         
            +
                clip_range_vf: 0.1
         
     | 
| 230 | 
         
            +
                ppo2_vf_coef_halving: true
         
     | 
| 231 | 
         
            +
              eval_params:
         
     | 
| 232 | 
         
            +
                deterministic: false # Good idea because MultiCategorical mode isn't great
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
            _no-mask-microrts: &no-mask-microrts-defaults
         
     | 
| 235 | 
         
            +
              <<: *microrts-defaults
         
     | 
| 236 | 
         
            +
              env_hyperparams:
         
     | 
| 237 | 
         
            +
                <<: *microrts-env-defaults
         
     | 
| 238 | 
         
            +
                mask_actions: false
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
            MicrortsMining-v1-NoMask:
         
     | 
| 241 | 
         
            +
              <<: *no-mask-microrts-defaults
         
     | 
| 242 | 
         
            +
              env_id: MicrortsMining-v1
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
            MicrortsAttackShapedReward-v1-NoMask:
         
     | 
| 245 | 
         
            +
              <<: *no-mask-microrts-defaults
         
     | 
| 246 | 
         
            +
              env_id: MicrortsAttackShapedReward-v1
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
            MicrortsRandomEnemyShapedReward3-v1-NoMask:
         
     | 
| 249 | 
         
            +
              <<: *no-mask-microrts-defaults
         
     | 
| 250 | 
         
            +
              env_id: MicrortsRandomEnemyShapedReward3-v1
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
            _microrts_ai: µrts-ai-defaults
         
     | 
| 253 | 
         
            +
              <<: *microrts-defaults
         
     | 
| 254 | 
         
            +
              n_timesteps: !!float 100e6
         
     | 
| 255 | 
         
            +
              additional_keys_to_log: ["microrts_stats"]
         
     | 
| 256 | 
         
            +
              env_hyperparams: µrts-ai-env-defaults
         
     | 
| 257 | 
         
            +
                n_envs: 24
         
     | 
| 258 | 
         
            +
                env_type: microrts
         
     | 
| 259 | 
         
            +
                make_kwargs:
         
     | 
| 260 | 
         
            +
                  num_selfplay_envs: 0
         
     | 
| 261 | 
         
            +
                  max_steps: 2000
         
     | 
| 262 | 
         
            +
                  render_theme: 2
         
     | 
| 263 | 
         
            +
                  map_path: maps/16x16/basesWorkers16x16.xml
         
     | 
| 264 | 
         
            +
                  reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0]
         
     | 
| 265 | 
         
            +
              policy_hyperparams: µrts-ai-policy-defaults
         
     | 
| 266 | 
         
            +
                <<: *microrts-policy-defaults
         
     | 
| 267 | 
         
            +
                cnn_flatten_dim: 256
         
     | 
| 268 | 
         
            +
                actor_head_style: gridnet
         
     | 
| 269 | 
         
            +
              algo_hyperparams: µrts-ai-algo-defaults
         
     | 
| 270 | 
         
            +
                <<: *microrts-algo-defaults
         
     | 
| 271 | 
         
            +
                learning_rate: !!float 2.5e-4
         
     | 
| 272 | 
         
            +
                learning_rate_decay: linear
         
     | 
| 273 | 
         
            +
                n_steps: 512
         
     | 
| 274 | 
         
            +
                batch_size: 3072
         
     | 
| 275 | 
         
            +
                n_epochs: 4
         
     | 
| 276 | 
         
            +
                ent_coef: 0.01
         
     | 
| 277 | 
         
            +
                vf_coef: 0.5
         
     | 
| 278 | 
         
            +
                max_grad_norm: 0.5
         
     | 
| 279 | 
         
            +
                clip_range: 0.1
         
     | 
| 280 | 
         
            +
                clip_range_vf: 0.1
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
            MicrortsAttackPassiveEnemySparseReward-v3:
         
     | 
| 283 | 
         
            +
              <<: *microrts-ai-defaults
         
     | 
| 284 | 
         
            +
              n_timesteps: !!float 2e6
         
     | 
| 285 | 
         
            +
              env_id: MicrortsAttackPassiveEnemySparseReward-v3 # Workaround to keep model name simple
         
     | 
| 286 | 
         
            +
              env_hyperparams:
         
     | 
| 287 | 
         
            +
                <<: *microrts-ai-env-defaults
         
     | 
| 288 | 
         
            +
                bots:
         
     | 
| 289 | 
         
            +
                  passiveAI: 24
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
            MicrortsDefeatRandomEnemySparseReward-v3: µrts-random-ai-defaults
         
     | 
| 292 | 
         
            +
              <<: *microrts-ai-defaults
         
     | 
| 293 | 
         
            +
              n_timesteps: !!float 2e6
         
     | 
| 294 | 
         
            +
              env_id: MicrortsDefeatRandomEnemySparseReward-v3 # Workaround to keep model name simple
         
     | 
| 295 | 
         
            +
              env_hyperparams:
         
     | 
| 296 | 
         
            +
                <<: *microrts-ai-env-defaults
         
     | 
| 297 | 
         
            +
                bots:
         
     | 
| 298 | 
         
            +
                  randomBiasedAI: 24
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
            enc-dec-MicrortsDefeatRandomEnemySparseReward-v3:
         
     | 
| 301 | 
         
            +
              <<: *microrts-random-ai-defaults
         
     | 
| 302 | 
         
            +
              policy_hyperparams:
         
     | 
| 303 | 
         
            +
                <<: *microrts-ai-policy-defaults
         
     | 
| 304 | 
         
            +
                cnn_style: gridnet_encoder
         
     | 
| 305 | 
         
            +
                actor_head_style: gridnet_decoder
         
     | 
| 306 | 
         
            +
                v_hidden_sizes: [128]
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
            MicrortsDefeatCoacAIShaped-v3: µrts-coacai-defaults
         
     | 
| 309 | 
         
            +
              <<: *microrts-ai-defaults
         
     | 
| 310 | 
         
            +
              env_id: MicrortsDefeatCoacAIShaped-v3 # Workaround to keep model name simple
         
     | 
| 311 | 
         
            +
              n_timesteps: !!float 300e6
         
     | 
| 312 | 
         
            +
              env_hyperparams: µrts-coacai-env-defaults
         
     | 
| 313 | 
         
            +
                <<: *microrts-ai-env-defaults
         
     | 
| 314 | 
         
            +
                bots:
         
     | 
| 315 | 
         
            +
                  coacAI: 24
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
            MicrortsDefeatCoacAIShaped-v3-diverseBots: µrts-diverse-defaults
         
     | 
| 318 | 
         
            +
              <<: *microrts-coacai-defaults
         
     | 
| 319 | 
         
            +
              env_hyperparams:
         
     | 
| 320 | 
         
            +
                <<: *microrts-coacai-env-defaults
         
     | 
| 321 | 
         
            +
                bots:
         
     | 
| 322 | 
         
            +
                  coacAI: 18
         
     | 
| 323 | 
         
            +
                  randomBiasedAI: 2
         
     | 
| 324 | 
         
            +
                  lightRushAI: 2
         
     | 
| 325 | 
         
            +
                  workerRushAI: 2
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
            enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
         
     | 
| 328 | 
         
            +
              <<: *microrts-diverse-defaults
         
     | 
| 329 | 
         
            +
              policy_hyperparams:
         
     | 
| 330 | 
         
            +
                <<: *microrts-ai-policy-defaults
         
     | 
| 331 | 
         
            +
                cnn_style: gridnet_encoder
         
     | 
| 332 | 
         
            +
                actor_head_style: gridnet_decoder
         
     | 
| 333 | 
         
            +
                v_hidden_sizes: [128]
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
             
            HalfCheetahBulletEnv-v0: &pybullet-defaults
         
     | 
| 336 | 
         
             
              n_timesteps: !!float 2e6
         
     | 
| 337 | 
         
             
              env_hyperparams: &pybullet-env-defaults
         
     | 
| 
         | 
|
| 402 | 
         
             
              policy_hyperparams: &procgen-policy-defaults
         
     | 
| 403 | 
         
             
                activation_fn: relu
         
     | 
| 404 | 
         
             
                cnn_style: impala
         
     | 
| 405 | 
         
            +
                cnn_flatten_dim: 256
         
     | 
| 406 | 
         
             
                init_layers_orthogonal: true
         
     | 
| 407 | 
         
             
                cnn_layers_init_orthogonal: false
         
     | 
| 408 | 
         
             
              algo_hyperparams: &procgen-algo-defaults
         
     | 
| 
         | 
|
| 488 | 
         
             
              policy_hyperparams:
         
     | 
| 489 | 
         
             
                <<: *procgen-policy-defaults
         
     | 
| 490 | 
         
             
                impala_channels: [32, 64, 64]
         
     | 
| 491 | 
         
            +
                cnn_flatten_dim: 512
         
     | 
| 492 | 
         
             
              algo_hyperparams:
         
     | 
| 493 | 
         
             
                <<: *procgen-hard-algo-defaults
         
     | 
| 494 | 
         
             
                learning_rate: !!float 2.5e-4
         
     | 
    	
        rl_algo_impls/hyperparams/vpg.yml
    CHANGED
    
    | 
         @@ -110,7 +110,7 @@ CarRacing-v0: 
     | 
|
| 110 | 
         
             
                log_std_init: -2
         
     | 
| 111 | 
         
             
                init_layers_orthogonal: false
         
     | 
| 112 | 
         
             
                activation_fn: relu
         
     | 
| 113 | 
         
            -
                 
     | 
| 114 | 
         
             
                hidden_sizes: [256]
         
     | 
| 115 | 
         
             
              algo_hyperparams:
         
     | 
| 116 | 
         
             
                n_steps: 1000
         
     | 
| 
         @@ -175,9 +175,9 @@ FrozenLake-v1: 
     | 
|
| 175 | 
         
             
                save_best: true
         
     | 
| 176 | 
         | 
| 177 | 
         
             
            _atari: &atari-defaults
         
     | 
| 178 | 
         
            -
              n_timesteps: !!float  
     | 
| 179 | 
         
             
              env_hyperparams:
         
     | 
| 180 | 
         
            -
                n_envs:  
     | 
| 181 | 
         
             
                frame_stack: 4
         
     | 
| 182 | 
         
             
                no_reward_timeout_steps: 1000
         
     | 
| 183 | 
         
             
                no_reward_fire_steps: 500
         
     | 
| 
         @@ -185,7 +185,7 @@ _atari: &atari-defaults 
     | 
|
| 185 | 
         
             
              policy_hyperparams:
         
     | 
| 186 | 
         
             
                activation_fn: relu
         
     | 
| 187 | 
         
             
              algo_hyperparams:
         
     | 
| 188 | 
         
            -
                n_steps:  
     | 
| 189 | 
         
             
                pi_lr: !!float 5e-5
         
     | 
| 190 | 
         
             
                gamma: 0.99
         
     | 
| 191 | 
         
             
                gae_lambda: 0.95
         
     | 
| 
         | 
|
| 110 | 
         
             
                log_std_init: -2
         
     | 
| 111 | 
         
             
                init_layers_orthogonal: false
         
     | 
| 112 | 
         
             
                activation_fn: relu
         
     | 
| 113 | 
         
            +
                cnn_flatten_dim: 256
         
     | 
| 114 | 
         
             
                hidden_sizes: [256]
         
     | 
| 115 | 
         
             
              algo_hyperparams:
         
     | 
| 116 | 
         
             
                n_steps: 1000
         
     | 
| 
         | 
|
| 175 | 
         
             
                save_best: true
         
     | 
| 176 | 
         | 
| 177 | 
         
             
            _atari: &atari-defaults
         
     | 
| 178 | 
         
            +
              n_timesteps: !!float 10e6
         
     | 
| 179 | 
         
             
              env_hyperparams:
         
     | 
| 180 | 
         
            +
                n_envs: 2
         
     | 
| 181 | 
         
             
                frame_stack: 4
         
     | 
| 182 | 
         
             
                no_reward_timeout_steps: 1000
         
     | 
| 183 | 
         
             
                no_reward_fire_steps: 500
         
     | 
| 
         | 
|
| 185 | 
         
             
              policy_hyperparams:
         
     | 
| 186 | 
         
             
                activation_fn: relu
         
     | 
| 187 | 
         
             
              algo_hyperparams:
         
     | 
| 188 | 
         
            +
                n_steps: 3072
         
     | 
| 189 | 
         
             
                pi_lr: !!float 5e-5
         
     | 
| 190 | 
         
             
                gamma: 0.99
         
     | 
| 191 | 
         
             
                gae_lambda: 0.95
         
     | 
    	
        rl_algo_impls/optimize.py
    CHANGED
    
    | 
         @@ -17,7 +17,7 @@ from typing import Callable, List, NamedTuple, Optional, Sequence, Union 
     | 
|
| 17 | 
         | 
| 18 | 
         
             
            from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
         
     | 
| 19 | 
         
             
            from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
         
     | 
| 20 | 
         
            -
            from rl_algo_impls. 
     | 
| 21 | 
         
             
            from rl_algo_impls.runner.running_utils import (
         
     | 
| 22 | 
         
             
                base_parser,
         
     | 
| 23 | 
         
             
                load_hyperparams,
         
     | 
| 
         @@ -194,7 +194,7 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) - 
     | 
|
| 194 | 
         
             
                env = make_env(
         
     | 
| 195 | 
         
             
                    config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
         
     | 
| 196 | 
         
             
                )
         
     | 
| 197 | 
         
            -
                device = get_device(config 
     | 
| 198 | 
         
             
                policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
         
     | 
| 199 | 
         
             
                algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
         
     | 
| 200 | 
         | 
| 
         @@ -274,7 +274,7 @@ def stepwise_optimize( 
     | 
|
| 274 | 
         
             
                        project=study_args.wandb_project_name,
         
     | 
| 275 | 
         
             
                        entity=study_args.wandb_entity,
         
     | 
| 276 | 
         
             
                        config=asdict(hyperparams),
         
     | 
| 277 | 
         
            -
                        name=f"{ 
     | 
| 278 | 
         
             
                        tags=study_args.wandb_tags,
         
     | 
| 279 | 
         
             
                        group=study_args.wandb_group,
         
     | 
| 280 | 
         
             
                        save_code=True,
         
     | 
| 
         @@ -298,7 +298,7 @@ def stepwise_optimize( 
     | 
|
| 298 | 
         
             
                            normalize_load_path=config.model_dir_path() if i > 0 else None,
         
     | 
| 299 | 
         
             
                            tb_writer=tb_writer,
         
     | 
| 300 | 
         
             
                        )
         
     | 
| 301 | 
         
            -
                        device = get_device(config 
     | 
| 302 | 
         
             
                        policy = make_policy(arg.algo, env, device, **config.policy_hyperparams)
         
     | 
| 303 | 
         
             
                        if i > 0:
         
     | 
| 304 | 
         
             
                            policy.load(config.model_dir_path())
         
     | 
| 
         @@ -433,6 +433,7 @@ def optimize() -> None: 
     | 
|
| 433 | 
         | 
| 434 | 
         
             
                fig1 = plot_optimization_history(study)
         
     | 
| 435 | 
         
             
                fig1.write_image("opt_history.png")
         
     | 
| 
         | 
|
| 436 | 
         
             
                fig2 = plot_param_importances(study)
         
     | 
| 437 | 
         
             
                fig2.write_image("param_importances.png")
         
     | 
| 438 | 
         | 
| 
         | 
|
| 17 | 
         | 
| 18 | 
         
             
            from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
         
     | 
| 19 | 
         
             
            from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
         
     | 
| 20 | 
         
            +
            from rl_algo_impls.shared.vec_env import make_env, make_eval_env
         
     | 
| 21 | 
         
             
            from rl_algo_impls.runner.running_utils import (
         
     | 
| 22 | 
         
             
                base_parser,
         
     | 
| 23 | 
         
             
                load_hyperparams,
         
     | 
| 
         | 
|
| 194 | 
         
             
                env = make_env(
         
     | 
| 195 | 
         
             
                    config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
         
     | 
| 196 | 
         
             
                )
         
     | 
| 197 | 
         
            +
                device = get_device(config, env)
         
     | 
| 198 | 
         
             
                policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
         
     | 
| 199 | 
         
             
                algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
         
     | 
| 200 | 
         | 
| 
         | 
|
| 274 | 
         
             
                        project=study_args.wandb_project_name,
         
     | 
| 275 | 
         
             
                        entity=study_args.wandb_entity,
         
     | 
| 276 | 
         
             
                        config=asdict(hyperparams),
         
     | 
| 277 | 
         
            +
                        name=f"{str(trial.number)}-S{base_config.seed()}",
         
     | 
| 278 | 
         
             
                        tags=study_args.wandb_tags,
         
     | 
| 279 | 
         
             
                        group=study_args.wandb_group,
         
     | 
| 280 | 
         
             
                        save_code=True,
         
     | 
| 
         | 
|
| 298 | 
         
             
                            normalize_load_path=config.model_dir_path() if i > 0 else None,
         
     | 
| 299 | 
         
             
                            tb_writer=tb_writer,
         
     | 
| 300 | 
         
             
                        )
         
     | 
| 301 | 
         
            +
                        device = get_device(config, env)
         
     | 
| 302 | 
         
             
                        policy = make_policy(arg.algo, env, device, **config.policy_hyperparams)
         
     | 
| 303 | 
         
             
                        if i > 0:
         
     | 
| 304 | 
         
             
                            policy.load(config.model_dir_path())
         
     | 
| 
         | 
|
| 433 | 
         | 
| 434 | 
         
             
                fig1 = plot_optimization_history(study)
         
     | 
| 435 | 
         
             
                fig1.write_image("opt_history.png")
         
     | 
| 436 | 
         
            +
             
     | 
| 437 | 
         
             
                fig2 = plot_param_importances(study)
         
     | 
| 438 | 
         
             
                fig2.write_image("param_importances.png")
         
     | 
| 439 | 
         | 
    	
        rl_algo_impls/ppo/ppo.py
    CHANGED
    
    | 
         @@ -1,59 +1,26 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import numpy as np
         
     | 
| 2 | 
         
             
            import torch
         
     | 
| 3 | 
         
             
            import torch.nn as nn
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
            from dataclasses import asdict, dataclass, field
         
     | 
| 6 | 
         
            -
            from time import perf_counter
         
     | 
| 7 | 
         
             
            from torch.optim import Adam
         
     | 
| 8 | 
         
             
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 9 | 
         
            -
            from typing import List, Optional, NamedTuple, TypeVar
         
     | 
| 10 | 
         | 
| 11 | 
         
             
            from rl_algo_impls.shared.algorithm import Algorithm
         
     | 
| 12 | 
         
             
            from rl_algo_impls.shared.callbacks.callback import Callback
         
     | 
| 13 | 
         
            -
            from rl_algo_impls.shared.gae import  
     | 
| 14 | 
         
             
            from rl_algo_impls.shared.policy.on_policy import ActorCritic
         
     | 
| 15 | 
         
            -
            from rl_algo_impls.shared.schedule import  
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 19 | 
         
             
            )
         
     | 
| 20 | 
         
            -
            from rl_algo_impls.shared.trajectory import Trajectory, TrajectoryAccumulator
         
     | 
| 21 | 
         
            -
            from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
            @dataclass
         
     | 
| 25 | 
         
            -
            class PPOTrajectory(Trajectory):
         
     | 
| 26 | 
         
            -
                logp_a: List[float] = field(default_factory=list)
         
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
                def add(
         
     | 
| 29 | 
         
            -
                    self,
         
     | 
| 30 | 
         
            -
                    obs: np.ndarray,
         
     | 
| 31 | 
         
            -
                    act: np.ndarray,
         
     | 
| 32 | 
         
            -
                    next_obs: np.ndarray,
         
     | 
| 33 | 
         
            -
                    rew: float,
         
     | 
| 34 | 
         
            -
                    terminated: bool,
         
     | 
| 35 | 
         
            -
                    v: float,
         
     | 
| 36 | 
         
            -
                    logp_a: float,
         
     | 
| 37 | 
         
            -
                ):
         
     | 
| 38 | 
         
            -
                    super().add(obs, act, next_obs, rew, terminated, v)
         
     | 
| 39 | 
         
            -
                    self.logp_a.append(logp_a)
         
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
            class PPOTrajectoryAccumulator(TrajectoryAccumulator):
         
     | 
| 43 | 
         
            -
                def __init__(self, num_envs: int) -> None:
         
     | 
| 44 | 
         
            -
                    super().__init__(num_envs, PPOTrajectory)
         
     | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
            -
                def step(
         
     | 
| 47 | 
         
            -
                    self,
         
     | 
| 48 | 
         
            -
                    obs: VecEnvObs,
         
     | 
| 49 | 
         
            -
                    action: np.ndarray,
         
     | 
| 50 | 
         
            -
                    next_obs: VecEnvObs,
         
     | 
| 51 | 
         
            -
                    reward: np.ndarray,
         
     | 
| 52 | 
         
            -
                    done: np.ndarray,
         
     | 
| 53 | 
         
            -
                    val: np.ndarray,
         
     | 
| 54 | 
         
            -
                    logp_a: np.ndarray,
         
     | 
| 55 | 
         
            -
                ) -> None:
         
     | 
| 56 | 
         
            -
                    super().step(obs, action, next_obs, reward, done, val, logp_a)
         
     | 
| 57 | 
         | 
| 58 | 
         | 
| 59 | 
         
             
            class TrainStepStats(NamedTuple):
         
     | 
| 
         @@ -132,39 +99,31 @@ class PPO(Algorithm): 
     | 
|
| 132 | 
         
             
                    vf_coef: float = 0.5,
         
     | 
| 133 | 
         
             
                    ppo2_vf_coef_halving: bool = False,
         
     | 
| 134 | 
         
             
                    max_grad_norm: float = 0.5,
         
     | 
| 135 | 
         
            -
                    update_rtg_between_epochs: bool = False,
         
     | 
| 136 | 
         
             
                    sde_sample_freq: int = -1,
         
     | 
| 
         | 
|
| 
         | 
|
| 137 | 
         
             
                ) -> None:
         
     | 
| 138 | 
         
             
                    super().__init__(policy, env, device, tb_writer)
         
     | 
| 139 | 
         
             
                    self.policy = policy
         
     | 
| 
         | 
|
| 140 | 
         | 
| 141 | 
         
             
                    self.gamma = gamma
         
     | 
| 142 | 
         
             
                    self.gae_lambda = gae_lambda
         
     | 
| 143 | 
         
             
                    self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
         
     | 
| 144 | 
         
            -
                    self.lr_schedule = (
         
     | 
| 145 | 
         
            -
                        linear_schedule(learning_rate, 0)
         
     | 
| 146 | 
         
            -
                        if learning_rate_decay == "linear"
         
     | 
| 147 | 
         
            -
                        else constant_schedule(learning_rate)
         
     | 
| 148 | 
         
            -
                    )
         
     | 
| 149 | 
         
             
                    self.max_grad_norm = max_grad_norm
         
     | 
| 150 | 
         
            -
                    self.clip_range_schedule = (
         
     | 
| 151 | 
         
            -
                        linear_schedule(clip_range, 0)
         
     | 
| 152 | 
         
            -
                        if clip_range_decay == "linear"
         
     | 
| 153 | 
         
            -
                        else constant_schedule(clip_range)
         
     | 
| 154 | 
         
            -
                    )
         
     | 
| 155 | 
         
             
                    self.clip_range_vf_schedule = None
         
     | 
| 156 | 
         
             
                    if clip_range_vf:
         
     | 
| 157 | 
         
            -
                        self.clip_range_vf_schedule = (
         
     | 
| 158 | 
         
            -
             
     | 
| 159 | 
         
            -
             
     | 
| 160 | 
         
            -
             
     | 
| 161 | 
         
            -
             
     | 
| 
         | 
|
| 162 | 
         
             
                    self.normalize_advantage = normalize_advantage
         
     | 
| 163 | 
         
            -
             
     | 
| 164 | 
         
            -
             
     | 
| 165 | 
         
            -
                        if ent_coef_decay == "linear"
         
     | 
| 166 | 
         
            -
                        else constant_schedule(ent_coef)
         
     | 
| 167 | 
         
            -
                    )
         
     | 
| 168 | 
         
             
                    self.vf_coef = vf_coef
         
     | 
| 169 | 
         
             
                    self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
         
     | 
| 170 | 
         | 
| 
         @@ -173,181 +132,243 @@ class PPO(Algorithm): 
     | 
|
| 173 | 
         
             
                    self.n_epochs = n_epochs
         
     | 
| 174 | 
         
             
                    self.sde_sample_freq = sde_sample_freq
         
     | 
| 175 | 
         | 
| 176 | 
         
            -
                    self. 
     | 
| 
         | 
|
| 177 | 
         | 
| 178 | 
         
             
                def learn(
         
     | 
| 179 | 
         
             
                    self: PPOSelf,
         
     | 
| 180 | 
         
            -
                     
     | 
| 181 | 
         
             
                    callback: Optional[Callback] = None,
         
     | 
| 
         | 
|
| 
         | 
|
| 182 | 
         
             
                ) -> PPOSelf:
         
     | 
| 183 | 
         
            -
                     
     | 
| 184 | 
         
            -
             
     | 
| 185 | 
         
            -
                     
     | 
| 186 | 
         
            -
             
     | 
| 187 | 
         
            -
             
     | 
| 188 | 
         
            -
             
     | 
| 189 | 
         
            -
             
     | 
| 190 | 
         
            -
             
     | 
| 191 | 
         
            -
             
     | 
| 192 | 
         
            -
             
     | 
| 193 | 
         
            -
             
     | 
| 194 | 
         
            -
             
     | 
| 195 | 
         
            -
             
     | 
| 196 | 
         
            -
             
     | 
| 197 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 198 | 
         
             
                        )
         
     | 
| 199 | 
         
            -
                        if  
     | 
| 200 | 
         
            -
             
     | 
| 201 | 
         
            -
             
     | 
| 202 | 
         
            -
                    return self
         
     | 
| 203 | 
         
            -
             
     | 
| 204 | 
         
            -
                def _collect_trajectories(self, obs: VecEnvObs) -> PPOTrajectoryAccumulator:
         
     | 
| 205 | 
         
            -
                    self.policy.eval()
         
     | 
| 206 | 
         
            -
                    accumulator = PPOTrajectoryAccumulator(self.env.num_envs)
         
     | 
| 207 | 
         
            -
                    self.policy.reset_noise()
         
     | 
| 208 | 
         
            -
                    for i in range(self.n_steps):
         
     | 
| 209 | 
         
            -
                        if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
         
     | 
| 210 | 
         
            -
                            self.policy.reset_noise()
         
     | 
| 211 | 
         
            -
                        action, value, logp_a, clamped_action = self.policy.step(obs)
         
     | 
| 212 | 
         
            -
                        next_obs, reward, done, _ = self.env.step(clamped_action)
         
     | 
| 213 | 
         
            -
                        accumulator.step(obs, action, next_obs, reward, done, value, logp_a)
         
     | 
| 214 | 
         
            -
                        obs = next_obs
         
     | 
| 215 | 
         
            -
                    return accumulator
         
     | 
| 216 | 
         
            -
             
     | 
| 217 | 
         
            -
                def train(
         
     | 
| 218 | 
         
            -
                    self, trajectories: List[PPOTrajectory], progress: float, timesteps_elapsed: int
         
     | 
| 219 | 
         
            -
                ) -> TrainStats:
         
     | 
| 220 | 
         
            -
                    self.policy.train()
         
     | 
| 221 | 
         
            -
                    learning_rate = self.lr_schedule(progress)
         
     | 
| 222 | 
         
            -
                    update_learning_rate(self.optimizer, learning_rate)
         
     | 
| 223 | 
         
            -
                    self.tb_writer.add_scalar(
         
     | 
| 224 | 
         
            -
                        "charts/learning_rate",
         
     | 
| 225 | 
         
            -
                        self.optimizer.param_groups[0]["lr"],
         
     | 
| 226 | 
         
            -
                        timesteps_elapsed,
         
     | 
| 227 | 
         
             
                    )
         
     | 
| 228 | 
         | 
| 229 | 
         
            -
                     
     | 
| 230 | 
         
            -
                     
     | 
| 231 | 
         
            -
             
     | 
| 232 | 
         
            -
                        v_clip = self.clip_range_vf_schedule(progress)
         
     | 
| 233 | 
         
            -
                        self.tb_writer.add_scalar("charts/v_clip", v_clip, timesteps_elapsed)
         
     | 
| 234 | 
         
            -
                    else:
         
     | 
| 235 | 
         
            -
                        v_clip = None
         
     | 
| 236 | 
         
            -
                    ent_coef = self.ent_coef_schedule(progress)
         
     | 
| 237 | 
         
            -
                    self.tb_writer.add_scalar("charts/ent_coef", ent_coef, timesteps_elapsed)
         
     | 
| 238 | 
         
            -
             
     | 
| 239 | 
         
            -
                    obs = torch.as_tensor(
         
     | 
| 240 | 
         
            -
                        np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
         
     | 
| 241 | 
         
            -
                    )
         
     | 
| 242 | 
         
            -
                    act = torch.as_tensor(
         
     | 
| 243 | 
         
            -
                        np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
         
     | 
| 244 | 
         
            -
                    )
         
     | 
| 245 | 
         
            -
                    rtg, adv = compute_rtg_and_advantage(
         
     | 
| 246 | 
         
            -
                        trajectories, self.policy, self.gamma, self.gae_lambda, self.device
         
     | 
| 247 | 
         
            -
                    )
         
     | 
| 248 | 
         
            -
                    orig_v = torch.as_tensor(
         
     | 
| 249 | 
         
            -
                        np.concatenate([np.array(t.v) for t in trajectories]), device=self.device
         
     | 
| 250 | 
         
            -
                    )
         
     | 
| 251 | 
         
            -
                    orig_logp_a = torch.as_tensor(
         
     | 
| 252 | 
         
            -
                        np.concatenate([np.array(t.logp_a) for t in trajectories]),
         
     | 
| 253 | 
         
            -
                        device=self.device,
         
     | 
| 254 | 
         
            -
                    )
         
     | 
| 255 | 
         | 
| 256 | 
         
            -
             
     | 
| 257 | 
         
            -
             
     | 
| 258 | 
         
            -
                         
     | 
| 259 | 
         
            -
                         
     | 
| 260 | 
         
            -
             
     | 
| 261 | 
         
            -
             
     | 
| 262 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 263 | 
         
             
                        else:
         
     | 
| 264 | 
         
            -
                             
     | 
| 265 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 266 | 
         
             
                            )
         
     | 
| 267 | 
         
            -
             
     | 
| 268 | 
         
            -
             
     | 
| 269 | 
         
            -
                            mb_idxs = idxs[i : i + self.batch_size]
         
     | 
| 270 | 
         
            -
                            mb_adv = adv[mb_idxs]
         
     | 
| 271 | 
         
            -
                            if self.normalize_advantage:
         
     | 
| 272 | 
         
            -
                                mb_adv = (mb_adv - mb_adv.mean(-1)) / (mb_adv.std(-1) + 1e-8)
         
     | 
| 273 | 
         
            -
                            self.policy.reset_noise(self.batch_size)
         
     | 
| 274 | 
         
            -
                            step_stats.append(
         
     | 
| 275 | 
         
            -
                                self._train_step(
         
     | 
| 276 | 
         
            -
                                    pi_clip,
         
     | 
| 277 | 
         
            -
                                    v_clip,
         
     | 
| 278 | 
         
            -
                                    ent_coef,
         
     | 
| 279 | 
         
            -
                                    obs[mb_idxs],
         
     | 
| 280 | 
         
            -
                                    act[mb_idxs],
         
     | 
| 281 | 
         
            -
                                    rtg[mb_idxs],
         
     | 
| 282 | 
         
            -
                                    mb_adv,
         
     | 
| 283 | 
         
            -
                                    orig_v[mb_idxs],
         
     | 
| 284 | 
         
            -
                                    orig_logp_a[mb_idxs],
         
     | 
| 285 | 
         
            -
                                )
         
     | 
| 286 | 
         
             
                            )
         
     | 
| 287 | 
         | 
| 288 | 
         
            -
             
     | 
| 289 | 
         
            -
             
     | 
| 290 | 
         
            -
             
     | 
| 291 | 
         
            -
                         
     | 
| 292 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 293 | 
         | 
| 294 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 295 | 
         | 
| 296 | 
         
            -
             
     | 
| 297 | 
         
            -
             
     | 
| 298 | 
         
            -
             
     | 
| 299 | 
         
            -
             
     | 
| 300 | 
         
            -
             
     | 
| 301 | 
         
            -
             
     | 
| 302 | 
         
            -
             
     | 
| 303 | 
         
            -
             
     | 
| 304 | 
         
            -
             
     | 
| 305 | 
         
            -
             
     | 
| 306 | 
         
            -
             
     | 
| 307 | 
         
            -
             
     | 
| 308 | 
         
            -
             
     | 
| 309 | 
         
            -
             
     | 
| 310 | 
         
            -
             
     | 
| 311 | 
         
            -
             
     | 
| 312 | 
         
            -
             
     | 
| 313 | 
         
            -
             
     | 
| 314 | 
         
            -
             
     | 
| 315 | 
         
            -
             
     | 
| 316 | 
         
            -
             
     | 
| 317 | 
         
            -
             
     | 
| 318 | 
         
            -
             
     | 
| 319 | 
         
            -
             
     | 
| 320 | 
         
            -
             
     | 
| 321 | 
         
            -
             
     | 
| 322 | 
         
            -
             
     | 
| 323 | 
         
            -
             
     | 
| 324 | 
         
            -
             
     | 
| 325 | 
         
            -
             
     | 
| 326 | 
         
            -
             
     | 
| 327 | 
         
            -
             
     | 
| 328 | 
         
            -
             
     | 
| 329 | 
         
            -
             
     | 
| 330 | 
         
            -
             
     | 
| 331 | 
         
            -
             
     | 
| 332 | 
         
            -
             
     | 
| 333 | 
         
            -
             
     | 
| 334 | 
         
            -
             
     | 
| 335 | 
         
            -
             
     | 
| 336 | 
         
            -
             
     | 
| 337 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 338 | 
         
             
                        )
         
     | 
| 339 | 
         
            -
                         
     | 
| 340 | 
         
            -
                             
     | 
| 341 | 
         
            -
                            if v_clip
         
     | 
| 342 | 
         
            -
                            else 0
         
     | 
| 343 | 
         
             
                        )
         
     | 
| 344 | 
         | 
| 345 | 
         
            -
             
     | 
| 346 | 
         
            -
                         
     | 
| 347 | 
         
            -
                         
     | 
| 348 | 
         
            -
             
     | 
| 349 | 
         
            -
             
     | 
| 350 | 
         
            -
             
     | 
| 351 | 
         
            -
                         
     | 
| 352 | 
         
            -
             
     | 
| 353 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import logging
         
     | 
| 2 | 
         
            +
            from dataclasses import asdict, dataclass
         
     | 
| 3 | 
         
            +
            from time import perf_counter
         
     | 
| 4 | 
         
            +
            from typing import List, NamedTuple, Optional, TypeVar
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
             
            import numpy as np
         
     | 
| 7 | 
         
             
            import torch
         
     | 
| 8 | 
         
             
            import torch.nn as nn
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 9 | 
         
             
            from torch.optim import Adam
         
     | 
| 10 | 
         
             
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 
         | 
|
| 11 | 
         | 
| 12 | 
         
             
            from rl_algo_impls.shared.algorithm import Algorithm
         
     | 
| 13 | 
         
             
            from rl_algo_impls.shared.callbacks.callback import Callback
         
     | 
| 14 | 
         
            +
            from rl_algo_impls.shared.gae import compute_advantages
         
     | 
| 15 | 
         
             
            from rl_algo_impls.shared.policy.on_policy import ActorCritic
         
     | 
| 16 | 
         
            +
            from rl_algo_impls.shared.schedule import schedule, update_learning_rate
         
     | 
| 17 | 
         
            +
            from rl_algo_impls.shared.stats import log_scalars
         
     | 
| 18 | 
         
            +
            from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
         
     | 
| 19 | 
         
            +
            from rl_algo_impls.wrappers.vectorable_wrapper import (
         
     | 
| 20 | 
         
            +
                VecEnv,
         
     | 
| 21 | 
         
            +
                single_action_space,
         
     | 
| 22 | 
         
            +
                single_observation_space,
         
     | 
| 23 | 
         
             
            )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 24 | 
         | 
| 25 | 
         | 
| 26 | 
         
             
            class TrainStepStats(NamedTuple):
         
     | 
| 
         | 
|
| 99 | 
         
             
                    vf_coef: float = 0.5,
         
     | 
| 100 | 
         
             
                    ppo2_vf_coef_halving: bool = False,
         
     | 
| 101 | 
         
             
                    max_grad_norm: float = 0.5,
         
     | 
| 
         | 
|
| 102 | 
         
             
                    sde_sample_freq: int = -1,
         
     | 
| 103 | 
         
            +
                    update_advantage_between_epochs: bool = True,
         
     | 
| 104 | 
         
            +
                    update_returns_between_epochs: bool = False,
         
     | 
| 105 | 
         
             
                ) -> None:
         
     | 
| 106 | 
         
             
                    super().__init__(policy, env, device, tb_writer)
         
     | 
| 107 | 
         
             
                    self.policy = policy
         
     | 
| 108 | 
         
            +
                    self.action_masker = find_action_masker(env)
         
     | 
| 109 | 
         | 
| 110 | 
         
             
                    self.gamma = gamma
         
     | 
| 111 | 
         
             
                    self.gae_lambda = gae_lambda
         
     | 
| 112 | 
         
             
                    self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
         
     | 
| 113 | 
         
            +
                    self.lr_schedule = schedule(learning_rate_decay, learning_rate)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 114 | 
         
             
                    self.max_grad_norm = max_grad_norm
         
     | 
| 115 | 
         
            +
                    self.clip_range_schedule = schedule(clip_range_decay, clip_range)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 116 | 
         
             
                    self.clip_range_vf_schedule = None
         
     | 
| 117 | 
         
             
                    if clip_range_vf:
         
     | 
| 118 | 
         
            +
                        self.clip_range_vf_schedule = schedule(clip_range_vf_decay, clip_range_vf)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    if normalize_advantage:
         
     | 
| 121 | 
         
            +
                        assert (
         
     | 
| 122 | 
         
            +
                            env.num_envs * n_steps > 1 and batch_size > 1
         
     | 
| 123 | 
         
            +
                        ), f"Each minibatch must be larger than 1 to support normalization"
         
     | 
| 124 | 
         
             
                    self.normalize_advantage = normalize_advantage
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 127 | 
         
             
                    self.vf_coef = vf_coef
         
     | 
| 128 | 
         
             
                    self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
         
     | 
| 129 | 
         | 
| 
         | 
|
| 132 | 
         
             
                    self.n_epochs = n_epochs
         
     | 
| 133 | 
         
             
                    self.sde_sample_freq = sde_sample_freq
         
     | 
| 134 | 
         | 
| 135 | 
         
            +
                    self.update_advantage_between_epochs = update_advantage_between_epochs
         
     | 
| 136 | 
         
            +
                    self.update_returns_between_epochs = update_returns_between_epochs
         
     | 
| 137 | 
         | 
| 138 | 
         
             
                def learn(
         
     | 
| 139 | 
         
             
                    self: PPOSelf,
         
     | 
| 140 | 
         
            +
                    train_timesteps: int,
         
     | 
| 141 | 
         
             
                    callback: Optional[Callback] = None,
         
     | 
| 142 | 
         
            +
                    total_timesteps: Optional[int] = None,
         
     | 
| 143 | 
         
            +
                    start_timesteps: int = 0,
         
     | 
| 144 | 
         
             
                ) -> PPOSelf:
         
     | 
| 145 | 
         
            +
                    if total_timesteps is None:
         
     | 
| 146 | 
         
            +
                        total_timesteps = train_timesteps
         
     | 
| 147 | 
         
            +
                    assert start_timesteps + train_timesteps <= total_timesteps
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    epoch_dim = (self.n_steps, self.env.num_envs)
         
     | 
| 150 | 
         
            +
                    step_dim = (self.env.num_envs,)
         
     | 
| 151 | 
         
            +
                    obs_space = single_observation_space(self.env)
         
     | 
| 152 | 
         
            +
                    act_space = single_action_space(self.env)
         
     | 
| 153 | 
         
            +
                    act_shape = self.policy.action_shape
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                    next_obs = self.env.reset()
         
     | 
| 156 | 
         
            +
                    next_action_masks = (
         
     | 
| 157 | 
         
            +
                        self.action_masker.action_masks() if self.action_masker else None
         
     | 
| 158 | 
         
            +
                    )
         
     | 
| 159 | 
         
            +
                    next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)  # type: ignore
         
     | 
| 162 | 
         
            +
                    actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype)  # type: ignore
         
     | 
| 163 | 
         
            +
                    rewards = np.zeros(epoch_dim, dtype=np.float32)
         
     | 
| 164 | 
         
            +
                    episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
         
     | 
| 165 | 
         
            +
                    values = np.zeros(epoch_dim, dtype=np.float32)
         
     | 
| 166 | 
         
            +
                    logprobs = np.zeros(epoch_dim, dtype=np.float32)
         
     | 
| 167 | 
         
            +
                    action_masks = (
         
     | 
| 168 | 
         
            +
                        np.zeros(
         
     | 
| 169 | 
         
            +
                            (self.n_steps,) + next_action_masks.shape, dtype=next_action_masks.dtype
         
     | 
| 170 | 
         
             
                        )
         
     | 
| 171 | 
         
            +
                        if next_action_masks is not None
         
     | 
| 172 | 
         
            +
                        else None
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 173 | 
         
             
                    )
         
     | 
| 174 | 
         | 
| 175 | 
         
            +
                    timesteps_elapsed = start_timesteps
         
     | 
| 176 | 
         
            +
                    while timesteps_elapsed < start_timesteps + train_timesteps:
         
     | 
| 177 | 
         
            +
                        start_time = perf_counter()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 178 | 
         | 
| 179 | 
         
            +
                        progress = timesteps_elapsed / total_timesteps
         
     | 
| 180 | 
         
            +
                        ent_coef = self.ent_coef_schedule(progress)
         
     | 
| 181 | 
         
            +
                        learning_rate = self.lr_schedule(progress)
         
     | 
| 182 | 
         
            +
                        update_learning_rate(self.optimizer, learning_rate)
         
     | 
| 183 | 
         
            +
                        pi_clip = self.clip_range_schedule(progress)
         
     | 
| 184 | 
         
            +
                        chart_scalars = {
         
     | 
| 185 | 
         
            +
                            "learning_rate": self.optimizer.param_groups[0]["lr"],
         
     | 
| 186 | 
         
            +
                            "ent_coef": ent_coef,
         
     | 
| 187 | 
         
            +
                            "pi_clip": pi_clip,
         
     | 
| 188 | 
         
            +
                        }
         
     | 
| 189 | 
         
            +
                        if self.clip_range_vf_schedule:
         
     | 
| 190 | 
         
            +
                            v_clip = self.clip_range_vf_schedule(progress)
         
     | 
| 191 | 
         
            +
                            chart_scalars["v_clip"] = v_clip
         
     | 
| 192 | 
         
             
                        else:
         
     | 
| 193 | 
         
            +
                            v_clip = None
         
     | 
| 194 | 
         
            +
                        log_scalars(self.tb_writer, "charts", chart_scalars, timesteps_elapsed)
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                        self.policy.eval()
         
     | 
| 197 | 
         
            +
                        self.policy.reset_noise()
         
     | 
| 198 | 
         
            +
                        for s in range(self.n_steps):
         
     | 
| 199 | 
         
            +
                            timesteps_elapsed += self.env.num_envs
         
     | 
| 200 | 
         
            +
                            if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
         
     | 
| 201 | 
         
            +
                                self.policy.reset_noise()
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                            obs[s] = next_obs
         
     | 
| 204 | 
         
            +
                            episode_starts[s] = next_episode_starts
         
     | 
| 205 | 
         
            +
                            if action_masks is not None:
         
     | 
| 206 | 
         
            +
                                action_masks[s] = next_action_masks
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                            (
         
     | 
| 209 | 
         
            +
                                actions[s],
         
     | 
| 210 | 
         
            +
                                values[s],
         
     | 
| 211 | 
         
            +
                                logprobs[s],
         
     | 
| 212 | 
         
            +
                                clamped_action,
         
     | 
| 213 | 
         
            +
                            ) = self.policy.step(next_obs, action_masks=next_action_masks)
         
     | 
| 214 | 
         
            +
                            next_obs, rewards[s], next_episode_starts, _ = self.env.step(
         
     | 
| 215 | 
         
            +
                                clamped_action
         
     | 
| 216 | 
         
             
                            )
         
     | 
| 217 | 
         
            +
                            next_action_masks = (
         
     | 
| 218 | 
         
            +
                                self.action_masker.action_masks() if self.action_masker else None
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 219 | 
         
             
                            )
         
     | 
| 220 | 
         | 
| 221 | 
         
            +
                        self.policy.train()
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                        b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)  # type: ignore
         
     | 
| 224 | 
         
            +
                        b_actions = torch.tensor(actions.reshape((-1,) + act_shape)).to(  # type: ignore
         
     | 
| 225 | 
         
            +
                            self.device
         
     | 
| 226 | 
         
            +
                        )
         
     | 
| 227 | 
         
            +
                        b_logprobs = torch.tensor(logprobs.reshape(-1)).to(self.device)
         
     | 
| 228 | 
         
            +
                        b_action_masks = (
         
     | 
| 229 | 
         
            +
                            torch.tensor(action_masks.reshape((-1,) + next_action_masks.shape[1:])).to(  # type: ignore
         
     | 
| 230 | 
         
            +
                                self.device
         
     | 
| 231 | 
         
            +
                            )
         
     | 
| 232 | 
         
            +
                            if action_masks is not None
         
     | 
| 233 | 
         
            +
                            else None
         
     | 
| 234 | 
         
            +
                        )
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                        y_pred = values.reshape(-1)
         
     | 
| 237 | 
         
            +
                        b_values = torch.tensor(y_pred).to(self.device)
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                        step_stats = []
         
     | 
| 240 | 
         
            +
                        # Define variables that will definitely be set through the first epoch
         
     | 
| 241 | 
         
            +
                        advantages: np.ndarray = None  # type: ignore
         
     | 
| 242 | 
         
            +
                        b_advantages: torch.Tensor = None  # type: ignore
         
     | 
| 243 | 
         
            +
                        y_true: np.ndarray = None  # type: ignore
         
     | 
| 244 | 
         
            +
                        b_returns: torch.Tensor = None  # type: ignore
         
     | 
| 245 | 
         
            +
                        for e in range(self.n_epochs):
         
     | 
| 246 | 
         
            +
                            if e == 0 or self.update_advantage_between_epochs:
         
     | 
| 247 | 
         
            +
                                advantages = compute_advantages(
         
     | 
| 248 | 
         
            +
                                    rewards,
         
     | 
| 249 | 
         
            +
                                    values,
         
     | 
| 250 | 
         
            +
                                    episode_starts,
         
     | 
| 251 | 
         
            +
                                    next_episode_starts,
         
     | 
| 252 | 
         
            +
                                    next_obs,
         
     | 
| 253 | 
         
            +
                                    self.policy,
         
     | 
| 254 | 
         
            +
                                    self.gamma,
         
     | 
| 255 | 
         
            +
                                    self.gae_lambda,
         
     | 
| 256 | 
         
            +
                                )
         
     | 
| 257 | 
         
            +
                                b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
         
     | 
| 258 | 
         
            +
                            if e == 0 or self.update_returns_between_epochs:
         
     | 
| 259 | 
         
            +
                                returns = advantages + values
         
     | 
| 260 | 
         
            +
                                y_true = returns.reshape(-1)
         
     | 
| 261 | 
         
            +
                                b_returns = torch.tensor(y_true).to(self.device)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                            b_idxs = torch.randperm(len(b_obs))
         
     | 
| 264 | 
         
            +
                            # Only record last epoch's stats
         
     | 
| 265 | 
         
            +
                            step_stats.clear()
         
     | 
| 266 | 
         
            +
                            for i in range(0, len(b_obs), self.batch_size):
         
     | 
| 267 | 
         
            +
                                self.policy.reset_noise(self.batch_size)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                                mb_idxs = b_idxs[i : i + self.batch_size]
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                                mb_obs = b_obs[mb_idxs]
         
     | 
| 272 | 
         
            +
                                mb_actions = b_actions[mb_idxs]
         
     | 
| 273 | 
         
            +
                                mb_values = b_values[mb_idxs]
         
     | 
| 274 | 
         
            +
                                mb_logprobs = b_logprobs[mb_idxs]
         
     | 
| 275 | 
         
            +
                                mb_action_masks = (
         
     | 
| 276 | 
         
            +
                                    b_action_masks[mb_idxs] if b_action_masks is not None else None
         
     | 
| 277 | 
         
            +
                                )
         
     | 
| 278 | 
         | 
| 279 | 
         
            +
                                mb_adv = b_advantages[mb_idxs]
         
     | 
| 280 | 
         
            +
                                if self.normalize_advantage:
         
     | 
| 281 | 
         
            +
                                    mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)
         
     | 
| 282 | 
         
            +
                                mb_returns = b_returns[mb_idxs]
         
     | 
| 283 | 
         | 
| 284 | 
         
            +
                                new_logprobs, entropy, new_values = self.policy(
         
     | 
| 285 | 
         
            +
                                    mb_obs, mb_actions, action_masks=mb_action_masks
         
     | 
| 286 | 
         
            +
                                )
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                                logratio = new_logprobs - mb_logprobs
         
     | 
| 289 | 
         
            +
                                ratio = torch.exp(logratio)
         
     | 
| 290 | 
         
            +
                                clipped_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
         
     | 
| 291 | 
         
            +
                                pi_loss = torch.max(-ratio * mb_adv, -clipped_ratio * mb_adv).mean()
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                                v_loss_unclipped = (new_values - mb_returns) ** 2
         
     | 
| 294 | 
         
            +
                                if v_clip:
         
     | 
| 295 | 
         
            +
                                    v_loss_clipped = (
         
     | 
| 296 | 
         
            +
                                        mb_values
         
     | 
| 297 | 
         
            +
                                        + torch.clamp(new_values - mb_values, -v_clip, v_clip)
         
     | 
| 298 | 
         
            +
                                        - mb_returns
         
     | 
| 299 | 
         
            +
                                    ) ** 2
         
     | 
| 300 | 
         
            +
                                    v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
         
     | 
| 301 | 
         
            +
                                else:
         
     | 
| 302 | 
         
            +
                                    v_loss = v_loss_unclipped.mean()
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                                if self.ppo2_vf_coef_halving:
         
     | 
| 305 | 
         
            +
                                    v_loss *= 0.5
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                                entropy_loss = -entropy.mean()
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                                loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                                self.optimizer.zero_grad()
         
     | 
| 312 | 
         
            +
                                loss.backward()
         
     | 
| 313 | 
         
            +
                                nn.utils.clip_grad_norm_(
         
     | 
| 314 | 
         
            +
                                    self.policy.parameters(), self.max_grad_norm
         
     | 
| 315 | 
         
            +
                                )
         
     | 
| 316 | 
         
            +
                                self.optimizer.step()
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                                with torch.no_grad():
         
     | 
| 319 | 
         
            +
                                    approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
         
     | 
| 320 | 
         
            +
                                    clipped_frac = (
         
     | 
| 321 | 
         
            +
                                        ((ratio - 1).abs() > pi_clip)
         
     | 
| 322 | 
         
            +
                                        .float()
         
     | 
| 323 | 
         
            +
                                        .mean()
         
     | 
| 324 | 
         
            +
                                        .cpu()
         
     | 
| 325 | 
         
            +
                                        .numpy()
         
     | 
| 326 | 
         
            +
                                        .item()
         
     | 
| 327 | 
         
            +
                                    )
         
     | 
| 328 | 
         
            +
                                    val_clipped_frac = (
         
     | 
| 329 | 
         
            +
                                        ((new_values - mb_values).abs() > v_clip)
         
     | 
| 330 | 
         
            +
                                        .float()
         
     | 
| 331 | 
         
            +
                                        .mean()
         
     | 
| 332 | 
         
            +
                                        .cpu()
         
     | 
| 333 | 
         
            +
                                        .numpy()
         
     | 
| 334 | 
         
            +
                                        .item()
         
     | 
| 335 | 
         
            +
                                        if v_clip
         
     | 
| 336 | 
         
            +
                                        else 0
         
     | 
| 337 | 
         
            +
                                    )
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                                step_stats.append(
         
     | 
| 340 | 
         
            +
                                    TrainStepStats(
         
     | 
| 341 | 
         
            +
                                        loss.item(),
         
     | 
| 342 | 
         
            +
                                        pi_loss.item(),
         
     | 
| 343 | 
         
            +
                                        v_loss.item(),
         
     | 
| 344 | 
         
            +
                                        entropy_loss.item(),
         
     | 
| 345 | 
         
            +
                                        approx_kl,
         
     | 
| 346 | 
         
            +
                                        clipped_frac,
         
     | 
| 347 | 
         
            +
                                        val_clipped_frac,
         
     | 
| 348 | 
         
            +
                                    )
         
     | 
| 349 | 
         
            +
                                )
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                        var_y = np.var(y_true).item()
         
     | 
| 352 | 
         
            +
                        explained_var = (
         
     | 
| 353 | 
         
            +
                            np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
         
     | 
| 354 | 
         
             
                        )
         
     | 
| 355 | 
         
            +
                        TrainStats(step_stats, explained_var).write_to_tensorboard(
         
     | 
| 356 | 
         
            +
                            self.tb_writer, timesteps_elapsed
         
     | 
| 
         | 
|
| 
         | 
|
| 357 | 
         
             
                        )
         
     | 
| 358 | 
         | 
| 359 | 
         
            +
                        end_time = perf_counter()
         
     | 
| 360 | 
         
            +
                        rollout_steps = self.n_steps * self.env.num_envs
         
     | 
| 361 | 
         
            +
                        self.tb_writer.add_scalar(
         
     | 
| 362 | 
         
            +
                            "train/steps_per_second",
         
     | 
| 363 | 
         
            +
                            rollout_steps / (end_time - start_time),
         
     | 
| 364 | 
         
            +
                            timesteps_elapsed,
         
     | 
| 365 | 
         
            +
                        )
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                        if callback:
         
     | 
| 368 | 
         
            +
                            if not callback.on_step(timesteps_elapsed=rollout_steps):
         
     | 
| 369 | 
         
            +
                                logging.info(
         
     | 
| 370 | 
         
            +
                                    f"Callback terminated training at {timesteps_elapsed} timesteps"
         
     | 
| 371 | 
         
            +
                                )
         
     | 
| 372 | 
         
            +
                                break
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                    return self
         
     | 
    	
        rl_algo_impls/runner/config.py
    CHANGED
    
    | 
         @@ -2,12 +2,10 @@ import dataclasses 
     | 
|
| 2 | 
         
             
            import inspect
         
     | 
| 3 | 
         
             
            import itertools
         
     | 
| 4 | 
         
             
            import os
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            from datetime import datetime
         
     | 
| 7 | 
         
             
            from dataclasses import dataclass
         
     | 
| 
         | 
|
| 8 | 
         
             
            from typing import Any, Dict, List, Optional, Type, TypeVar, Union
         
     | 
| 9 | 
         | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
             
            RunArgsSelf = TypeVar("RunArgsSelf", bound="RunArgs")
         
     | 
| 12 | 
         | 
| 13 | 
         | 
| 
         @@ -50,6 +48,9 @@ class EnvHyperparams: 
     | 
|
| 50 | 
         
             
                video_step_interval: Union[int, float] = 1_000_000
         
     | 
| 51 | 
         
             
                initial_steps_to_truncate: Optional[int] = None
         
     | 
| 52 | 
         
             
                clip_atari_rewards: bool = True
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 53 | 
         | 
| 54 | 
         | 
| 55 | 
         
             
            HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
         
     | 
| 
         @@ -64,6 +65,7 @@ class Hyperparams: 
     | 
|
| 64 | 
         
             
                algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
         
     | 
| 65 | 
         
             
                eval_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
         
     | 
| 66 | 
         
             
                env_id: Optional[str] = None
         
     | 
| 
         | 
|
| 67 | 
         | 
| 68 | 
         
             
                @classmethod
         
     | 
| 69 | 
         
             
                def from_dict_with_extra_fields(
         
     | 
| 
         @@ -119,6 +121,10 @@ class Config: 
     | 
|
| 119 | 
         
             
                def env_id(self) -> str:
         
     | 
| 120 | 
         
             
                    return self.hyperparams.env_id or self.args.env
         
     | 
| 121 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 122 | 
         
             
                def model_name(self, include_seed: bool = True) -> str:
         
     | 
| 123 | 
         
             
                    # Use arg env name instead of environment name
         
     | 
| 124 | 
         
             
                    parts = [self.algo, self.args.env]
         
     | 
| 
         | 
|
| 2 | 
         
             
            import inspect
         
     | 
| 3 | 
         
             
            import itertools
         
     | 
| 4 | 
         
             
            import os
         
     | 
| 
         | 
|
| 
         | 
|
| 5 | 
         
             
            from dataclasses import dataclass
         
     | 
| 6 | 
         
            +
            from datetime import datetime
         
     | 
| 7 | 
         
             
            from typing import Any, Dict, List, Optional, Type, TypeVar, Union
         
     | 
| 8 | 
         | 
| 
         | 
|
| 9 | 
         
             
            RunArgsSelf = TypeVar("RunArgsSelf", bound="RunArgs")
         
     | 
| 10 | 
         | 
| 11 | 
         | 
| 
         | 
|
| 48 | 
         
             
                video_step_interval: Union[int, float] = 1_000_000
         
     | 
| 49 | 
         
             
                initial_steps_to_truncate: Optional[int] = None
         
     | 
| 50 | 
         
             
                clip_atari_rewards: bool = True
         
     | 
| 51 | 
         
            +
                normalize_type: Optional[str] = None
         
     | 
| 52 | 
         
            +
                mask_actions: bool = False
         
     | 
| 53 | 
         
            +
                bots: Optional[Dict[str, int]] = None
         
     | 
| 54 | 
         | 
| 55 | 
         | 
| 56 | 
         
             
            HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
         
     | 
| 
         | 
|
| 65 | 
         
             
                algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
         
     | 
| 66 | 
         
             
                eval_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
         
     | 
| 67 | 
         
             
                env_id: Optional[str] = None
         
     | 
| 68 | 
         
            +
                additional_keys_to_log: List[str] = dataclasses.field(default_factory=list)
         
     | 
| 69 | 
         | 
| 70 | 
         
             
                @classmethod
         
     | 
| 71 | 
         
             
                def from_dict_with_extra_fields(
         
     | 
| 
         | 
|
| 121 | 
         
             
                def env_id(self) -> str:
         
     | 
| 122 | 
         
             
                    return self.hyperparams.env_id or self.args.env
         
     | 
| 123 | 
         | 
| 124 | 
         
            +
                @property
         
     | 
| 125 | 
         
            +
                def additional_keys_to_log(self) -> List[str]:
         
     | 
| 126 | 
         
            +
                    return self.hyperparams.additional_keys_to_log
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
             
                def model_name(self, include_seed: bool = True) -> str:
         
     | 
| 129 | 
         
             
                    # Use arg env name instead of environment name
         
     | 
| 130 | 
         
             
                    parts = [self.algo, self.args.env]
         
     | 
    	
        rl_algo_impls/runner/evaluate.py
    CHANGED
    
    | 
         @@ -4,7 +4,7 @@ import shutil 
     | 
|
| 4 | 
         
             
            from dataclasses import dataclass
         
     | 
| 5 | 
         
             
            from typing import NamedTuple, Optional
         
     | 
| 6 | 
         | 
| 7 | 
         
            -
            from rl_algo_impls. 
     | 
| 8 | 
         
             
            from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
         
     | 
| 9 | 
         
             
            from rl_algo_impls.runner.running_utils import (
         
     | 
| 10 | 
         
             
                load_hyperparams,
         
     | 
| 
         @@ -75,7 +75,7 @@ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation: 
     | 
|
| 75 | 
         
             
                    render=args.render,
         
     | 
| 76 | 
         
             
                    normalize_load_path=model_path,
         
     | 
| 77 | 
         
             
                )
         
     | 
| 78 | 
         
            -
                device = get_device(config 
     | 
| 79 | 
         
             
                policy = make_policy(
         
     | 
| 80 | 
         
             
                    args.algo,
         
     | 
| 81 | 
         
             
                    env,
         
     | 
| 
         | 
|
| 4 | 
         
             
            from dataclasses import dataclass
         
     | 
| 5 | 
         
             
            from typing import NamedTuple, Optional
         
     | 
| 6 | 
         | 
| 7 | 
         
            +
            from rl_algo_impls.shared.vec_env import make_eval_env
         
     | 
| 8 | 
         
             
            from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
         
     | 
| 9 | 
         
             
            from rl_algo_impls.runner.running_utils import (
         
     | 
| 10 | 
         
             
                load_hyperparams,
         
     | 
| 
         | 
|
| 75 | 
         
             
                    render=args.render,
         
     | 
| 76 | 
         
             
                    normalize_load_path=model_path,
         
     | 
| 77 | 
         
             
                )
         
     | 
| 78 | 
         
            +
                device = get_device(config, env)
         
     | 
| 79 | 
         
             
                policy = make_policy(
         
     | 
| 80 | 
         
             
                    args.algo,
         
     | 
| 81 | 
         
             
                    env,
         
     | 
    	
        rl_algo_impls/runner/running_utils.py
    CHANGED
    
    | 
         @@ -1,32 +1,32 @@ 
     | 
|
| 1 | 
         
             
            import argparse
         
     | 
| 2 | 
         
            -
            import gym
         
     | 
| 3 | 
         
             
            import json
         
     | 
| 4 | 
         
            -
            import matplotlib.pyplot as plt
         
     | 
| 5 | 
         
            -
            import numpy as np
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 7 | 
         
             
            import random
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 8 | 
         
             
            import torch
         
     | 
| 9 | 
         
             
            import torch.backends.cudnn
         
     | 
| 10 | 
         
             
            import yaml
         
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
            from dataclasses import asdict
         
     | 
| 13 | 
         
             
            from gym.spaces import Box, Discrete
         
     | 
| 14 | 
         
            -
            from pathlib import Path
         
     | 
| 15 | 
         
             
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 16 | 
         
            -
            from typing import Dict, Optional, Type, Union
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
            from rl_algo_impls.runner.config import Hyperparams
         
     | 
| 19 | 
         
            -
            from rl_algo_impls.shared.algorithm import Algorithm
         
     | 
| 20 | 
         
            -
            from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
         
     | 
| 21 | 
         
            -
            from rl_algo_impls.shared.policy.on_policy import ActorCritic
         
     | 
| 22 | 
         
            -
            from rl_algo_impls.shared.policy.policy import Policy
         
     | 
| 23 | 
         | 
| 24 | 
         
             
            from rl_algo_impls.a2c.a2c import A2C
         
     | 
| 25 | 
         
             
            from rl_algo_impls.dqn.dqn import DQN
         
     | 
| 26 | 
         
             
            from rl_algo_impls.dqn.policy import DQNPolicy
         
     | 
| 27 | 
         
             
            from rl_algo_impls.ppo.ppo import PPO
         
     | 
| 28 | 
         
            -
            from rl_algo_impls. 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 29 | 
         
             
            from rl_algo_impls.vpg.policy import VPGActorCritic
         
     | 
| 
         | 
|
| 30 | 
         
             
            from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
         
     | 
| 31 | 
         | 
| 32 | 
         
             
            ALGOS: Dict[str, Type[Algorithm]] = {
         
     | 
| 
         @@ -81,16 +81,19 @@ def load_hyperparams(algo: str, env_id: str) -> Hyperparams: 
     | 
|
| 81 | 
         
             
                if env_id in hyperparams_dict:
         
     | 
| 82 | 
         
             
                    return Hyperparams(**hyperparams_dict[env_id])
         
     | 
| 83 | 
         | 
| 84 | 
         
            -
                 
     | 
| 85 | 
         
            -
                    import pybullet_envs
         
     | 
| 86 | 
         
             
                spec = gym.spec(env_id)
         
     | 
| 87 | 
         
            -
                 
     | 
| 
         | 
|
| 88 | 
         
             
                    return Hyperparams(**hyperparams_dict["_atari"])
         
     | 
| 
         | 
|
| 
         | 
|
| 89 | 
         
             
                else:
         
     | 
| 90 | 
         
             
                    raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
         
     | 
| 91 | 
         | 
| 92 | 
         | 
| 93 | 
         
            -
            def get_device( 
     | 
| 
         | 
|
| 94 | 
         
             
                # cuda by default
         
     | 
| 95 | 
         
             
                if device == "auto":
         
     | 
| 96 | 
         
             
                    device = "cuda"
         
     | 
| 
         @@ -108,6 +111,16 @@ def get_device(device: str, env: VecEnv) -> torch.device: 
     | 
|
| 108 | 
         
             
                        device = "cpu"
         
     | 
| 109 | 
         
             
                    elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
         
     | 
| 110 | 
         
             
                        device = "cpu"
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 111 | 
         
             
                print(f"Device: {device}")
         
     | 
| 112 | 
         
             
                return torch.device(device)
         
     | 
| 113 | 
         | 
| 
         @@ -187,6 +200,8 @@ def hparam_dict( 
     | 
|
| 187 | 
         
             
                                flattened[key] = str(sv)
         
     | 
| 188 | 
         
             
                            else:
         
     | 
| 189 | 
         
             
                                flattened[key] = sv
         
     | 
| 
         | 
|
| 
         | 
|
| 190 | 
         
             
                    else:
         
     | 
| 191 | 
         
             
                        flattened[k] = v  # type: ignore
         
     | 
| 192 | 
         
             
                return flattened  # type: ignore
         
     | 
| 
         | 
|
| 1 | 
         
             
            import argparse
         
     | 
| 
         | 
|
| 2 | 
         
             
            import json
         
     | 
| 
         | 
|
| 
         | 
|
| 3 | 
         
             
            import os
         
     | 
| 4 | 
         
             
            import random
         
     | 
| 5 | 
         
            +
            from dataclasses import asdict
         
     | 
| 6 | 
         
            +
            from pathlib import Path
         
     | 
| 7 | 
         
            +
            from typing import Dict, Optional, Type, Union
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import gym
         
     | 
| 10 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 11 | 
         
            +
            import numpy as np
         
     | 
| 12 | 
         
             
            import torch
         
     | 
| 13 | 
         
             
            import torch.backends.cudnn
         
     | 
| 14 | 
         
             
            import yaml
         
     | 
| 
         | 
|
| 
         | 
|
| 15 | 
         
             
            from gym.spaces import Box, Discrete
         
     | 
| 
         | 
|
| 16 | 
         
             
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 17 | 
         | 
| 18 | 
         
             
            from rl_algo_impls.a2c.a2c import A2C
         
     | 
| 19 | 
         
             
            from rl_algo_impls.dqn.dqn import DQN
         
     | 
| 20 | 
         
             
            from rl_algo_impls.dqn.policy import DQNPolicy
         
     | 
| 21 | 
         
             
            from rl_algo_impls.ppo.ppo import PPO
         
     | 
| 22 | 
         
            +
            from rl_algo_impls.runner.config import Config, Hyperparams
         
     | 
| 23 | 
         
            +
            from rl_algo_impls.shared.algorithm import Algorithm
         
     | 
| 24 | 
         
            +
            from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
         
     | 
| 25 | 
         
            +
            from rl_algo_impls.shared.policy.on_policy import ActorCritic
         
     | 
| 26 | 
         
            +
            from rl_algo_impls.shared.policy.policy import Policy
         
     | 
| 27 | 
         
            +
            from rl_algo_impls.shared.vec_env.utils import import_for_env_id, is_microrts
         
     | 
| 28 | 
         
             
            from rl_algo_impls.vpg.policy import VPGActorCritic
         
     | 
| 29 | 
         
            +
            from rl_algo_impls.vpg.vpg import VanillaPolicyGradient
         
     | 
| 30 | 
         
             
            from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
         
     | 
| 31 | 
         | 
| 32 | 
         
             
            ALGOS: Dict[str, Type[Algorithm]] = {
         
     | 
| 
         | 
|
| 81 | 
         
             
                if env_id in hyperparams_dict:
         
     | 
| 82 | 
         
             
                    return Hyperparams(**hyperparams_dict[env_id])
         
     | 
| 83 | 
         | 
| 84 | 
         
            +
                import_for_env_id(env_id)
         
     | 
| 
         | 
|
| 85 | 
         
             
                spec = gym.spec(env_id)
         
     | 
| 86 | 
         
            +
                entry_point_name = str(spec.entry_point)  # type: ignore
         
     | 
| 87 | 
         
            +
                if "AtariEnv" in entry_point_name and "_atari" in hyperparams_dict:
         
     | 
| 88 | 
         
             
                    return Hyperparams(**hyperparams_dict["_atari"])
         
     | 
| 89 | 
         
            +
                elif "gym_microrts" in entry_point_name and "_microrts" in hyperparams_dict:
         
     | 
| 90 | 
         
            +
                    return Hyperparams(**hyperparams_dict["_microrts"])
         
     | 
| 91 | 
         
             
                else:
         
     | 
| 92 | 
         
             
                    raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
         
     | 
| 93 | 
         | 
| 94 | 
         | 
| 95 | 
         
            +
            def get_device(config: Config, env: VecEnv) -> torch.device:
         
     | 
| 96 | 
         
            +
                device = config.device
         
     | 
| 97 | 
         
             
                # cuda by default
         
     | 
| 98 | 
         
             
                if device == "auto":
         
     | 
| 99 | 
         
             
                    device = "cuda"
         
     | 
| 
         | 
|
| 111 | 
         
             
                        device = "cpu"
         
     | 
| 112 | 
         
             
                    elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
         
     | 
| 113 | 
         
             
                        device = "cpu"
         
     | 
| 114 | 
         
            +
                    if is_microrts(config):
         
     | 
| 115 | 
         
            +
                        try:
         
     | 
| 116 | 
         
            +
                            from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                            # Models that move more than one unit at a time should use mps
         
     | 
| 119 | 
         
            +
                            if not isinstance(env.unwrapped, MicroRTSGridModeVecEnv):
         
     | 
| 120 | 
         
            +
                                device = "cpu"
         
     | 
| 121 | 
         
            +
                        except ModuleNotFoundError:
         
     | 
| 122 | 
         
            +
                            # Likely on gym_microrts v0.0.2 to match ppo-implementation-details
         
     | 
| 123 | 
         
            +
                            device = "cpu"
         
     | 
| 124 | 
         
             
                print(f"Device: {device}")
         
     | 
| 125 | 
         
             
                return torch.device(device)
         
     | 
| 126 | 
         | 
| 
         | 
|
| 200 | 
         
             
                                flattened[key] = str(sv)
         
     | 
| 201 | 
         
             
                            else:
         
     | 
| 202 | 
         
             
                                flattened[key] = sv
         
     | 
| 203 | 
         
            +
                    elif isinstance(v, list):
         
     | 
| 204 | 
         
            +
                        flattened[k] = json.dumps(v)
         
     | 
| 205 | 
         
             
                    else:
         
     | 
| 206 | 
         
             
                        flattened[k] = v  # type: ignore
         
     | 
| 207 | 
         
             
                return flattened  # type: ignore
         
     | 
    	
        rl_algo_impls/runner/train.py
    CHANGED
    
    | 
         @@ -5,26 +5,26 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 
     | 
|
| 5 | 
         | 
| 6 | 
         
             
            import dataclasses
         
     | 
| 7 | 
         
             
            import shutil
         
     | 
| 8 | 
         
            -
            import wandb
         
     | 
| 9 | 
         
            -
            import yaml
         
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
             
            from dataclasses import asdict, dataclass
         
     | 
| 12 | 
         
            -
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 13 | 
         
             
            from typing import Any, Dict, Optional, Sequence
         
     | 
| 14 | 
         | 
| 15 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 16 | 
         
             
            from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
         
     | 
| 17 | 
         
            -
            from rl_algo_impls.runner.env import make_env, make_eval_env
         
     | 
| 18 | 
         
             
            from rl_algo_impls.runner.running_utils import (
         
     | 
| 19 | 
         
             
                ALGOS,
         
     | 
| 20 | 
         
            -
                load_hyperparams,
         
     | 
| 21 | 
         
            -
                set_seeds,
         
     | 
| 22 | 
         
             
                get_device,
         
     | 
| 
         | 
|
| 
         | 
|
| 23 | 
         
             
                make_policy,
         
     | 
| 24 | 
         
             
                plot_eval_callback,
         
     | 
| 25 | 
         
            -
                 
     | 
| 26 | 
         
             
            )
         
     | 
| 
         | 
|
| 27 | 
         
             
            from rl_algo_impls.shared.stats import EpisodesStats
         
     | 
| 
         | 
|
| 28 | 
         | 
| 29 | 
         | 
| 30 | 
         
             
            @dataclass
         
     | 
| 
         @@ -65,7 +65,7 @@ def train(args: TrainArgs): 
     | 
|
| 65 | 
         
             
                env = make_env(
         
     | 
| 66 | 
         
             
                    config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
         
     | 
| 67 | 
         
             
                )
         
     | 
| 68 | 
         
            -
                device = get_device(config 
     | 
| 69 | 
         
             
                policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
         
     | 
| 70 | 
         
             
                algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
         
     | 
| 71 | 
         | 
| 
         @@ -94,6 +94,7 @@ def train(args: TrainArgs): 
     | 
|
| 94 | 
         
             
                    if record_best_videos
         
     | 
| 95 | 
         
             
                    else None,
         
     | 
| 96 | 
         
             
                    best_video_dir=config.best_videos_dir,
         
     | 
| 
         | 
|
| 97 | 
         
             
                )
         
     | 
| 98 | 
         
             
                algo.learn(config.n_timesteps, callback=callback)
         
     | 
| 99 | 
         | 
| 
         | 
|
| 5 | 
         | 
| 6 | 
         
             
            import dataclasses
         
     | 
| 7 | 
         
             
            import shutil
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 8 | 
         
             
            from dataclasses import asdict, dataclass
         
     | 
| 
         | 
|
| 9 | 
         
             
            from typing import Any, Dict, Optional, Sequence
         
     | 
| 10 | 
         | 
| 11 | 
         
            +
            import yaml
         
     | 
| 12 | 
         
            +
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import wandb
         
     | 
| 15 | 
         
             
            from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
         
     | 
| 
         | 
|
| 16 | 
         
             
            from rl_algo_impls.runner.running_utils import (
         
     | 
| 17 | 
         
             
                ALGOS,
         
     | 
| 
         | 
|
| 
         | 
|
| 18 | 
         
             
                get_device,
         
     | 
| 19 | 
         
            +
                hparam_dict,
         
     | 
| 20 | 
         
            +
                load_hyperparams,
         
     | 
| 21 | 
         
             
                make_policy,
         
     | 
| 22 | 
         
             
                plot_eval_callback,
         
     | 
| 23 | 
         
            +
                set_seeds,
         
     | 
| 24 | 
         
             
            )
         
     | 
| 25 | 
         
            +
            from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
         
     | 
| 26 | 
         
             
            from rl_algo_impls.shared.stats import EpisodesStats
         
     | 
| 27 | 
         
            +
            from rl_algo_impls.shared.vec_env import make_env, make_eval_env
         
     | 
| 28 | 
         | 
| 29 | 
         | 
| 30 | 
         
             
            @dataclass
         
     | 
| 
         | 
|
| 65 | 
         
             
                env = make_env(
         
     | 
| 66 | 
         
             
                    config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
         
     | 
| 67 | 
         
             
                )
         
     | 
| 68 | 
         
            +
                device = get_device(config, env)
         
     | 
| 69 | 
         
             
                policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
         
     | 
| 70 | 
         
             
                algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
         
     | 
| 71 | 
         | 
| 
         | 
|
| 94 | 
         
             
                    if record_best_videos
         
     | 
| 95 | 
         
             
                    else None,
         
     | 
| 96 | 
         
             
                    best_video_dir=config.best_videos_dir,
         
     | 
| 97 | 
         
            +
                    additional_keys_to_log=config.additional_keys_to_log,
         
     | 
| 98 | 
         
             
                )
         
     | 
| 99 | 
         
             
                algo.learn(config.n_timesteps, callback=callback)
         
     | 
| 100 | 
         | 
    	
        rl_algo_impls/shared/actor/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from rl_algo_impls.shared.actor.actor import Actor, PiForward
         
     | 
| 2 | 
         
            +
            from rl_algo_impls.shared.actor.make_actor import actor_head
         
     | 
    	
        rl_algo_impls/shared/actor/actor.py
    ADDED
    
    | 
         @@ -0,0 +1,42 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 2 | 
         
            +
            from typing import NamedTuple, Optional, Tuple
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            import torch.nn as nn
         
     | 
| 7 | 
         
            +
            from torch.distributions import Distribution
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class PiForward(NamedTuple):
         
     | 
| 11 | 
         
            +
                pi: Distribution
         
     | 
| 12 | 
         
            +
                logp_a: Optional[torch.Tensor]
         
     | 
| 13 | 
         
            +
                entropy: Optional[torch.Tensor]
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            class Actor(nn.Module, ABC):
         
     | 
| 17 | 
         
            +
                @abstractmethod
         
     | 
| 18 | 
         
            +
                def forward(
         
     | 
| 19 | 
         
            +
                    self,
         
     | 
| 20 | 
         
            +
                    obs: torch.Tensor,
         
     | 
| 21 | 
         
            +
                    actions: Optional[torch.Tensor] = None,
         
     | 
| 22 | 
         
            +
                    action_masks: Optional[torch.Tensor] = None,
         
     | 
| 23 | 
         
            +
                ) -> PiForward:
         
     | 
| 24 | 
         
            +
                    ...
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                def sample_weights(self, batch_size: int = 1) -> None:
         
     | 
| 27 | 
         
            +
                    pass
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                @property
         
     | 
| 30 | 
         
            +
                @abstractmethod
         
     | 
| 31 | 
         
            +
                def action_shape(self) -> Tuple[int, ...]:
         
     | 
| 32 | 
         
            +
                    ...
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def pi_forward(
         
     | 
| 35 | 
         
            +
                    self, distribution: Distribution, actions: Optional[torch.Tensor] = None
         
     | 
| 36 | 
         
            +
                ) -> PiForward:
         
     | 
| 37 | 
         
            +
                    logp_a = None
         
     | 
| 38 | 
         
            +
                    entropy = None
         
     | 
| 39 | 
         
            +
                    if actions is not None:
         
     | 
| 40 | 
         
            +
                        logp_a = distribution.log_prob(actions)
         
     | 
| 41 | 
         
            +
                        entropy = distribution.entropy()
         
     | 
| 42 | 
         
            +
                    return PiForward(distribution, logp_a, entropy)
         
     | 
    	
        rl_algo_impls/shared/actor/categorical.py
    ADDED
    
    | 
         @@ -0,0 +1,64 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Tuple, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            from torch.distributions import Categorical
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from rl_algo_impls.shared.actor import Actor, PiForward
         
     | 
| 8 | 
         
            +
            from rl_algo_impls.shared.module.module import mlp
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class MaskedCategorical(Categorical):
         
     | 
| 12 | 
         
            +
                def __init__(
         
     | 
| 13 | 
         
            +
                    self,
         
     | 
| 14 | 
         
            +
                    probs=None,
         
     | 
| 15 | 
         
            +
                    logits=None,
         
     | 
| 16 | 
         
            +
                    validate_args=None,
         
     | 
| 17 | 
         
            +
                    mask: Optional[torch.Tensor] = None,
         
     | 
| 18 | 
         
            +
                ):
         
     | 
| 19 | 
         
            +
                    if mask is not None:
         
     | 
| 20 | 
         
            +
                        assert logits is not None, "mask requires logits and not probs"
         
     | 
| 21 | 
         
            +
                        logits = torch.where(mask, logits, -1e8)
         
     | 
| 22 | 
         
            +
                    self.mask = mask
         
     | 
| 23 | 
         
            +
                    super().__init__(probs, logits, validate_args)
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                def entropy(self) -> torch.Tensor:
         
     | 
| 26 | 
         
            +
                    if self.mask is None:
         
     | 
| 27 | 
         
            +
                        return super().entropy()
         
     | 
| 28 | 
         
            +
                    # If mask set, then use approximation for entropy
         
     | 
| 29 | 
         
            +
                    p_log_p = self.logits * self.probs  # type: ignore
         
     | 
| 30 | 
         
            +
                    masked = torch.where(self.mask, p_log_p, 0)
         
     | 
| 31 | 
         
            +
                    return -masked.sum(-1)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            class CategoricalActorHead(Actor):
         
     | 
| 35 | 
         
            +
                def __init__(
         
     | 
| 36 | 
         
            +
                    self,
         
     | 
| 37 | 
         
            +
                    act_dim: int,
         
     | 
| 38 | 
         
            +
                    in_dim: int,
         
     | 
| 39 | 
         
            +
                    hidden_sizes: Tuple[int, ...] = (32,),
         
     | 
| 40 | 
         
            +
                    activation: Type[nn.Module] = nn.Tanh,
         
     | 
| 41 | 
         
            +
                    init_layers_orthogonal: bool = True,
         
     | 
| 42 | 
         
            +
                ) -> None:
         
     | 
| 43 | 
         
            +
                    super().__init__()
         
     | 
| 44 | 
         
            +
                    layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
         
     | 
| 45 | 
         
            +
                    self._fc = mlp(
         
     | 
| 46 | 
         
            +
                        layer_sizes,
         
     | 
| 47 | 
         
            +
                        activation,
         
     | 
| 48 | 
         
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 49 | 
         
            +
                        final_layer_gain=0.01,
         
     | 
| 50 | 
         
            +
                    )
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def forward(
         
     | 
| 53 | 
         
            +
                    self,
         
     | 
| 54 | 
         
            +
                    obs: torch.Tensor,
         
     | 
| 55 | 
         
            +
                    actions: Optional[torch.Tensor] = None,
         
     | 
| 56 | 
         
            +
                    action_masks: Optional[torch.Tensor] = None,
         
     | 
| 57 | 
         
            +
                ) -> PiForward:
         
     | 
| 58 | 
         
            +
                    logits = self._fc(obs)
         
     | 
| 59 | 
         
            +
                    pi = MaskedCategorical(logits=logits, mask=action_masks)
         
     | 
| 60 | 
         
            +
                    return self.pi_forward(pi, actions)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                @property
         
     | 
| 63 | 
         
            +
                def action_shape(self) -> Tuple[int, ...]:
         
     | 
| 64 | 
         
            +
                    return ()
         
     | 
    	
        rl_algo_impls/shared/actor/gaussian.py
    ADDED
    
    | 
         @@ -0,0 +1,61 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Tuple, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            from torch.distributions import Distribution, Normal
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from rl_algo_impls.shared.actor.actor import Actor, PiForward
         
     | 
| 8 | 
         
            +
            from rl_algo_impls.shared.module.module import mlp
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class GaussianDistribution(Normal):
         
     | 
| 12 | 
         
            +
                def log_prob(self, a: torch.Tensor) -> torch.Tensor:
         
     | 
| 13 | 
         
            +
                    return super().log_prob(a).sum(axis=-1)
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                def sample(self) -> torch.Tensor:
         
     | 
| 16 | 
         
            +
                    return self.rsample()
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            class GaussianActorHead(Actor):
         
     | 
| 20 | 
         
            +
                def __init__(
         
     | 
| 21 | 
         
            +
                    self,
         
     | 
| 22 | 
         
            +
                    act_dim: int,
         
     | 
| 23 | 
         
            +
                    in_dim: int,
         
     | 
| 24 | 
         
            +
                    hidden_sizes: Tuple[int, ...] = (32,),
         
     | 
| 25 | 
         
            +
                    activation: Type[nn.Module] = nn.Tanh,
         
     | 
| 26 | 
         
            +
                    init_layers_orthogonal: bool = True,
         
     | 
| 27 | 
         
            +
                    log_std_init: float = -0.5,
         
     | 
| 28 | 
         
            +
                ) -> None:
         
     | 
| 29 | 
         
            +
                    super().__init__()
         
     | 
| 30 | 
         
            +
                    self.act_dim = act_dim
         
     | 
| 31 | 
         
            +
                    layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
         
     | 
| 32 | 
         
            +
                    self.mu_net = mlp(
         
     | 
| 33 | 
         
            +
                        layer_sizes,
         
     | 
| 34 | 
         
            +
                        activation,
         
     | 
| 35 | 
         
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 36 | 
         
            +
                        final_layer_gain=0.01,
         
     | 
| 37 | 
         
            +
                    )
         
     | 
| 38 | 
         
            +
                    self.log_std = nn.Parameter(
         
     | 
| 39 | 
         
            +
                        torch.ones(act_dim, dtype=torch.float32) * log_std_init
         
     | 
| 40 | 
         
            +
                    )
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                def _distribution(self, obs: torch.Tensor) -> Distribution:
         
     | 
| 43 | 
         
            +
                    mu = self.mu_net(obs)
         
     | 
| 44 | 
         
            +
                    std = torch.exp(self.log_std)
         
     | 
| 45 | 
         
            +
                    return GaussianDistribution(mu, std)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def forward(
         
     | 
| 48 | 
         
            +
                    self,
         
     | 
| 49 | 
         
            +
                    obs: torch.Tensor,
         
     | 
| 50 | 
         
            +
                    actions: Optional[torch.Tensor] = None,
         
     | 
| 51 | 
         
            +
                    action_masks: Optional[torch.Tensor] = None,
         
     | 
| 52 | 
         
            +
                ) -> PiForward:
         
     | 
| 53 | 
         
            +
                    assert (
         
     | 
| 54 | 
         
            +
                        not action_masks
         
     | 
| 55 | 
         
            +
                    ), f"{self.__class__.__name__} does not support action_masks"
         
     | 
| 56 | 
         
            +
                    pi = self._distribution(obs)
         
     | 
| 57 | 
         
            +
                    return self.pi_forward(pi, actions)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                @property
         
     | 
| 60 | 
         
            +
                def action_shape(self) -> Tuple[int, ...]:
         
     | 
| 61 | 
         
            +
                    return (self.act_dim,)
         
     | 
    	
        rl_algo_impls/shared/actor/gridnet.py
    ADDED
    
    | 
         @@ -0,0 +1,108 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Dict, Optional, Tuple, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from numpy.typing import NDArray
         
     | 
| 7 | 
         
            +
            from torch.distributions import Distribution, constraints
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from rl_algo_impls.shared.actor import Actor, PiForward
         
     | 
| 10 | 
         
            +
            from rl_algo_impls.shared.actor.categorical import MaskedCategorical
         
     | 
| 11 | 
         
            +
            from rl_algo_impls.shared.encoder import EncoderOutDim
         
     | 
| 12 | 
         
            +
            from rl_algo_impls.shared.module.module import mlp
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class GridnetDistribution(Distribution):
         
     | 
| 16 | 
         
            +
                def __init__(
         
     | 
| 17 | 
         
            +
                    self,
         
     | 
| 18 | 
         
            +
                    map_size: int,
         
     | 
| 19 | 
         
            +
                    action_vec: NDArray[np.int64],
         
     | 
| 20 | 
         
            +
                    logits: torch.Tensor,
         
     | 
| 21 | 
         
            +
                    masks: torch.Tensor,
         
     | 
| 22 | 
         
            +
                    validate_args: Optional[bool] = None,
         
     | 
| 23 | 
         
            +
                ) -> None:
         
     | 
| 24 | 
         
            +
                    self.map_size = map_size
         
     | 
| 25 | 
         
            +
                    self.action_vec = action_vec
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    masks = masks.view(-1, masks.shape[-1])
         
     | 
| 28 | 
         
            +
                    split_masks = torch.split(masks[:, 1:], action_vec.tolist(), dim=1)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    grid_logits = logits.reshape(-1, action_vec.sum())
         
     | 
| 31 | 
         
            +
                    split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1)
         
     | 
| 32 | 
         
            +
                    self.categoricals = [
         
     | 
| 33 | 
         
            +
                        MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
         
     | 
| 34 | 
         
            +
                        for lg, m in zip(split_logits, split_masks)
         
     | 
| 35 | 
         
            +
                    ]
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    batch_shape = logits.size()[:-1] if logits.ndimension() > 1 else torch.Size()
         
     | 
| 38 | 
         
            +
                    super().__init__(batch_shape=batch_shape, validate_args=validate_args)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def log_prob(self, action: torch.Tensor) -> torch.Tensor:
         
     | 
| 41 | 
         
            +
                    prob_stack = torch.stack(
         
     | 
| 42 | 
         
            +
                        [
         
     | 
| 43 | 
         
            +
                            c.log_prob(a)
         
     | 
| 44 | 
         
            +
                            for a, c in zip(action.view(-1, action.shape[-1]).T, self.categoricals)
         
     | 
| 45 | 
         
            +
                        ],
         
     | 
| 46 | 
         
            +
                        dim=-1,
         
     | 
| 47 | 
         
            +
                    )
         
     | 
| 48 | 
         
            +
                    logprob = prob_stack.view(-1, self.map_size, len(self.action_vec))
         
     | 
| 49 | 
         
            +
                    return logprob.sum(dim=(1, 2))
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def entropy(self) -> torch.Tensor:
         
     | 
| 52 | 
         
            +
                    ent = torch.stack([c.entropy() for c in self.categoricals], dim=-1)
         
     | 
| 53 | 
         
            +
                    ent = ent.view(-1, self.map_size, len(self.action_vec))
         
     | 
| 54 | 
         
            +
                    return ent.sum(dim=(1, 2))
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
         
     | 
| 57 | 
         
            +
                    s = torch.stack([c.sample(sample_shape) for c in self.categoricals], dim=-1)
         
     | 
| 58 | 
         
            +
                    return s.view(-1, self.map_size, len(self.action_vec))
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                @property
         
     | 
| 61 | 
         
            +
                def mode(self) -> torch.Tensor:
         
     | 
| 62 | 
         
            +
                    m = torch.stack([c.mode for c in self.categoricals], dim=-1)
         
     | 
| 63 | 
         
            +
                    return m.view(-1, self.map_size, len(self.action_vec))
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                @property
         
     | 
| 66 | 
         
            +
                def arg_constraints(self) -> Dict[str, constraints.Constraint]:
         
     | 
| 67 | 
         
            +
                    # Constraints handled by child distributions in dist
         
     | 
| 68 | 
         
            +
                    return {}
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            class GridnetActorHead(Actor):
         
     | 
| 72 | 
         
            +
                def __init__(
         
     | 
| 73 | 
         
            +
                    self,
         
     | 
| 74 | 
         
            +
                    map_size: int,
         
     | 
| 75 | 
         
            +
                    action_vec: NDArray[np.int64],
         
     | 
| 76 | 
         
            +
                    in_dim: EncoderOutDim,
         
     | 
| 77 | 
         
            +
                    hidden_sizes: Tuple[int, ...] = (32,),
         
     | 
| 78 | 
         
            +
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 79 | 
         
            +
                    init_layers_orthogonal: bool = True,
         
     | 
| 80 | 
         
            +
                ) -> None:
         
     | 
| 81 | 
         
            +
                    super().__init__()
         
     | 
| 82 | 
         
            +
                    self.map_size = map_size
         
     | 
| 83 | 
         
            +
                    self.action_vec = action_vec
         
     | 
| 84 | 
         
            +
                    assert isinstance(in_dim, int)
         
     | 
| 85 | 
         
            +
                    layer_sizes = (in_dim,) + hidden_sizes + (map_size * action_vec.sum(),)
         
     | 
| 86 | 
         
            +
                    self._fc = mlp(
         
     | 
| 87 | 
         
            +
                        layer_sizes,
         
     | 
| 88 | 
         
            +
                        activation,
         
     | 
| 89 | 
         
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 90 | 
         
            +
                        final_layer_gain=0.01,
         
     | 
| 91 | 
         
            +
                    )
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                def forward(
         
     | 
| 94 | 
         
            +
                    self,
         
     | 
| 95 | 
         
            +
                    obs: torch.Tensor,
         
     | 
| 96 | 
         
            +
                    actions: Optional[torch.Tensor] = None,
         
     | 
| 97 | 
         
            +
                    action_masks: Optional[torch.Tensor] = None,
         
     | 
| 98 | 
         
            +
                ) -> PiForward:
         
     | 
| 99 | 
         
            +
                    assert (
         
     | 
| 100 | 
         
            +
                        action_masks is not None
         
     | 
| 101 | 
         
            +
                    ), f"No mask case unhandled in {self.__class__.__name__}"
         
     | 
| 102 | 
         
            +
                    logits = self._fc(obs)
         
     | 
| 103 | 
         
            +
                    pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
         
     | 
| 104 | 
         
            +
                    return self.pi_forward(pi, actions)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                @property
         
     | 
| 107 | 
         
            +
                def action_shape(self) -> Tuple[int, ...]:
         
     | 
| 108 | 
         
            +
                    return (self.map_size, len(self.action_vec))
         
     | 
    	
        rl_algo_impls/shared/actor/gridnet_decoder.py
    ADDED
    
    | 
         @@ -0,0 +1,80 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Tuple, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from numpy.typing import NDArray
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from rl_algo_impls.shared.actor import Actor, PiForward
         
     | 
| 9 | 
         
            +
            from rl_algo_impls.shared.actor.categorical import MaskedCategorical
         
     | 
| 10 | 
         
            +
            from rl_algo_impls.shared.actor.gridnet import GridnetDistribution
         
     | 
| 11 | 
         
            +
            from rl_algo_impls.shared.encoder import EncoderOutDim
         
     | 
| 12 | 
         
            +
            from rl_algo_impls.shared.module.module import layer_init
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class Transpose(nn.Module):
         
     | 
| 16 | 
         
            +
                def __init__(self, permutation: Tuple[int, ...]) -> None:
         
     | 
| 17 | 
         
            +
                    super().__init__()
         
     | 
| 18 | 
         
            +
                    self.permutation = permutation
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 21 | 
         
            +
                    return x.permute(self.permutation)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class GridnetDecoder(Actor):
         
     | 
| 25 | 
         
            +
                def __init__(
         
     | 
| 26 | 
         
            +
                    self,
         
     | 
| 27 | 
         
            +
                    map_size: int,
         
     | 
| 28 | 
         
            +
                    action_vec: NDArray[np.int64],
         
     | 
| 29 | 
         
            +
                    in_dim: EncoderOutDim,
         
     | 
| 30 | 
         
            +
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 31 | 
         
            +
                    init_layers_orthogonal: bool = True,
         
     | 
| 32 | 
         
            +
                ) -> None:
         
     | 
| 33 | 
         
            +
                    super().__init__()
         
     | 
| 34 | 
         
            +
                    self.map_size = map_size
         
     | 
| 35 | 
         
            +
                    self.action_vec = action_vec
         
     | 
| 36 | 
         
            +
                    assert isinstance(in_dim, tuple)
         
     | 
| 37 | 
         
            +
                    self.deconv = nn.Sequential(
         
     | 
| 38 | 
         
            +
                        layer_init(
         
     | 
| 39 | 
         
            +
                            nn.ConvTranspose2d(
         
     | 
| 40 | 
         
            +
                                in_dim[0], 128, 3, stride=2, padding=1, output_padding=1
         
     | 
| 41 | 
         
            +
                            ),
         
     | 
| 42 | 
         
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 43 | 
         
            +
                        ),
         
     | 
| 44 | 
         
            +
                        activation(),
         
     | 
| 45 | 
         
            +
                        layer_init(
         
     | 
| 46 | 
         
            +
                            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
         
     | 
| 47 | 
         
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 48 | 
         
            +
                        ),
         
     | 
| 49 | 
         
            +
                        activation(),
         
     | 
| 50 | 
         
            +
                        layer_init(
         
     | 
| 51 | 
         
            +
                            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
         
     | 
| 52 | 
         
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 53 | 
         
            +
                        ),
         
     | 
| 54 | 
         
            +
                        activation(),
         
     | 
| 55 | 
         
            +
                        layer_init(
         
     | 
| 56 | 
         
            +
                            nn.ConvTranspose2d(
         
     | 
| 57 | 
         
            +
                                32, action_vec.sum(), 3, stride=2, padding=1, output_padding=1
         
     | 
| 58 | 
         
            +
                            ),
         
     | 
| 59 | 
         
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 60 | 
         
            +
                            std=0.01,
         
     | 
| 61 | 
         
            +
                        ),
         
     | 
| 62 | 
         
            +
                        Transpose((0, 2, 3, 1)),
         
     | 
| 63 | 
         
            +
                    )
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def forward(
         
     | 
| 66 | 
         
            +
                    self,
         
     | 
| 67 | 
         
            +
                    obs: torch.Tensor,
         
     | 
| 68 | 
         
            +
                    actions: Optional[torch.Tensor] = None,
         
     | 
| 69 | 
         
            +
                    action_masks: Optional[torch.Tensor] = None,
         
     | 
| 70 | 
         
            +
                ) -> PiForward:
         
     | 
| 71 | 
         
            +
                    assert (
         
     | 
| 72 | 
         
            +
                        action_masks is not None
         
     | 
| 73 | 
         
            +
                    ), f"No mask case unhandled in {self.__class__.__name__}"
         
     | 
| 74 | 
         
            +
                    logits = self.deconv(obs)
         
     | 
| 75 | 
         
            +
                    pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
         
     | 
| 76 | 
         
            +
                    return self.pi_forward(pi, actions)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                @property
         
     | 
| 79 | 
         
            +
                def action_shape(self) -> Tuple[int, ...]:
         
     | 
| 80 | 
         
            +
                    return (self.map_size, len(self.action_vec))
         
     | 
    	
        rl_algo_impls/shared/actor/make_actor.py
    ADDED
    
    | 
         @@ -0,0 +1,95 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Tuple, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import gym
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            from gym.spaces import Box, Discrete, MultiDiscrete
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from rl_algo_impls.shared.actor.actor import Actor
         
     | 
| 8 | 
         
            +
            from rl_algo_impls.shared.actor.categorical import CategoricalActorHead
         
     | 
| 9 | 
         
            +
            from rl_algo_impls.shared.actor.gaussian import GaussianActorHead
         
     | 
| 10 | 
         
            +
            from rl_algo_impls.shared.actor.gridnet import GridnetActorHead
         
     | 
| 11 | 
         
            +
            from rl_algo_impls.shared.actor.gridnet_decoder import GridnetDecoder
         
     | 
| 12 | 
         
            +
            from rl_algo_impls.shared.actor.multi_discrete import MultiDiscreteActorHead
         
     | 
| 13 | 
         
            +
            from rl_algo_impls.shared.actor.state_dependent_noise import (
         
     | 
| 14 | 
         
            +
                StateDependentNoiseActorHead,
         
     | 
| 15 | 
         
            +
            )
         
     | 
| 16 | 
         
            +
            from rl_algo_impls.shared.encoder import EncoderOutDim
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def actor_head(
         
     | 
| 20 | 
         
            +
                action_space: gym.Space,
         
     | 
| 21 | 
         
            +
                in_dim: EncoderOutDim,
         
     | 
| 22 | 
         
            +
                hidden_sizes: Tuple[int, ...],
         
     | 
| 23 | 
         
            +
                init_layers_orthogonal: bool,
         
     | 
| 24 | 
         
            +
                activation: Type[nn.Module],
         
     | 
| 25 | 
         
            +
                log_std_init: float = -0.5,
         
     | 
| 26 | 
         
            +
                use_sde: bool = False,
         
     | 
| 27 | 
         
            +
                full_std: bool = True,
         
     | 
| 28 | 
         
            +
                squash_output: bool = False,
         
     | 
| 29 | 
         
            +
                actor_head_style: str = "single",
         
     | 
| 30 | 
         
            +
            ) -> Actor:
         
     | 
| 31 | 
         
            +
                assert not use_sde or isinstance(
         
     | 
| 32 | 
         
            +
                    action_space, Box
         
     | 
| 33 | 
         
            +
                ), "use_sde only valid if Box action_space"
         
     | 
| 34 | 
         
            +
                assert not squash_output or use_sde, "squash_output only valid if use_sde"
         
     | 
| 35 | 
         
            +
                if isinstance(action_space, Discrete):
         
     | 
| 36 | 
         
            +
                    assert isinstance(in_dim, int)
         
     | 
| 37 | 
         
            +
                    return CategoricalActorHead(
         
     | 
| 38 | 
         
            +
                        action_space.n,  # type: ignore
         
     | 
| 39 | 
         
            +
                        in_dim=in_dim,
         
     | 
| 40 | 
         
            +
                        hidden_sizes=hidden_sizes,
         
     | 
| 41 | 
         
            +
                        activation=activation,
         
     | 
| 42 | 
         
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 43 | 
         
            +
                    )
         
     | 
| 44 | 
         
            +
                elif isinstance(action_space, Box):
         
     | 
| 45 | 
         
            +
                    assert isinstance(in_dim, int)
         
     | 
| 46 | 
         
            +
                    if use_sde:
         
     | 
| 47 | 
         
            +
                        return StateDependentNoiseActorHead(
         
     | 
| 48 | 
         
            +
                            action_space.shape[0],  # type: ignore
         
     | 
| 49 | 
         
            +
                            in_dim=in_dim,
         
     | 
| 50 | 
         
            +
                            hidden_sizes=hidden_sizes,
         
     | 
| 51 | 
         
            +
                            activation=activation,
         
     | 
| 52 | 
         
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 53 | 
         
            +
                            log_std_init=log_std_init,
         
     | 
| 54 | 
         
            +
                            full_std=full_std,
         
     | 
| 55 | 
         
            +
                            squash_output=squash_output,
         
     | 
| 56 | 
         
            +
                        )
         
     | 
| 57 | 
         
            +
                    else:
         
     | 
| 58 | 
         
            +
                        return GaussianActorHead(
         
     | 
| 59 | 
         
            +
                            action_space.shape[0],  # type: ignore
         
     | 
| 60 | 
         
            +
                            in_dim=in_dim,
         
     | 
| 61 | 
         
            +
                            hidden_sizes=hidden_sizes,
         
     | 
| 62 | 
         
            +
                            activation=activation,
         
     | 
| 63 | 
         
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 64 | 
         
            +
                            log_std_init=log_std_init,
         
     | 
| 65 | 
         
            +
                        )
         
     | 
| 66 | 
         
            +
                elif isinstance(action_space, MultiDiscrete):
         
     | 
| 67 | 
         
            +
                    if actor_head_style == "single":
         
     | 
| 68 | 
         
            +
                        return MultiDiscreteActorHead(
         
     | 
| 69 | 
         
            +
                            action_space.nvec,  # type: ignore
         
     | 
| 70 | 
         
            +
                            in_dim=in_dim,
         
     | 
| 71 | 
         
            +
                            hidden_sizes=hidden_sizes,
         
     | 
| 72 | 
         
            +
                            activation=activation,
         
     | 
| 73 | 
         
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 74 | 
         
            +
                        )
         
     | 
| 75 | 
         
            +
                    elif actor_head_style == "gridnet":
         
     | 
| 76 | 
         
            +
                        return GridnetActorHead(
         
     | 
| 77 | 
         
            +
                            action_space.nvec[0],  # type: ignore
         
     | 
| 78 | 
         
            +
                            action_space.nvec[1:],  # type: ignore
         
     | 
| 79 | 
         
            +
                            in_dim=in_dim,
         
     | 
| 80 | 
         
            +
                            hidden_sizes=hidden_sizes,
         
     | 
| 81 | 
         
            +
                            activation=activation,
         
     | 
| 82 | 
         
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 83 | 
         
            +
                        )
         
     | 
| 84 | 
         
            +
                    elif actor_head_style == "gridnet_decoder":
         
     | 
| 85 | 
         
            +
                        return GridnetDecoder(
         
     | 
| 86 | 
         
            +
                            action_space.nvec[0],  # type: ignore
         
     | 
| 87 | 
         
            +
                            action_space.nvec[1:],  # type: ignore
         
     | 
| 88 | 
         
            +
                            in_dim=in_dim,
         
     | 
| 89 | 
         
            +
                            activation=activation,
         
     | 
| 90 | 
         
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 91 | 
         
            +
                        )
         
     | 
| 92 | 
         
            +
                    else:
         
     | 
| 93 | 
         
            +
                        raise ValueError(f"Doesn't support actor_head_style {actor_head_style}")
         
     | 
| 94 | 
         
            +
                else:
         
     | 
| 95 | 
         
            +
                    raise ValueError(f"Unsupported action space: {action_space}")
         
     | 
    	
        rl_algo_impls/shared/actor/multi_discrete.py
    ADDED
    
    | 
         @@ -0,0 +1,101 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Dict, Optional, Tuple, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from numpy.typing import NDArray
         
     | 
| 7 | 
         
            +
            from torch.distributions import Distribution, constraints
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from rl_algo_impls.shared.actor.actor import Actor, PiForward
         
     | 
| 10 | 
         
            +
            from rl_algo_impls.shared.actor.categorical import MaskedCategorical
         
     | 
| 11 | 
         
            +
            from rl_algo_impls.shared.encoder import EncoderOutDim
         
     | 
| 12 | 
         
            +
            from rl_algo_impls.shared.module.module import mlp
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class MultiCategorical(Distribution):
         
     | 
| 16 | 
         
            +
                def __init__(
         
     | 
| 17 | 
         
            +
                    self,
         
     | 
| 18 | 
         
            +
                    nvec: NDArray[np.int64],
         
     | 
| 19 | 
         
            +
                    probs=None,
         
     | 
| 20 | 
         
            +
                    logits=None,
         
     | 
| 21 | 
         
            +
                    validate_args=None,
         
     | 
| 22 | 
         
            +
                    masks: Optional[torch.Tensor] = None,
         
     | 
| 23 | 
         
            +
                ):
         
     | 
| 24 | 
         
            +
                    # Either probs or logits should be set
         
     | 
| 25 | 
         
            +
                    assert (probs is None) != (logits is None)
         
     | 
| 26 | 
         
            +
                    masks_split = (
         
     | 
| 27 | 
         
            +
                        torch.split(masks, nvec.tolist(), dim=1)
         
     | 
| 28 | 
         
            +
                        if masks is not None
         
     | 
| 29 | 
         
            +
                        else [None] * len(nvec)
         
     | 
| 30 | 
         
            +
                    )
         
     | 
| 31 | 
         
            +
                    if probs:
         
     | 
| 32 | 
         
            +
                        self.dists = [
         
     | 
| 33 | 
         
            +
                            MaskedCategorical(probs=p, validate_args=validate_args, mask=m)
         
     | 
| 34 | 
         
            +
                            for p, m in zip(torch.split(probs, nvec.tolist(), dim=1), masks_split)
         
     | 
| 35 | 
         
            +
                        ]
         
     | 
| 36 | 
         
            +
                        param = probs
         
     | 
| 37 | 
         
            +
                    else:
         
     | 
| 38 | 
         
            +
                        assert logits is not None
         
     | 
| 39 | 
         
            +
                        self.dists = [
         
     | 
| 40 | 
         
            +
                            MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
         
     | 
| 41 | 
         
            +
                            for lg, m in zip(torch.split(logits, nvec.tolist(), dim=1), masks_split)
         
     | 
| 42 | 
         
            +
                        ]
         
     | 
| 43 | 
         
            +
                        param = logits
         
     | 
| 44 | 
         
            +
                    batch_shape = param.size()[:-1] if param.ndimension() > 1 else torch.Size()
         
     | 
| 45 | 
         
            +
                    super().__init__(batch_shape=batch_shape, validate_args=validate_args)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def log_prob(self, action: torch.Tensor) -> torch.Tensor:
         
     | 
| 48 | 
         
            +
                    prob_stack = torch.stack(
         
     | 
| 49 | 
         
            +
                        [c.log_prob(a) for a, c in zip(action.T, self.dists)], dim=-1
         
     | 
| 50 | 
         
            +
                    )
         
     | 
| 51 | 
         
            +
                    return prob_stack.sum(dim=-1)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                def entropy(self) -> torch.Tensor:
         
     | 
| 54 | 
         
            +
                    return torch.stack([c.entropy() for c in self.dists], dim=-1).sum(dim=-1)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
         
     | 
| 57 | 
         
            +
                    return torch.stack([c.sample(sample_shape) for c in self.dists], dim=-1)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                @property
         
     | 
| 60 | 
         
            +
                def mode(self) -> torch.Tensor:
         
     | 
| 61 | 
         
            +
                    return torch.stack([c.mode for c in self.dists], dim=-1)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                @property
         
     | 
| 64 | 
         
            +
                def arg_constraints(self) -> Dict[str, constraints.Constraint]:
         
     | 
| 65 | 
         
            +
                    # Constraints handled by child distributions in dist
         
     | 
| 66 | 
         
            +
                    return {}
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            class MultiDiscreteActorHead(Actor):
         
     | 
| 70 | 
         
            +
                def __init__(
         
     | 
| 71 | 
         
            +
                    self,
         
     | 
| 72 | 
         
            +
                    nvec: NDArray[np.int64],
         
     | 
| 73 | 
         
            +
                    in_dim: EncoderOutDim,
         
     | 
| 74 | 
         
            +
                    hidden_sizes: Tuple[int, ...] = (32,),
         
     | 
| 75 | 
         
            +
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 76 | 
         
            +
                    init_layers_orthogonal: bool = True,
         
     | 
| 77 | 
         
            +
                ) -> None:
         
     | 
| 78 | 
         
            +
                    super().__init__()
         
     | 
| 79 | 
         
            +
                    self.nvec = nvec
         
     | 
| 80 | 
         
            +
                    assert isinstance(in_dim, int)
         
     | 
| 81 | 
         
            +
                    layer_sizes = (in_dim,) + hidden_sizes + (nvec.sum(),)
         
     | 
| 82 | 
         
            +
                    self._fc = mlp(
         
     | 
| 83 | 
         
            +
                        layer_sizes,
         
     | 
| 84 | 
         
            +
                        activation,
         
     | 
| 85 | 
         
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 86 | 
         
            +
                        final_layer_gain=0.01,
         
     | 
| 87 | 
         
            +
                    )
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                def forward(
         
     | 
| 90 | 
         
            +
                    self,
         
     | 
| 91 | 
         
            +
                    obs: torch.Tensor,
         
     | 
| 92 | 
         
            +
                    actions: Optional[torch.Tensor] = None,
         
     | 
| 93 | 
         
            +
                    action_masks: Optional[torch.Tensor] = None,
         
     | 
| 94 | 
         
            +
                ) -> PiForward:
         
     | 
| 95 | 
         
            +
                    logits = self._fc(obs)
         
     | 
| 96 | 
         
            +
                    pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
         
     | 
| 97 | 
         
            +
                    return self.pi_forward(pi, actions)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                @property
         
     | 
| 100 | 
         
            +
                def action_shape(self) -> Tuple[int, ...]:
         
     | 
| 101 | 
         
            +
                    return (len(self.nvec),)
         
     | 
    	
        rl_algo_impls/shared/{policy/actor.py → actor/state_dependent_noise.py}
    RENAMED
    
    | 
         @@ -1,99 +1,13 @@ 
     | 
|
| 1 | 
         
            -
            import  
     | 
| 
         | 
|
| 2 | 
         
             
            import torch
         
     | 
| 3 | 
         
             
            import torch.nn as nn
         
     | 
| 
         | 
|
| 4 | 
         | 
| 5 | 
         
            -
            from  
     | 
| 6 | 
         
            -
            from gym.spaces import Box, Discrete
         
     | 
| 7 | 
         
            -
            from torch.distributions import Categorical, Distribution, Normal
         
     | 
| 8 | 
         
            -
            from typing import NamedTuple, Optional, Sequence, Type, TypeVar, Union
         
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
             
            from rl_algo_impls.shared.module.module import mlp
         
     | 
| 11 | 
         | 
| 12 | 
         | 
| 13 | 
         
            -
            class PiForward(NamedTuple):
         
     | 
| 14 | 
         
            -
                pi: Distribution
         
     | 
| 15 | 
         
            -
                logp_a: Optional[torch.Tensor]
         
     | 
| 16 | 
         
            -
                entropy: Optional[torch.Tensor]
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
            -
            class Actor(nn.Module, ABC):
         
     | 
| 20 | 
         
            -
                @abstractmethod
         
     | 
| 21 | 
         
            -
                def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
         
     | 
| 22 | 
         
            -
                    ...
         
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
            class CategoricalActorHead(Actor):
         
     | 
| 26 | 
         
            -
                def __init__(
         
     | 
| 27 | 
         
            -
                    self,
         
     | 
| 28 | 
         
            -
                    act_dim: int,
         
     | 
| 29 | 
         
            -
                    hidden_sizes: Sequence[int] = (32,),
         
     | 
| 30 | 
         
            -
                    activation: Type[nn.Module] = nn.Tanh,
         
     | 
| 31 | 
         
            -
                    init_layers_orthogonal: bool = True,
         
     | 
| 32 | 
         
            -
                ) -> None:
         
     | 
| 33 | 
         
            -
                    super().__init__()
         
     | 
| 34 | 
         
            -
                    layer_sizes = tuple(hidden_sizes) + (act_dim,)
         
     | 
| 35 | 
         
            -
                    self._fc = mlp(
         
     | 
| 36 | 
         
            -
                        layer_sizes,
         
     | 
| 37 | 
         
            -
                        activation,
         
     | 
| 38 | 
         
            -
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 39 | 
         
            -
                        final_layer_gain=0.01,
         
     | 
| 40 | 
         
            -
                    )
         
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
                def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
         
     | 
| 43 | 
         
            -
                    logits = self._fc(obs)
         
     | 
| 44 | 
         
            -
                    pi = Categorical(logits=logits)
         
     | 
| 45 | 
         
            -
                    logp_a = None
         
     | 
| 46 | 
         
            -
                    entropy = None
         
     | 
| 47 | 
         
            -
                    if a is not None:
         
     | 
| 48 | 
         
            -
                        logp_a = pi.log_prob(a)
         
     | 
| 49 | 
         
            -
                        entropy = pi.entropy()
         
     | 
| 50 | 
         
            -
                    return PiForward(pi, logp_a, entropy)
         
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
            class GaussianDistribution(Normal):
         
     | 
| 54 | 
         
            -
                def log_prob(self, a: torch.Tensor) -> torch.Tensor:
         
     | 
| 55 | 
         
            -
                    return super().log_prob(a).sum(axis=-1)
         
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
                def sample(self) -> torch.Tensor:
         
     | 
| 58 | 
         
            -
                    return self.rsample()
         
     | 
| 59 | 
         
            -
             
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
            class GaussianActorHead(Actor):
         
     | 
| 62 | 
         
            -
                def __init__(
         
     | 
| 63 | 
         
            -
                    self,
         
     | 
| 64 | 
         
            -
                    act_dim: int,
         
     | 
| 65 | 
         
            -
                    hidden_sizes: Sequence[int] = (32,),
         
     | 
| 66 | 
         
            -
                    activation: Type[nn.Module] = nn.Tanh,
         
     | 
| 67 | 
         
            -
                    init_layers_orthogonal: bool = True,
         
     | 
| 68 | 
         
            -
                    log_std_init: float = -0.5,
         
     | 
| 69 | 
         
            -
                ) -> None:
         
     | 
| 70 | 
         
            -
                    super().__init__()
         
     | 
| 71 | 
         
            -
                    layer_sizes = tuple(hidden_sizes) + (act_dim,)
         
     | 
| 72 | 
         
            -
                    self.mu_net = mlp(
         
     | 
| 73 | 
         
            -
                        layer_sizes,
         
     | 
| 74 | 
         
            -
                        activation,
         
     | 
| 75 | 
         
            -
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 76 | 
         
            -
                        final_layer_gain=0.01,
         
     | 
| 77 | 
         
            -
                    )
         
     | 
| 78 | 
         
            -
                    self.log_std = nn.Parameter(
         
     | 
| 79 | 
         
            -
                        torch.ones(act_dim, dtype=torch.float32) * log_std_init
         
     | 
| 80 | 
         
            -
                    )
         
     | 
| 81 | 
         
            -
             
     | 
| 82 | 
         
            -
                def _distribution(self, obs: torch.Tensor) -> Distribution:
         
     | 
| 83 | 
         
            -
                    mu = self.mu_net(obs)
         
     | 
| 84 | 
         
            -
                    std = torch.exp(self.log_std)
         
     | 
| 85 | 
         
            -
                    return GaussianDistribution(mu, std)
         
     | 
| 86 | 
         
            -
             
     | 
| 87 | 
         
            -
                def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
         
     | 
| 88 | 
         
            -
                    pi = self._distribution(obs)
         
     | 
| 89 | 
         
            -
                    logp_a = None
         
     | 
| 90 | 
         
            -
                    entropy = None
         
     | 
| 91 | 
         
            -
                    if a is not None:
         
     | 
| 92 | 
         
            -
                        logp_a = pi.log_prob(a)
         
     | 
| 93 | 
         
            -
                        entropy = pi.entropy()
         
     | 
| 94 | 
         
            -
                    return PiForward(pi, logp_a, entropy)
         
     | 
| 95 | 
         
            -
             
     | 
| 96 | 
         
            -
             
     | 
| 97 | 
         
             
            class TanhBijector:
         
     | 
| 98 | 
         
             
                def __init__(self, epsilon: float = 1e-6) -> None:
         
     | 
| 99 | 
         
             
                    self.epsilon = epsilon
         
     | 
| 
         @@ -173,7 +87,8 @@ class StateDependentNoiseActorHead(Actor): 
     | 
|
| 173 | 
         
             
                def __init__(
         
     | 
| 174 | 
         
             
                    self,
         
     | 
| 175 | 
         
             
                    act_dim: int,
         
     | 
| 176 | 
         
            -
                     
     | 
| 
         | 
|
| 177 | 
         
             
                    activation: Type[nn.Module] = nn.Tanh,
         
     | 
| 178 | 
         
             
                    init_layers_orthogonal: bool = True,
         
     | 
| 179 | 
         
             
                    log_std_init: float = -0.5,
         
     | 
| 
         @@ -183,7 +98,7 @@ class StateDependentNoiseActorHead(Actor): 
     | 
|
| 183 | 
         
             
                ) -> None:
         
     | 
| 184 | 
         
             
                    super().__init__()
         
     | 
| 185 | 
         
             
                    self.act_dim = act_dim
         
     | 
| 186 | 
         
            -
                    layer_sizes =  
     | 
| 187 | 
         
             
                    if len(layer_sizes) == 2:
         
     | 
| 188 | 
         
             
                        self.latent_net = nn.Identity()
         
     | 
| 189 | 
         
             
                    elif len(layer_sizes) > 2:
         
     | 
| 
         @@ -193,8 +108,6 @@ class StateDependentNoiseActorHead(Actor): 
     | 
|
| 193 | 
         
             
                            output_activation=activation,
         
     | 
| 194 | 
         
             
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 195 | 
         
             
                        )
         
     | 
| 196 | 
         
            -
                    else:
         
     | 
| 197 | 
         
            -
                        raise ValueError("hidden_sizes must be of at least length 1")
         
     | 
| 198 | 
         
             
                    self.mu_net = mlp(
         
     | 
| 199 | 
         
             
                        layer_sizes[-2:],
         
     | 
| 200 | 
         
             
                        activation,
         
     | 
| 
         @@ -202,7 +115,7 @@ class StateDependentNoiseActorHead(Actor): 
     | 
|
| 202 | 
         
             
                        final_layer_gain=0.01,
         
     | 
| 203 | 
         
             
                    )
         
     | 
| 204 | 
         
             
                    self.full_std = full_std
         
     | 
| 205 | 
         
            -
                    std_dim = ( 
     | 
| 206 | 
         
             
                    self.log_std = nn.Parameter(
         
     | 
| 207 | 
         
             
                        torch.ones(std_dim, dtype=torch.float32) * log_std_init
         
     | 
| 208 | 
         
             
                    )
         
     | 
| 
         @@ -249,14 +162,17 @@ class StateDependentNoiseActorHead(Actor): 
     | 
|
| 249 | 
         
             
                        ones = ones.to(self.device)
         
     | 
| 250 | 
         
             
                    return ones * std
         
     | 
| 251 | 
         | 
| 252 | 
         
            -
                def forward( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 253 | 
         
             
                    pi = self._distribution(obs)
         
     | 
| 254 | 
         
            -
                     
     | 
| 255 | 
         
            -
                    entropy = None
         
     | 
| 256 | 
         
            -
                    if a is not None:
         
     | 
| 257 | 
         
            -
                        logp_a = pi.log_prob(a)
         
     | 
| 258 | 
         
            -
                        entropy = -logp_a if self.bijector else sum_independent_dims(pi.entropy())
         
     | 
| 259 | 
         
            -
                    return PiForward(pi, logp_a, entropy)
         
     | 
| 260 | 
         | 
| 261 | 
         
             
                def sample_weights(self, batch_size: int = 1) -> None:
         
     | 
| 262 | 
         
             
                    std = self._get_std()
         
     | 
| 
         @@ -265,46 +181,20 @@ class StateDependentNoiseActorHead(Actor): 
     | 
|
| 265 | 
         
             
                    self.exploration_mat = weights_dist.rsample()
         
     | 
| 266 | 
         
             
                    self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
         
     | 
| 267 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 268 | 
         | 
| 269 | 
         
            -
            def  
     | 
| 270 | 
         
            -
             
     | 
| 271 | 
         
            -
                 
     | 
| 272 | 
         
            -
             
     | 
| 273 | 
         
            -
             
     | 
| 274 | 
         
            -
             
     | 
| 275 | 
         
            -
             
     | 
| 276 | 
         
            -
             
     | 
| 277 | 
         
            -
             
     | 
| 278 | 
         
            -
             
     | 
| 279 | 
         
            -
             
     | 
| 280 | 
         
            -
                    action_space, Box
         
     | 
| 281 | 
         
            -
                ), "use_sde only valid if Box action_space"
         
     | 
| 282 | 
         
            -
                assert not squash_output or use_sde, "squash_output only valid if use_sde"
         
     | 
| 283 | 
         
            -
                if isinstance(action_space, Discrete):
         
     | 
| 284 | 
         
            -
                    return CategoricalActorHead(
         
     | 
| 285 | 
         
            -
                        action_space.n,
         
     | 
| 286 | 
         
            -
                        hidden_sizes=hidden_sizes,
         
     | 
| 287 | 
         
            -
                        activation=activation,
         
     | 
| 288 | 
         
            -
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 289 | 
         
            -
                    )
         
     | 
| 290 | 
         
            -
                elif isinstance(action_space, Box):
         
     | 
| 291 | 
         
            -
                    if use_sde:
         
     | 
| 292 | 
         
            -
                        return StateDependentNoiseActorHead(
         
     | 
| 293 | 
         
            -
                            action_space.shape[0],
         
     | 
| 294 | 
         
            -
                            hidden_sizes=hidden_sizes,
         
     | 
| 295 | 
         
            -
                            activation=activation,
         
     | 
| 296 | 
         
            -
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 297 | 
         
            -
                            log_std_init=log_std_init,
         
     | 
| 298 | 
         
            -
                            full_std=full_std,
         
     | 
| 299 | 
         
            -
                            squash_output=squash_output,
         
     | 
| 300 | 
         
            -
                        )
         
     | 
| 301 | 
         
            -
                    else:
         
     | 
| 302 | 
         
            -
                        return GaussianActorHead(
         
     | 
| 303 | 
         
            -
                            action_space.shape[0],
         
     | 
| 304 | 
         
            -
                            hidden_sizes=hidden_sizes,
         
     | 
| 305 | 
         
            -
                            activation=activation,
         
     | 
| 306 | 
         
            -
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 307 | 
         
            -
                            log_std_init=log_std_init,
         
     | 
| 308 | 
         
             
                        )
         
     | 
| 309 | 
         
            -
             
     | 
| 310 | 
         
            -
                    raise ValueError(f"Unsupported action space: {action_space}")
         
     | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Tuple, Type, TypeVar, Union
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         
             
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            from torch.distributions import Distribution, Normal
         
     | 
| 6 | 
         | 
| 7 | 
         
            +
            from rl_algo_impls.shared.actor.actor import Actor, PiForward
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 8 | 
         
             
            from rl_algo_impls.shared.module.module import mlp
         
     | 
| 9 | 
         | 
| 10 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 11 | 
         
             
            class TanhBijector:
         
     | 
| 12 | 
         
             
                def __init__(self, epsilon: float = 1e-6) -> None:
         
     | 
| 13 | 
         
             
                    self.epsilon = epsilon
         
     | 
| 
         | 
|
| 87 | 
         
             
                def __init__(
         
     | 
| 88 | 
         
             
                    self,
         
     | 
| 89 | 
         
             
                    act_dim: int,
         
     | 
| 90 | 
         
            +
                    in_dim: int,
         
     | 
| 91 | 
         
            +
                    hidden_sizes: Tuple[int, ...] = (32,),
         
     | 
| 92 | 
         
             
                    activation: Type[nn.Module] = nn.Tanh,
         
     | 
| 93 | 
         
             
                    init_layers_orthogonal: bool = True,
         
     | 
| 94 | 
         
             
                    log_std_init: float = -0.5,
         
     | 
| 
         | 
|
| 98 | 
         
             
                ) -> None:
         
     | 
| 99 | 
         
             
                    super().__init__()
         
     | 
| 100 | 
         
             
                    self.act_dim = act_dim
         
     | 
| 101 | 
         
            +
                    layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
         
     | 
| 102 | 
         
             
                    if len(layer_sizes) == 2:
         
     | 
| 103 | 
         
             
                        self.latent_net = nn.Identity()
         
     | 
| 104 | 
         
             
                    elif len(layer_sizes) > 2:
         
     | 
| 
         | 
|
| 108 | 
         
             
                            output_activation=activation,
         
     | 
| 109 | 
         
             
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 110 | 
         
             
                        )
         
     | 
| 
         | 
|
| 
         | 
|
| 111 | 
         
             
                    self.mu_net = mlp(
         
     | 
| 112 | 
         
             
                        layer_sizes[-2:],
         
     | 
| 113 | 
         
             
                        activation,
         
     | 
| 
         | 
|
| 115 | 
         
             
                        final_layer_gain=0.01,
         
     | 
| 116 | 
         
             
                    )
         
     | 
| 117 | 
         
             
                    self.full_std = full_std
         
     | 
| 118 | 
         
            +
                    std_dim = (layer_sizes[-2], act_dim if self.full_std else 1)
         
     | 
| 119 | 
         
             
                    self.log_std = nn.Parameter(
         
     | 
| 120 | 
         
             
                        torch.ones(std_dim, dtype=torch.float32) * log_std_init
         
     | 
| 121 | 
         
             
                    )
         
     | 
| 
         | 
|
| 162 | 
         
             
                        ones = ones.to(self.device)
         
     | 
| 163 | 
         
             
                    return ones * std
         
     | 
| 164 | 
         | 
| 165 | 
         
            +
                def forward(
         
     | 
| 166 | 
         
            +
                    self,
         
     | 
| 167 | 
         
            +
                    obs: torch.Tensor,
         
     | 
| 168 | 
         
            +
                    actions: Optional[torch.Tensor] = None,
         
     | 
| 169 | 
         
            +
                    action_masks: Optional[torch.Tensor] = None,
         
     | 
| 170 | 
         
            +
                ) -> PiForward:
         
     | 
| 171 | 
         
            +
                    assert (
         
     | 
| 172 | 
         
            +
                        not action_masks
         
     | 
| 173 | 
         
            +
                    ), f"{self.__class__.__name__} does not support action_masks"
         
     | 
| 174 | 
         
             
                    pi = self._distribution(obs)
         
     | 
| 175 | 
         
            +
                    return self.pi_forward(pi, actions)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 176 | 
         | 
| 177 | 
         
             
                def sample_weights(self, batch_size: int = 1) -> None:
         
     | 
| 178 | 
         
             
                    std = self._get_std()
         
     | 
| 
         | 
|
| 181 | 
         
             
                    self.exploration_mat = weights_dist.rsample()
         
     | 
| 182 | 
         
             
                    self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
         
     | 
| 183 | 
         | 
| 184 | 
         
            +
                @property
         
     | 
| 185 | 
         
            +
                def action_shape(self) -> Tuple[int, ...]:
         
     | 
| 186 | 
         
            +
                    return (self.act_dim,)
         
     | 
| 187 | 
         | 
| 188 | 
         
            +
                def pi_forward(
         
     | 
| 189 | 
         
            +
                    self, distribution: Distribution, actions: Optional[torch.Tensor] = None
         
     | 
| 190 | 
         
            +
                ) -> PiForward:
         
     | 
| 191 | 
         
            +
                    logp_a = None
         
     | 
| 192 | 
         
            +
                    entropy = None
         
     | 
| 193 | 
         
            +
                    if actions is not None:
         
     | 
| 194 | 
         
            +
                        logp_a = distribution.log_prob(actions)
         
     | 
| 195 | 
         
            +
                        entropy = (
         
     | 
| 196 | 
         
            +
                            -logp_a
         
     | 
| 197 | 
         
            +
                            if self.bijector
         
     | 
| 198 | 
         
            +
                            else sum_independent_dims(distribution.entropy())
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 199 | 
         
             
                        )
         
     | 
| 200 | 
         
            +
                    return PiForward(distribution, logp_a, entropy)
         
     | 
| 
         | 
    	
        rl_algo_impls/shared/callbacks/eval_callback.py
    CHANGED
    
    | 
         @@ -1,14 +1,15 @@ 
     | 
|
| 1 | 
         
             
            import itertools
         
     | 
| 2 | 
         
            -
            import numpy as np
         
     | 
| 3 | 
         
             
            import os
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
             
            from time import perf_counter
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 6 | 
         
             
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 7 | 
         
            -
            from typing import List, Optional, Union
         
     | 
| 8 | 
         | 
| 9 | 
         
             
            from rl_algo_impls.shared.callbacks.callback import Callback
         
     | 
| 10 | 
         
             
            from rl_algo_impls.shared.policy.policy import Policy
         
     | 
| 11 | 
         
             
            from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
         
     | 
| 
         | 
|
| 12 | 
         
             
            from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
         
     | 
| 13 | 
         
             
            from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
         
     | 
| 14 | 
         | 
| 
         @@ -20,6 +21,7 @@ class EvaluateAccumulator(EpisodeAccumulator): 
     | 
|
| 20 | 
         
             
                    goal_episodes: int,
         
     | 
| 21 | 
         
             
                    print_returns: bool = True,
         
     | 
| 22 | 
         
             
                    ignore_first_episode: bool = False,
         
     | 
| 
         | 
|
| 23 | 
         
             
                ):
         
     | 
| 24 | 
         
             
                    super().__init__(num_envs)
         
     | 
| 25 | 
         
             
                    self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
         
     | 
| 
         @@ -36,8 +38,11 @@ class EvaluateAccumulator(EpisodeAccumulator): 
     | 
|
| 36 | 
         
             
                        self.should_record_done = should_record_done
         
     | 
| 37 | 
         
             
                    else:
         
     | 
| 38 | 
         
             
                        self.should_record_done = lambda idx: True
         
     | 
| 
         | 
|
| 39 | 
         | 
| 40 | 
         
            -
                def on_done(self, ep_idx: int, episode: Episode) -> None:
         
     | 
| 
         | 
|
| 
         | 
|
| 41 | 
         
             
                    if (
         
     | 
| 42 | 
         
             
                        self.should_record_done(ep_idx)
         
     | 
| 43 | 
         
             
                        and len(self.completed_episodes_by_env_idx[ep_idx])
         
     | 
| 
         @@ -74,19 +79,29 @@ def evaluate( 
     | 
|
| 74 | 
         
             
                deterministic: bool = True,
         
     | 
| 75 | 
         
             
                print_returns: bool = True,
         
     | 
| 76 | 
         
             
                ignore_first_episode: bool = False,
         
     | 
| 
         | 
|
| 77 | 
         
             
            ) -> EpisodesStats:
         
     | 
| 78 | 
         
             
                policy.sync_normalization(env)
         
     | 
| 79 | 
         
             
                policy.eval()
         
     | 
| 80 | 
         | 
| 81 | 
         
             
                episodes = EvaluateAccumulator(
         
     | 
| 82 | 
         
            -
                    env.num_envs, 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 83 | 
         
             
                )
         
     | 
| 84 | 
         | 
| 85 | 
         
             
                obs = env.reset()
         
     | 
| 
         | 
|
| 86 | 
         
             
                while not episodes.is_done():
         
     | 
| 87 | 
         
            -
                    act = policy.act( 
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 90 | 
         
             
                    if render:
         
     | 
| 91 | 
         
             
                        env.render()
         
     | 
| 92 | 
         
             
                stats = EpisodesStats(episodes.episodes)
         
     | 
| 
         @@ -111,6 +126,7 @@ class EvalCallback(Callback): 
     | 
|
| 111 | 
         
             
                    best_video_dir: Optional[str] = None,
         
     | 
| 112 | 
         
             
                    max_video_length: int = 3600,
         
     | 
| 113 | 
         
             
                    ignore_first_episode: bool = False,
         
     | 
| 
         | 
|
| 114 | 
         
             
                ) -> None:
         
     | 
| 115 | 
         
             
                    super().__init__()
         
     | 
| 116 | 
         
             
                    self.policy = policy
         
     | 
| 
         @@ -133,8 +149,8 @@ class EvalCallback(Callback): 
     | 
|
| 133 | 
         
             
                        os.makedirs(best_video_dir, exist_ok=True)
         
     | 
| 134 | 
         
             
                    self.max_video_length = max_video_length
         
     | 
| 135 | 
         
             
                    self.best_video_base_path = None
         
     | 
| 136 | 
         
            -
             
     | 
| 137 | 
         
             
                    self.ignore_first_episode = ignore_first_episode
         
     | 
| 
         | 
|
| 138 | 
         | 
| 139 | 
         
             
                def on_step(self, timesteps_elapsed: int = 1) -> bool:
         
     | 
| 140 | 
         
             
                    super().on_step(timesteps_elapsed)
         
     | 
| 
         @@ -153,6 +169,7 @@ class EvalCallback(Callback): 
     | 
|
| 153 | 
         
             
                        deterministic=self.deterministic,
         
     | 
| 154 | 
         
             
                        print_returns=print_returns or False,
         
     | 
| 155 | 
         
             
                        ignore_first_episode=self.ignore_first_episode,
         
     | 
| 
         | 
|
| 156 | 
         
             
                    )
         
     | 
| 157 | 
         
             
                    end_time = perf_counter()
         
     | 
| 158 | 
         
             
                    self.tb_writer.add_scalar(
         
     | 
| 
         | 
|
| 1 | 
         
             
            import itertools
         
     | 
| 
         | 
|
| 2 | 
         
             
            import os
         
     | 
| 
         | 
|
| 3 | 
         
             
            from time import perf_counter
         
     | 
| 4 | 
         
            +
            from typing import Dict, List, Optional, Union
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
             
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 
         | 
|
| 8 | 
         | 
| 9 | 
         
             
            from rl_algo_impls.shared.callbacks.callback import Callback
         
     | 
| 10 | 
         
             
            from rl_algo_impls.shared.policy.policy import Policy
         
     | 
| 11 | 
         
             
            from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
         
     | 
| 12 | 
         
            +
            from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
         
     | 
| 13 | 
         
             
            from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
         
     | 
| 14 | 
         
             
            from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
         
     | 
| 15 | 
         | 
| 
         | 
|
| 21 | 
         
             
                    goal_episodes: int,
         
     | 
| 22 | 
         
             
                    print_returns: bool = True,
         
     | 
| 23 | 
         
             
                    ignore_first_episode: bool = False,
         
     | 
| 24 | 
         
            +
                    additional_keys_to_log: Optional[List[str]] = None,
         
     | 
| 25 | 
         
             
                ):
         
     | 
| 26 | 
         
             
                    super().__init__(num_envs)
         
     | 
| 27 | 
         
             
                    self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
         
     | 
| 
         | 
|
| 38 | 
         
             
                        self.should_record_done = should_record_done
         
     | 
| 39 | 
         
             
                    else:
         
     | 
| 40 | 
         
             
                        self.should_record_done = lambda idx: True
         
     | 
| 41 | 
         
            +
                    self.additional_keys_to_log = additional_keys_to_log
         
     | 
| 42 | 
         | 
| 43 | 
         
            +
                def on_done(self, ep_idx: int, episode: Episode, info: Dict) -> None:
         
     | 
| 44 | 
         
            +
                    if self.additional_keys_to_log:
         
     | 
| 45 | 
         
            +
                        episode.info = {k: info[k] for k in self.additional_keys_to_log}
         
     | 
| 46 | 
         
             
                    if (
         
     | 
| 47 | 
         
             
                        self.should_record_done(ep_idx)
         
     | 
| 48 | 
         
             
                        and len(self.completed_episodes_by_env_idx[ep_idx])
         
     | 
| 
         | 
|
| 79 | 
         
             
                deterministic: bool = True,
         
     | 
| 80 | 
         
             
                print_returns: bool = True,
         
     | 
| 81 | 
         
             
                ignore_first_episode: bool = False,
         
     | 
| 82 | 
         
            +
                additional_keys_to_log: Optional[List[str]] = None,
         
     | 
| 83 | 
         
             
            ) -> EpisodesStats:
         
     | 
| 84 | 
         
             
                policy.sync_normalization(env)
         
     | 
| 85 | 
         
             
                policy.eval()
         
     | 
| 86 | 
         | 
| 87 | 
         
             
                episodes = EvaluateAccumulator(
         
     | 
| 88 | 
         
            +
                    env.num_envs,
         
     | 
| 89 | 
         
            +
                    n_episodes,
         
     | 
| 90 | 
         
            +
                    print_returns,
         
     | 
| 91 | 
         
            +
                    ignore_first_episode,
         
     | 
| 92 | 
         
            +
                    additional_keys_to_log=additional_keys_to_log,
         
     | 
| 93 | 
         
             
                )
         
     | 
| 94 | 
         | 
| 95 | 
         
             
                obs = env.reset()
         
     | 
| 96 | 
         
            +
                action_masker = find_action_masker(env)
         
     | 
| 97 | 
         
             
                while not episodes.is_done():
         
     | 
| 98 | 
         
            +
                    act = policy.act(
         
     | 
| 99 | 
         
            +
                        obs,
         
     | 
| 100 | 
         
            +
                        deterministic=deterministic,
         
     | 
| 101 | 
         
            +
                        action_masks=action_masker.action_masks() if action_masker else None,
         
     | 
| 102 | 
         
            +
                    )
         
     | 
| 103 | 
         
            +
                    obs, rew, done, info = env.step(act)
         
     | 
| 104 | 
         
            +
                    episodes.step(rew, done, info)
         
     | 
| 105 | 
         
             
                    if render:
         
     | 
| 106 | 
         
             
                        env.render()
         
     | 
| 107 | 
         
             
                stats = EpisodesStats(episodes.episodes)
         
     | 
| 
         | 
|
| 126 | 
         
             
                    best_video_dir: Optional[str] = None,
         
     | 
| 127 | 
         
             
                    max_video_length: int = 3600,
         
     | 
| 128 | 
         
             
                    ignore_first_episode: bool = False,
         
     | 
| 129 | 
         
            +
                    additional_keys_to_log: Optional[List[str]] = None,
         
     | 
| 130 | 
         
             
                ) -> None:
         
     | 
| 131 | 
         
             
                    super().__init__()
         
     | 
| 132 | 
         
             
                    self.policy = policy
         
     | 
| 
         | 
|
| 149 | 
         
             
                        os.makedirs(best_video_dir, exist_ok=True)
         
     | 
| 150 | 
         
             
                    self.max_video_length = max_video_length
         
     | 
| 151 | 
         
             
                    self.best_video_base_path = None
         
     | 
| 
         | 
|
| 152 | 
         
             
                    self.ignore_first_episode = ignore_first_episode
         
     | 
| 153 | 
         
            +
                    self.additional_keys_to_log = additional_keys_to_log
         
     | 
| 154 | 
         | 
| 155 | 
         
             
                def on_step(self, timesteps_elapsed: int = 1) -> bool:
         
     | 
| 156 | 
         
             
                    super().on_step(timesteps_elapsed)
         
     | 
| 
         | 
|
| 169 | 
         
             
                        deterministic=self.deterministic,
         
     | 
| 170 | 
         
             
                        print_returns=print_returns or False,
         
     | 
| 171 | 
         
             
                        ignore_first_episode=self.ignore_first_episode,
         
     | 
| 172 | 
         
            +
                        additional_keys_to_log=self.additional_keys_to_log,
         
     | 
| 173 | 
         
             
                    )
         
     | 
| 174 | 
         
             
                    end_time = perf_counter()
         
     | 
| 175 | 
         
             
                    self.tb_writer.add_scalar(
         
     | 
    	
        rl_algo_impls/shared/encoder/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from rl_algo_impls.shared.encoder.cnn import EncoderOutDim
         
     | 
| 2 | 
         
            +
            from rl_algo_impls.shared.encoder.encoder import Encoder
         
     | 
    	
        rl_algo_impls/shared/encoder/cnn.py
    ADDED
    
    | 
         @@ -0,0 +1,72 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 2 | 
         
            +
            from typing import Optional, Tuple, Type, Union
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import gym
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import torch.nn as nn
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from rl_algo_impls.shared.module.module import layer_init
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            EncoderOutDim = Union[int, Tuple[int, ...]]
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class CnnEncoder(nn.Module, ABC):
         
     | 
| 15 | 
         
            +
                @abstractmethod
         
     | 
| 16 | 
         
            +
                def __init__(
         
     | 
| 17 | 
         
            +
                    self,
         
     | 
| 18 | 
         
            +
                    obs_space: gym.Space,
         
     | 
| 19 | 
         
            +
                    **kwargs,
         
     | 
| 20 | 
         
            +
                ) -> None:
         
     | 
| 21 | 
         
            +
                    super().__init__()
         
     | 
| 22 | 
         
            +
                    self.range_size = np.max(obs_space.high) - np.min(obs_space.low)  # type: ignore
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def preprocess(self, obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 25 | 
         
            +
                    if len(obs.shape) == 3:
         
     | 
| 26 | 
         
            +
                        obs = obs.unsqueeze(0)
         
     | 
| 27 | 
         
            +
                    return obs.float() / self.range_size
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def forward(self, obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 30 | 
         
            +
                    return self.preprocess(obs)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                @property
         
     | 
| 33 | 
         
            +
                @abstractmethod
         
     | 
| 34 | 
         
            +
                def out_dim(self) -> EncoderOutDim:
         
     | 
| 35 | 
         
            +
                    ...
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            class FlattenedCnnEncoder(CnnEncoder):
         
     | 
| 39 | 
         
            +
                def __init__(
         
     | 
| 40 | 
         
            +
                    self,
         
     | 
| 41 | 
         
            +
                    obs_space: gym.Space,
         
     | 
| 42 | 
         
            +
                    activation: Type[nn.Module],
         
     | 
| 43 | 
         
            +
                    linear_init_layers_orthogonal: bool,
         
     | 
| 44 | 
         
            +
                    cnn_flatten_dim: int,
         
     | 
| 45 | 
         
            +
                    cnn: nn.Module,
         
     | 
| 46 | 
         
            +
                    **kwargs,
         
     | 
| 47 | 
         
            +
                ) -> None:
         
     | 
| 48 | 
         
            +
                    super().__init__(obs_space, **kwargs)
         
     | 
| 49 | 
         
            +
                    self.cnn = cnn
         
     | 
| 50 | 
         
            +
                    self.flattened_dim = cnn_flatten_dim
         
     | 
| 51 | 
         
            +
                    with torch.no_grad():
         
     | 
| 52 | 
         
            +
                        cnn_out = torch.flatten(
         
     | 
| 53 | 
         
            +
                            cnn(self.preprocess(torch.as_tensor(obs_space.sample()))), start_dim=1
         
     | 
| 54 | 
         
            +
                        )
         
     | 
| 55 | 
         
            +
                    self.fc = nn.Sequential(
         
     | 
| 56 | 
         
            +
                        nn.Flatten(),
         
     | 
| 57 | 
         
            +
                        layer_init(
         
     | 
| 58 | 
         
            +
                            nn.Linear(cnn_out.shape[1], cnn_flatten_dim),
         
     | 
| 59 | 
         
            +
                            linear_init_layers_orthogonal,
         
     | 
| 60 | 
         
            +
                        ),
         
     | 
| 61 | 
         
            +
                        activation(),
         
     | 
| 62 | 
         
            +
                    )
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                def forward(self, obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 65 | 
         
            +
                    x = super().forward(obs)
         
     | 
| 66 | 
         
            +
                    x = self.cnn(x)
         
     | 
| 67 | 
         
            +
                    x = self.fc(x)
         
     | 
| 68 | 
         
            +
                    return x
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                @property
         
     | 
| 71 | 
         
            +
                def out_dim(self) -> EncoderOutDim:
         
     | 
| 72 | 
         
            +
                    return self.flattened_dim
         
     | 
    	
        rl_algo_impls/shared/encoder/encoder.py
    ADDED
    
    | 
         @@ -0,0 +1,73 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Dict, Optional, Sequence, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import gym
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 7 | 
         
            +
            from gym.spaces import Box, Discrete
         
     | 
| 8 | 
         
            +
            from stable_baselines3.common.preprocessing import get_flattened_obs_dim
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from rl_algo_impls.shared.encoder.cnn import CnnEncoder
         
     | 
| 11 | 
         
            +
            from rl_algo_impls.shared.encoder.gridnet_encoder import GridnetEncoder
         
     | 
| 12 | 
         
            +
            from rl_algo_impls.shared.encoder.impala_cnn import ImpalaCnn
         
     | 
| 13 | 
         
            +
            from rl_algo_impls.shared.encoder.microrts_cnn import MicrortsCnn
         
     | 
| 14 | 
         
            +
            from rl_algo_impls.shared.encoder.nature_cnn import NatureCnn
         
     | 
| 15 | 
         
            +
            from rl_algo_impls.shared.module.module import layer_init
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnEncoder]] = {
         
     | 
| 18 | 
         
            +
                "nature": NatureCnn,
         
     | 
| 19 | 
         
            +
                "impala": ImpalaCnn,
         
     | 
| 20 | 
         
            +
                "microrts": MicrortsCnn,
         
     | 
| 21 | 
         
            +
                "gridnet_encoder": GridnetEncoder,
         
     | 
| 22 | 
         
            +
            }
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            class Encoder(nn.Module):
         
     | 
| 26 | 
         
            +
                def __init__(
         
     | 
| 27 | 
         
            +
                    self,
         
     | 
| 28 | 
         
            +
                    obs_space: gym.Space,
         
     | 
| 29 | 
         
            +
                    activation: Type[nn.Module],
         
     | 
| 30 | 
         
            +
                    init_layers_orthogonal: bool = False,
         
     | 
| 31 | 
         
            +
                    cnn_flatten_dim: int = 512,
         
     | 
| 32 | 
         
            +
                    cnn_style: str = "nature",
         
     | 
| 33 | 
         
            +
                    cnn_layers_init_orthogonal: Optional[bool] = None,
         
     | 
| 34 | 
         
            +
                    impala_channels: Sequence[int] = (16, 32, 32),
         
     | 
| 35 | 
         
            +
                ) -> None:
         
     | 
| 36 | 
         
            +
                    super().__init__()
         
     | 
| 37 | 
         
            +
                    if isinstance(obs_space, Box):
         
     | 
| 38 | 
         
            +
                        # Conv2D: (channels, height, width)
         
     | 
| 39 | 
         
            +
                        if len(obs_space.shape) == 3:  # type: ignore
         
     | 
| 40 | 
         
            +
                            self.preprocess = None
         
     | 
| 41 | 
         
            +
                            cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
         
     | 
| 42 | 
         
            +
                                obs_space,
         
     | 
| 43 | 
         
            +
                                activation=activation,
         
     | 
| 44 | 
         
            +
                                cnn_init_layers_orthogonal=cnn_layers_init_orthogonal,
         
     | 
| 45 | 
         
            +
                                linear_init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 46 | 
         
            +
                                cnn_flatten_dim=cnn_flatten_dim,
         
     | 
| 47 | 
         
            +
                                impala_channels=impala_channels,
         
     | 
| 48 | 
         
            +
                            )
         
     | 
| 49 | 
         
            +
                            self.feature_extractor = cnn
         
     | 
| 50 | 
         
            +
                            self.out_dim = cnn.out_dim
         
     | 
| 51 | 
         
            +
                        elif len(obs_space.shape) == 1:  # type: ignore
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                            def preprocess(obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 54 | 
         
            +
                                if len(obs.shape) == 1:
         
     | 
| 55 | 
         
            +
                                    obs = obs.unsqueeze(0)
         
     | 
| 56 | 
         
            +
                                return obs.float()
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                            self.preprocess = preprocess
         
     | 
| 59 | 
         
            +
                            self.feature_extractor = nn.Flatten()
         
     | 
| 60 | 
         
            +
                            self.out_dim = get_flattened_obs_dim(obs_space)
         
     | 
| 61 | 
         
            +
                        else:
         
     | 
| 62 | 
         
            +
                            raise ValueError(f"Unsupported observation space: {obs_space}")
         
     | 
| 63 | 
         
            +
                    elif isinstance(obs_space, Discrete):
         
     | 
| 64 | 
         
            +
                        self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
         
     | 
| 65 | 
         
            +
                        self.feature_extractor = nn.Flatten()
         
     | 
| 66 | 
         
            +
                        self.out_dim = obs_space.n  # type: ignore
         
     | 
| 67 | 
         
            +
                    else:
         
     | 
| 68 | 
         
            +
                        raise NotImplementedError
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                def forward(self, obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 71 | 
         
            +
                    if self.preprocess:
         
     | 
| 72 | 
         
            +
                        obs = self.preprocess(obs)
         
     | 
| 73 | 
         
            +
                    return self.feature_extractor(obs)
         
     | 
    	
        rl_algo_impls/shared/encoder/gridnet_encoder.py
    ADDED
    
    | 
         @@ -0,0 +1,64 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Tuple, Type, Union
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import gym
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from rl_algo_impls.shared.encoder.cnn import CnnEncoder, EncoderOutDim
         
     | 
| 8 | 
         
            +
            from rl_algo_impls.shared.module.module import layer_init
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class GridnetEncoder(CnnEncoder):
         
     | 
| 12 | 
         
            +
                """
         
     | 
| 13 | 
         
            +
                Encoder for encoder-decoder for Gym-MicroRTS
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __init__(
         
     | 
| 17 | 
         
            +
                    self,
         
     | 
| 18 | 
         
            +
                    obs_space: gym.Space,
         
     | 
| 19 | 
         
            +
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 20 | 
         
            +
                    cnn_init_layers_orthogonal: Optional[bool] = None,
         
     | 
| 21 | 
         
            +
                    **kwargs
         
     | 
| 22 | 
         
            +
                ) -> None:
         
     | 
| 23 | 
         
            +
                    if cnn_init_layers_orthogonal is None:
         
     | 
| 24 | 
         
            +
                        cnn_init_layers_orthogonal = True
         
     | 
| 25 | 
         
            +
                    super().__init__(obs_space, **kwargs)
         
     | 
| 26 | 
         
            +
                    in_channels = obs_space.shape[0]  # type: ignore
         
     | 
| 27 | 
         
            +
                    self.encoder = nn.Sequential(
         
     | 
| 28 | 
         
            +
                        layer_init(
         
     | 
| 29 | 
         
            +
                            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
         
     | 
| 30 | 
         
            +
                            cnn_init_layers_orthogonal,
         
     | 
| 31 | 
         
            +
                        ),
         
     | 
| 32 | 
         
            +
                        nn.MaxPool2d(3, stride=2, padding=1),
         
     | 
| 33 | 
         
            +
                        activation(),
         
     | 
| 34 | 
         
            +
                        layer_init(
         
     | 
| 35 | 
         
            +
                            nn.Conv2d(32, 64, kernel_size=3, padding=1),
         
     | 
| 36 | 
         
            +
                            cnn_init_layers_orthogonal,
         
     | 
| 37 | 
         
            +
                        ),
         
     | 
| 38 | 
         
            +
                        nn.MaxPool2d(3, stride=2, padding=1),
         
     | 
| 39 | 
         
            +
                        activation(),
         
     | 
| 40 | 
         
            +
                        layer_init(
         
     | 
| 41 | 
         
            +
                            nn.Conv2d(64, 128, kernel_size=3, padding=1),
         
     | 
| 42 | 
         
            +
                            cnn_init_layers_orthogonal,
         
     | 
| 43 | 
         
            +
                        ),
         
     | 
| 44 | 
         
            +
                        nn.MaxPool2d(3, stride=2, padding=1),
         
     | 
| 45 | 
         
            +
                        activation(),
         
     | 
| 46 | 
         
            +
                        layer_init(
         
     | 
| 47 | 
         
            +
                            nn.Conv2d(128, 256, kernel_size=3, padding=1),
         
     | 
| 48 | 
         
            +
                            cnn_init_layers_orthogonal,
         
     | 
| 49 | 
         
            +
                        ),
         
     | 
| 50 | 
         
            +
                        nn.MaxPool2d(3, stride=2, padding=1),
         
     | 
| 51 | 
         
            +
                        activation(),
         
     | 
| 52 | 
         
            +
                    )
         
     | 
| 53 | 
         
            +
                    with torch.no_grad():
         
     | 
| 54 | 
         
            +
                        encoder_out = self.encoder(
         
     | 
| 55 | 
         
            +
                            self.preprocess(torch.as_tensor(obs_space.sample()))  # type: ignore
         
     | 
| 56 | 
         
            +
                        )
         
     | 
| 57 | 
         
            +
                        self._out_dim = encoder_out.shape[1:]
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                def forward(self, obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 60 | 
         
            +
                    return self.encoder(super().forward(obs))
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                @property
         
     | 
| 63 | 
         
            +
                def out_dim(self) -> EncoderOutDim:
         
     | 
| 64 | 
         
            +
                    return self._out_dim
         
     | 
    	
        rl_algo_impls/shared/encoder/impala_cnn.py
    ADDED
    
    | 
         @@ -0,0 +1,92 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Sequence, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import gym
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
         
     | 
| 8 | 
         
            +
            from rl_algo_impls.shared.module.module import layer_init
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class ResidualBlock(nn.Module):
         
     | 
| 12 | 
         
            +
                def __init__(
         
     | 
| 13 | 
         
            +
                    self,
         
     | 
| 14 | 
         
            +
                    channels: int,
         
     | 
| 15 | 
         
            +
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 16 | 
         
            +
                    init_layers_orthogonal: bool = False,
         
     | 
| 17 | 
         
            +
                ) -> None:
         
     | 
| 18 | 
         
            +
                    super().__init__()
         
     | 
| 19 | 
         
            +
                    self.residual = nn.Sequential(
         
     | 
| 20 | 
         
            +
                        activation(),
         
     | 
| 21 | 
         
            +
                        layer_init(
         
     | 
| 22 | 
         
            +
                            nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
         
     | 
| 23 | 
         
            +
                        ),
         
     | 
| 24 | 
         
            +
                        activation(),
         
     | 
| 25 | 
         
            +
                        layer_init(
         
     | 
| 26 | 
         
            +
                            nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
         
     | 
| 27 | 
         
            +
                        ),
         
     | 
| 28 | 
         
            +
                    )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 31 | 
         
            +
                    return x + self.residual(x)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            class ConvSequence(nn.Module):
         
     | 
| 35 | 
         
            +
                def __init__(
         
     | 
| 36 | 
         
            +
                    self,
         
     | 
| 37 | 
         
            +
                    in_channels: int,
         
     | 
| 38 | 
         
            +
                    out_channels: int,
         
     | 
| 39 | 
         
            +
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 40 | 
         
            +
                    init_layers_orthogonal: bool = False,
         
     | 
| 41 | 
         
            +
                ) -> None:
         
     | 
| 42 | 
         
            +
                    super().__init__()
         
     | 
| 43 | 
         
            +
                    self.seq = nn.Sequential(
         
     | 
| 44 | 
         
            +
                        layer_init(
         
     | 
| 45 | 
         
            +
                            nn.Conv2d(in_channels, out_channels, 3, padding=1),
         
     | 
| 46 | 
         
            +
                            init_layers_orthogonal,
         
     | 
| 47 | 
         
            +
                        ),
         
     | 
| 48 | 
         
            +
                        nn.MaxPool2d(3, stride=2, padding=1),
         
     | 
| 49 | 
         
            +
                        ResidualBlock(out_channels, activation, init_layers_orthogonal),
         
     | 
| 50 | 
         
            +
                        ResidualBlock(out_channels, activation, init_layers_orthogonal),
         
     | 
| 51 | 
         
            +
                    )
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 54 | 
         
            +
                    return self.seq(x)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            class ImpalaCnn(FlattenedCnnEncoder):
         
     | 
| 58 | 
         
            +
                """
         
     | 
| 59 | 
         
            +
                IMPALA-style CNN architecture
         
     | 
| 60 | 
         
            +
                """
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                def __init__(
         
     | 
| 63 | 
         
            +
                    self,
         
     | 
| 64 | 
         
            +
                    obs_space: gym.Space,
         
     | 
| 65 | 
         
            +
                    activation: Type[nn.Module],
         
     | 
| 66 | 
         
            +
                    cnn_init_layers_orthogonal: Optional[bool],
         
     | 
| 67 | 
         
            +
                    linear_init_layers_orthogonal: bool,
         
     | 
| 68 | 
         
            +
                    cnn_flatten_dim: int,
         
     | 
| 69 | 
         
            +
                    impala_channels: Sequence[int] = (16, 32, 32),
         
     | 
| 70 | 
         
            +
                    **kwargs,
         
     | 
| 71 | 
         
            +
                ) -> None:
         
     | 
| 72 | 
         
            +
                    if cnn_init_layers_orthogonal is None:
         
     | 
| 73 | 
         
            +
                        cnn_init_layers_orthogonal = False
         
     | 
| 74 | 
         
            +
                    in_channels = obs_space.shape[0]  # type: ignore
         
     | 
| 75 | 
         
            +
                    sequences = []
         
     | 
| 76 | 
         
            +
                    for out_channels in impala_channels:
         
     | 
| 77 | 
         
            +
                        sequences.append(
         
     | 
| 78 | 
         
            +
                            ConvSequence(
         
     | 
| 79 | 
         
            +
                                in_channels, out_channels, activation, cnn_init_layers_orthogonal
         
     | 
| 80 | 
         
            +
                            )
         
     | 
| 81 | 
         
            +
                        )
         
     | 
| 82 | 
         
            +
                        in_channels = out_channels
         
     | 
| 83 | 
         
            +
                    sequences.append(activation())
         
     | 
| 84 | 
         
            +
                    cnn = nn.Sequential(*sequences)
         
     | 
| 85 | 
         
            +
                    super().__init__(
         
     | 
| 86 | 
         
            +
                        obs_space,
         
     | 
| 87 | 
         
            +
                        activation,
         
     | 
| 88 | 
         
            +
                        linear_init_layers_orthogonal,
         
     | 
| 89 | 
         
            +
                        cnn_flatten_dim,
         
     | 
| 90 | 
         
            +
                        cnn,
         
     | 
| 91 | 
         
            +
                        **kwargs,
         
     | 
| 92 | 
         
            +
                    )
         
     | 
    	
        rl_algo_impls/shared/encoder/microrts_cnn.py
    ADDED
    
    | 
         @@ -0,0 +1,45 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import gym
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
         
     | 
| 8 | 
         
            +
            from rl_algo_impls.shared.module.module import layer_init
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class MicrortsCnn(FlattenedCnnEncoder):
         
     | 
| 12 | 
         
            +
                """
         
     | 
| 13 | 
         
            +
                Base CNN architecture for Gym-MicroRTS
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __init__(
         
     | 
| 17 | 
         
            +
                    self,
         
     | 
| 18 | 
         
            +
                    obs_space: gym.Space,
         
     | 
| 19 | 
         
            +
                    activation: Type[nn.Module],
         
     | 
| 20 | 
         
            +
                    cnn_init_layers_orthogonal: Optional[bool],
         
     | 
| 21 | 
         
            +
                    linear_init_layers_orthogonal: bool,
         
     | 
| 22 | 
         
            +
                    cnn_flatten_dim: int,
         
     | 
| 23 | 
         
            +
                    **kwargs,
         
     | 
| 24 | 
         
            +
                ) -> None:
         
     | 
| 25 | 
         
            +
                    if cnn_init_layers_orthogonal is None:
         
     | 
| 26 | 
         
            +
                        cnn_init_layers_orthogonal = True
         
     | 
| 27 | 
         
            +
                    in_channels = obs_space.shape[0]  # type: ignore
         
     | 
| 28 | 
         
            +
                    cnn = nn.Sequential(
         
     | 
| 29 | 
         
            +
                        layer_init(
         
     | 
| 30 | 
         
            +
                            nn.Conv2d(in_channels, 16, kernel_size=3, stride=2),
         
     | 
| 31 | 
         
            +
                            cnn_init_layers_orthogonal,
         
     | 
| 32 | 
         
            +
                        ),
         
     | 
| 33 | 
         
            +
                        activation(),
         
     | 
| 34 | 
         
            +
                        layer_init(nn.Conv2d(16, 32, kernel_size=2), cnn_init_layers_orthogonal),
         
     | 
| 35 | 
         
            +
                        activation(),
         
     | 
| 36 | 
         
            +
                        nn.Flatten(),
         
     | 
| 37 | 
         
            +
                    )
         
     | 
| 38 | 
         
            +
                    super().__init__(
         
     | 
| 39 | 
         
            +
                        obs_space,
         
     | 
| 40 | 
         
            +
                        activation,
         
     | 
| 41 | 
         
            +
                        linear_init_layers_orthogonal,
         
     | 
| 42 | 
         
            +
                        cnn_flatten_dim,
         
     | 
| 43 | 
         
            +
                        cnn,
         
     | 
| 44 | 
         
            +
                        **kwargs,
         
     | 
| 45 | 
         
            +
                    )
         
     | 
    	
        rl_algo_impls/shared/encoder/nature_cnn.py
    ADDED
    
    | 
         @@ -0,0 +1,53 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import gym
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
         
     | 
| 7 | 
         
            +
            from rl_algo_impls.shared.module.module import layer_init
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class NatureCnn(FlattenedCnnEncoder):
         
     | 
| 11 | 
         
            +
                """
         
     | 
| 12 | 
         
            +
                CNN from DQN Nature paper: Mnih, Volodymyr, et al.
         
     | 
| 13 | 
         
            +
                "Human-level control through deep reinforcement learning."
         
     | 
| 14 | 
         
            +
                Nature 518.7540 (2015): 529-533.
         
     | 
| 15 | 
         
            +
                """
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def __init__(
         
     | 
| 18 | 
         
            +
                    self,
         
     | 
| 19 | 
         
            +
                    obs_space: gym.Space,
         
     | 
| 20 | 
         
            +
                    activation: Type[nn.Module],
         
     | 
| 21 | 
         
            +
                    cnn_init_layers_orthogonal: Optional[bool],
         
     | 
| 22 | 
         
            +
                    linear_init_layers_orthogonal: bool,
         
     | 
| 23 | 
         
            +
                    cnn_flatten_dim: int,
         
     | 
| 24 | 
         
            +
                    **kwargs,
         
     | 
| 25 | 
         
            +
                ) -> None:
         
     | 
| 26 | 
         
            +
                    if cnn_init_layers_orthogonal is None:
         
     | 
| 27 | 
         
            +
                        cnn_init_layers_orthogonal = True
         
     | 
| 28 | 
         
            +
                    in_channels = obs_space.shape[0]  # type: ignore
         
     | 
| 29 | 
         
            +
                    cnn = nn.Sequential(
         
     | 
| 30 | 
         
            +
                        layer_init(
         
     | 
| 31 | 
         
            +
                            nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
         
     | 
| 32 | 
         
            +
                            cnn_init_layers_orthogonal,
         
     | 
| 33 | 
         
            +
                        ),
         
     | 
| 34 | 
         
            +
                        activation(),
         
     | 
| 35 | 
         
            +
                        layer_init(
         
     | 
| 36 | 
         
            +
                            nn.Conv2d(32, 64, kernel_size=4, stride=2),
         
     | 
| 37 | 
         
            +
                            cnn_init_layers_orthogonal,
         
     | 
| 38 | 
         
            +
                        ),
         
     | 
| 39 | 
         
            +
                        activation(),
         
     | 
| 40 | 
         
            +
                        layer_init(
         
     | 
| 41 | 
         
            +
                            nn.Conv2d(64, 64, kernel_size=3, stride=1),
         
     | 
| 42 | 
         
            +
                            cnn_init_layers_orthogonal,
         
     | 
| 43 | 
         
            +
                        ),
         
     | 
| 44 | 
         
            +
                        activation(),
         
     | 
| 45 | 
         
            +
                    )
         
     | 
| 46 | 
         
            +
                    super().__init__(
         
     | 
| 47 | 
         
            +
                        obs_space,
         
     | 
| 48 | 
         
            +
                        activation,
         
     | 
| 49 | 
         
            +
                        linear_init_layers_orthogonal,
         
     | 
| 50 | 
         
            +
                        cnn_flatten_dim,
         
     | 
| 51 | 
         
            +
                        cnn,
         
     | 
| 52 | 
         
            +
                        **kwargs,
         
     | 
| 53 | 
         
            +
                    )
         
     | 
    	
        rl_algo_impls/shared/gae.py
    CHANGED
    
    | 
         @@ -5,6 +5,7 @@ from typing import NamedTuple, Sequence 
     | 
|
| 5 | 
         | 
| 6 | 
         
             
            from rl_algo_impls.shared.policy.on_policy import OnPolicy
         
     | 
| 7 | 
         
             
            from rl_algo_impls.shared.trajectory import Trajectory
         
     | 
| 
         | 
|
| 8 | 
         | 
| 9 | 
         | 
| 10 | 
         
             
            class RtgAdvantage(NamedTuple):
         
     | 
| 
         @@ -19,7 +20,7 @@ def discounted_cumsum(x: np.ndarray, gamma: float) -> np.ndarray: 
     | 
|
| 19 | 
         
             
                return dc
         
     | 
| 20 | 
         | 
| 21 | 
         | 
| 22 | 
         
            -
            def  
     | 
| 23 | 
         
             
                trajectories: Sequence[Trajectory],
         
     | 
| 24 | 
         
             
                policy: OnPolicy,
         
     | 
| 25 | 
         
             
                gamma: float,
         
     | 
| 
         @@ -40,7 +41,7 @@ def compute_advantage( 
     | 
|
| 40 | 
         
             
                )
         
     | 
| 41 | 
         | 
| 42 | 
         | 
| 43 | 
         
            -
            def  
     | 
| 44 | 
         
             
                trajectories: Sequence[Trajectory],
         
     | 
| 45 | 
         
             
                policy: OnPolicy,
         
     | 
| 46 | 
         
             
                gamma: float,
         
     | 
| 
         @@ -65,3 +66,29 @@ def compute_rtg_and_advantage( 
     | 
|
| 65 | 
         
             
                    ),
         
     | 
| 66 | 
         
             
                    torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
         
     | 
| 67 | 
         
             
                )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 5 | 
         | 
| 6 | 
         
             
            from rl_algo_impls.shared.policy.on_policy import OnPolicy
         
     | 
| 7 | 
         
             
            from rl_algo_impls.shared.trajectory import Trajectory
         
     | 
| 8 | 
         
            +
            from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvObs
         
     | 
| 9 | 
         | 
| 10 | 
         | 
| 11 | 
         
             
            class RtgAdvantage(NamedTuple):
         
     | 
| 
         | 
|
| 20 | 
         
             
                return dc
         
     | 
| 21 | 
         | 
| 22 | 
         | 
| 23 | 
         
            +
            def compute_advantage_from_trajectories(
         
     | 
| 24 | 
         
             
                trajectories: Sequence[Trajectory],
         
     | 
| 25 | 
         
             
                policy: OnPolicy,
         
     | 
| 26 | 
         
             
                gamma: float,
         
     | 
| 
         | 
|
| 41 | 
         
             
                )
         
     | 
| 42 | 
         | 
| 43 | 
         | 
| 44 | 
         
            +
            def compute_rtg_and_advantage_from_trajectories(
         
     | 
| 45 | 
         
             
                trajectories: Sequence[Trajectory],
         
     | 
| 46 | 
         
             
                policy: OnPolicy,
         
     | 
| 47 | 
         
             
                gamma: float,
         
     | 
| 
         | 
|
| 66 | 
         
             
                    ),
         
     | 
| 67 | 
         
             
                    torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
         
     | 
| 68 | 
         
             
                )
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def compute_advantages(
         
     | 
| 72 | 
         
            +
                rewards: np.ndarray,
         
     | 
| 73 | 
         
            +
                values: np.ndarray,
         
     | 
| 74 | 
         
            +
                episode_starts: np.ndarray,
         
     | 
| 75 | 
         
            +
                next_episode_starts: np.ndarray,
         
     | 
| 76 | 
         
            +
                next_obs: VecEnvObs,
         
     | 
| 77 | 
         
            +
                policy: OnPolicy,
         
     | 
| 78 | 
         
            +
                gamma: float,
         
     | 
| 79 | 
         
            +
                gae_lambda: float,
         
     | 
| 80 | 
         
            +
            ) -> np.ndarray:
         
     | 
| 81 | 
         
            +
                advantages = np.zeros_like(rewards)
         
     | 
| 82 | 
         
            +
                last_gae_lam = 0
         
     | 
| 83 | 
         
            +
                n_steps = advantages.shape[0]
         
     | 
| 84 | 
         
            +
                for t in reversed(range(n_steps)):
         
     | 
| 85 | 
         
            +
                    if t == n_steps - 1:
         
     | 
| 86 | 
         
            +
                        next_nonterminal = 1.0 - next_episode_starts
         
     | 
| 87 | 
         
            +
                        next_value = policy.value(next_obs)
         
     | 
| 88 | 
         
            +
                    else:
         
     | 
| 89 | 
         
            +
                        next_nonterminal = 1.0 - episode_starts[t + 1]
         
     | 
| 90 | 
         
            +
                        next_value = values[t + 1]
         
     | 
| 91 | 
         
            +
                    delta = rewards[t] + gamma * next_value * next_nonterminal - values[t]
         
     | 
| 92 | 
         
            +
                    last_gae_lam = delta + gamma * gae_lambda * next_nonterminal * last_gae_lam
         
     | 
| 93 | 
         
            +
                    advantages[t] = last_gae_lam
         
     | 
| 94 | 
         
            +
                return advantages
         
     | 
    	
        rl_algo_impls/shared/module/feature_extractor.py
    DELETED
    
    | 
         @@ -1,215 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import gym
         
     | 
| 2 | 
         
            -
            import torch
         
     | 
| 3 | 
         
            -
            import torch.nn as nn
         
     | 
| 4 | 
         
            -
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            from abc import ABC, abstractmethod
         
     | 
| 7 | 
         
            -
            from gym.spaces import Box, Discrete
         
     | 
| 8 | 
         
            -
            from stable_baselines3.common.preprocessing import get_flattened_obs_dim
         
     | 
| 9 | 
         
            -
            from typing import Dict, Optional, Sequence, Type
         
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
            from rl_algo_impls.shared.module.module import layer_init
         
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
            class CnnFeatureExtractor(nn.Module, ABC):
         
     | 
| 15 | 
         
            -
                @abstractmethod
         
     | 
| 16 | 
         
            -
                def __init__(
         
     | 
| 17 | 
         
            -
                    self,
         
     | 
| 18 | 
         
            -
                    in_channels: int,
         
     | 
| 19 | 
         
            -
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 20 | 
         
            -
                    init_layers_orthogonal: Optional[bool] = None,
         
     | 
| 21 | 
         
            -
                    **kwargs,
         
     | 
| 22 | 
         
            -
                ) -> None:
         
     | 
| 23 | 
         
            -
                    super().__init__()
         
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
            class NatureCnn(CnnFeatureExtractor):
         
     | 
| 27 | 
         
            -
                """
         
     | 
| 28 | 
         
            -
                CNN from DQN Nature paper: Mnih, Volodymyr, et al.
         
     | 
| 29 | 
         
            -
                "Human-level control through deep reinforcement learning."
         
     | 
| 30 | 
         
            -
                Nature 518.7540 (2015): 529-533.
         
     | 
| 31 | 
         
            -
                """
         
     | 
| 32 | 
         
            -
             
     | 
| 33 | 
         
            -
                def __init__(
         
     | 
| 34 | 
         
            -
                    self,
         
     | 
| 35 | 
         
            -
                    in_channels: int,
         
     | 
| 36 | 
         
            -
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 37 | 
         
            -
                    init_layers_orthogonal: Optional[bool] = None,
         
     | 
| 38 | 
         
            -
                    **kwargs,
         
     | 
| 39 | 
         
            -
                ) -> None:
         
     | 
| 40 | 
         
            -
                    if init_layers_orthogonal is None:
         
     | 
| 41 | 
         
            -
                        init_layers_orthogonal = True
         
     | 
| 42 | 
         
            -
                    super().__init__(in_channels, activation, init_layers_orthogonal)
         
     | 
| 43 | 
         
            -
                    self.cnn = nn.Sequential(
         
     | 
| 44 | 
         
            -
                        layer_init(
         
     | 
| 45 | 
         
            -
                            nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
         
     | 
| 46 | 
         
            -
                            init_layers_orthogonal,
         
     | 
| 47 | 
         
            -
                        ),
         
     | 
| 48 | 
         
            -
                        activation(),
         
     | 
| 49 | 
         
            -
                        layer_init(
         
     | 
| 50 | 
         
            -
                            nn.Conv2d(32, 64, kernel_size=4, stride=2),
         
     | 
| 51 | 
         
            -
                            init_layers_orthogonal,
         
     | 
| 52 | 
         
            -
                        ),
         
     | 
| 53 | 
         
            -
                        activation(),
         
     | 
| 54 | 
         
            -
                        layer_init(
         
     | 
| 55 | 
         
            -
                            nn.Conv2d(64, 64, kernel_size=3, stride=1),
         
     | 
| 56 | 
         
            -
                            init_layers_orthogonal,
         
     | 
| 57 | 
         
            -
                        ),
         
     | 
| 58 | 
         
            -
                        activation(),
         
     | 
| 59 | 
         
            -
                        nn.Flatten(),
         
     | 
| 60 | 
         
            -
                    )
         
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
                def forward(self, obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 63 | 
         
            -
                    return self.cnn(obs)
         
     | 
| 64 | 
         
            -
             
     | 
| 65 | 
         
            -
             
     | 
| 66 | 
         
            -
            class ResidualBlock(nn.Module):
         
     | 
| 67 | 
         
            -
                def __init__(
         
     | 
| 68 | 
         
            -
                    self,
         
     | 
| 69 | 
         
            -
                    channels: int,
         
     | 
| 70 | 
         
            -
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 71 | 
         
            -
                    init_layers_orthogonal: bool = False,
         
     | 
| 72 | 
         
            -
                ) -> None:
         
     | 
| 73 | 
         
            -
                    super().__init__()
         
     | 
| 74 | 
         
            -
                    self.residual = nn.Sequential(
         
     | 
| 75 | 
         
            -
                        activation(),
         
     | 
| 76 | 
         
            -
                        layer_init(
         
     | 
| 77 | 
         
            -
                            nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
         
     | 
| 78 | 
         
            -
                        ),
         
     | 
| 79 | 
         
            -
                        activation(),
         
     | 
| 80 | 
         
            -
                        layer_init(
         
     | 
| 81 | 
         
            -
                            nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
         
     | 
| 82 | 
         
            -
                        ),
         
     | 
| 83 | 
         
            -
                    )
         
     | 
| 84 | 
         
            -
             
     | 
| 85 | 
         
            -
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 86 | 
         
            -
                    return x + self.residual(x)
         
     | 
| 87 | 
         
            -
             
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
            -
            class ConvSequence(nn.Module):
         
     | 
| 90 | 
         
            -
                def __init__(
         
     | 
| 91 | 
         
            -
                    self,
         
     | 
| 92 | 
         
            -
                    in_channels: int,
         
     | 
| 93 | 
         
            -
                    out_channels: int,
         
     | 
| 94 | 
         
            -
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 95 | 
         
            -
                    init_layers_orthogonal: bool = False,
         
     | 
| 96 | 
         
            -
                ) -> None:
         
     | 
| 97 | 
         
            -
                    super().__init__()
         
     | 
| 98 | 
         
            -
                    self.seq = nn.Sequential(
         
     | 
| 99 | 
         
            -
                        layer_init(
         
     | 
| 100 | 
         
            -
                            nn.Conv2d(in_channels, out_channels, 3, padding=1),
         
     | 
| 101 | 
         
            -
                            init_layers_orthogonal,
         
     | 
| 102 | 
         
            -
                        ),
         
     | 
| 103 | 
         
            -
                        nn.MaxPool2d(3, stride=2, padding=1),
         
     | 
| 104 | 
         
            -
                        ResidualBlock(out_channels, activation, init_layers_orthogonal),
         
     | 
| 105 | 
         
            -
                        ResidualBlock(out_channels, activation, init_layers_orthogonal),
         
     | 
| 106 | 
         
            -
                    )
         
     | 
| 107 | 
         
            -
             
     | 
| 108 | 
         
            -
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 109 | 
         
            -
                    return self.seq(x)
         
     | 
| 110 | 
         
            -
             
     | 
| 111 | 
         
            -
             
     | 
| 112 | 
         
            -
            class ImpalaCnn(CnnFeatureExtractor):
         
     | 
| 113 | 
         
            -
                """
         
     | 
| 114 | 
         
            -
                IMPALA-style CNN architecture
         
     | 
| 115 | 
         
            -
                """
         
     | 
| 116 | 
         
            -
             
     | 
| 117 | 
         
            -
                def __init__(
         
     | 
| 118 | 
         
            -
                    self,
         
     | 
| 119 | 
         
            -
                    in_channels: int,
         
     | 
| 120 | 
         
            -
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 121 | 
         
            -
                    init_layers_orthogonal: Optional[bool] = None,
         
     | 
| 122 | 
         
            -
                    impala_channels: Sequence[int] = (16, 32, 32),
         
     | 
| 123 | 
         
            -
                    **kwargs,
         
     | 
| 124 | 
         
            -
                ) -> None:
         
     | 
| 125 | 
         
            -
                    if init_layers_orthogonal is None:
         
     | 
| 126 | 
         
            -
                        init_layers_orthogonal = False
         
     | 
| 127 | 
         
            -
                    super().__init__(in_channels, activation, init_layers_orthogonal)
         
     | 
| 128 | 
         
            -
                    sequences = []
         
     | 
| 129 | 
         
            -
                    for out_channels in impala_channels:
         
     | 
| 130 | 
         
            -
                        sequences.append(
         
     | 
| 131 | 
         
            -
                            ConvSequence(
         
     | 
| 132 | 
         
            -
                                in_channels, out_channels, activation, init_layers_orthogonal
         
     | 
| 133 | 
         
            -
                            )
         
     | 
| 134 | 
         
            -
                        )
         
     | 
| 135 | 
         
            -
                        in_channels = out_channels
         
     | 
| 136 | 
         
            -
                    sequences.extend(
         
     | 
| 137 | 
         
            -
                        [
         
     | 
| 138 | 
         
            -
                            activation(),
         
     | 
| 139 | 
         
            -
                            nn.Flatten(),
         
     | 
| 140 | 
         
            -
                        ]
         
     | 
| 141 | 
         
            -
                    )
         
     | 
| 142 | 
         
            -
                    self.seq = nn.Sequential(*sequences)
         
     | 
| 143 | 
         
            -
             
     | 
| 144 | 
         
            -
                def forward(self, obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 145 | 
         
            -
                    return self.seq(obs)
         
     | 
| 146 | 
         
            -
             
     | 
| 147 | 
         
            -
             
     | 
| 148 | 
         
            -
            CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnFeatureExtractor]] = {
         
     | 
| 149 | 
         
            -
                "nature": NatureCnn,
         
     | 
| 150 | 
         
            -
                "impala": ImpalaCnn,
         
     | 
| 151 | 
         
            -
            }
         
     | 
| 152 | 
         
            -
             
     | 
| 153 | 
         
            -
             
     | 
| 154 | 
         
            -
            class FeatureExtractor(nn.Module):
         
     | 
| 155 | 
         
            -
                def __init__(
         
     | 
| 156 | 
         
            -
                    self,
         
     | 
| 157 | 
         
            -
                    obs_space: gym.Space,
         
     | 
| 158 | 
         
            -
                    activation: Type[nn.Module],
         
     | 
| 159 | 
         
            -
                    init_layers_orthogonal: bool = False,
         
     | 
| 160 | 
         
            -
                    cnn_feature_dim: int = 512,
         
     | 
| 161 | 
         
            -
                    cnn_style: str = "nature",
         
     | 
| 162 | 
         
            -
                    cnn_layers_init_orthogonal: Optional[bool] = None,
         
     | 
| 163 | 
         
            -
                    impala_channels: Sequence[int] = (16, 32, 32),
         
     | 
| 164 | 
         
            -
                ) -> None:
         
     | 
| 165 | 
         
            -
                    super().__init__()
         
     | 
| 166 | 
         
            -
                    if isinstance(obs_space, Box):
         
     | 
| 167 | 
         
            -
                        # Conv2D: (channels, height, width)
         
     | 
| 168 | 
         
            -
                        if len(obs_space.shape) == 3:
         
     | 
| 169 | 
         
            -
                            cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
         
     | 
| 170 | 
         
            -
                                obs_space.shape[0],
         
     | 
| 171 | 
         
            -
                                activation,
         
     | 
| 172 | 
         
            -
                                init_layers_orthogonal=cnn_layers_init_orthogonal,
         
     | 
| 173 | 
         
            -
                                impala_channels=impala_channels,
         
     | 
| 174 | 
         
            -
                            )
         
     | 
| 175 | 
         
            -
             
     | 
| 176 | 
         
            -
                            def preprocess(obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 177 | 
         
            -
                                if len(obs.shape) == 3:
         
     | 
| 178 | 
         
            -
                                    obs = obs.unsqueeze(0)
         
     | 
| 179 | 
         
            -
                                return obs.float() / 255.0
         
     | 
| 180 | 
         
            -
             
     | 
| 181 | 
         
            -
                            with torch.no_grad():
         
     | 
| 182 | 
         
            -
                                cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
         
     | 
| 183 | 
         
            -
                            self.preprocess = preprocess
         
     | 
| 184 | 
         
            -
                            self.feature_extractor = nn.Sequential(
         
     | 
| 185 | 
         
            -
                                cnn,
         
     | 
| 186 | 
         
            -
                                layer_init(
         
     | 
| 187 | 
         
            -
                                    nn.Linear(cnn_out.shape[1], cnn_feature_dim),
         
     | 
| 188 | 
         
            -
                                    init_layers_orthogonal,
         
     | 
| 189 | 
         
            -
                                ),
         
     | 
| 190 | 
         
            -
                                activation(),
         
     | 
| 191 | 
         
            -
                            )
         
     | 
| 192 | 
         
            -
                            self.out_dim = cnn_feature_dim
         
     | 
| 193 | 
         
            -
                        elif len(obs_space.shape) == 1:
         
     | 
| 194 | 
         
            -
             
     | 
| 195 | 
         
            -
                            def preprocess(obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 196 | 
         
            -
                                if len(obs.shape) == 1:
         
     | 
| 197 | 
         
            -
                                    obs = obs.unsqueeze(0)
         
     | 
| 198 | 
         
            -
                                return obs.float()
         
     | 
| 199 | 
         
            -
             
     | 
| 200 | 
         
            -
                            self.preprocess = preprocess
         
     | 
| 201 | 
         
            -
                            self.feature_extractor = nn.Flatten()
         
     | 
| 202 | 
         
            -
                            self.out_dim = get_flattened_obs_dim(obs_space)
         
     | 
| 203 | 
         
            -
                        else:
         
     | 
| 204 | 
         
            -
                            raise ValueError(f"Unsupported observation space: {obs_space}")
         
     | 
| 205 | 
         
            -
                    elif isinstance(obs_space, Discrete):
         
     | 
| 206 | 
         
            -
                        self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
         
     | 
| 207 | 
         
            -
                        self.feature_extractor = nn.Flatten()
         
     | 
| 208 | 
         
            -
                        self.out_dim = obs_space.n
         
     | 
| 209 | 
         
            -
                    else:
         
     | 
| 210 | 
         
            -
                        raise NotImplementedError
         
     | 
| 211 | 
         
            -
             
     | 
| 212 | 
         
            -
                def forward(self, obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 213 | 
         
            -
                    if self.preprocess:
         
     | 
| 214 | 
         
            -
                        obs = self.preprocess(obs)
         
     | 
| 215 | 
         
            -
                    return self.feature_extractor(obs)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        rl_algo_impls/shared/module/module.py
    CHANGED
    
    | 
         @@ -1,8 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import numpy as np
         
     | 
| 2 | 
         
             
            import torch.nn as nn
         
     | 
| 3 | 
         | 
| 4 | 
         
            -
            from typing import Sequence, Type
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         | 
| 7 | 
         
             
            def mlp(
         
     | 
| 8 | 
         
             
                layer_sizes: Sequence[int],
         
     | 
| 
         @@ -10,12 +10,15 @@ def mlp( 
     | 
|
| 10 | 
         
             
                output_activation: Type[nn.Module] = nn.Identity,
         
     | 
| 11 | 
         
             
                init_layers_orthogonal: bool = False,
         
     | 
| 12 | 
         
             
                final_layer_gain: float = np.sqrt(2),
         
     | 
| 
         | 
|
| 13 | 
         
             
            ) -> nn.Module:
         
     | 
| 14 | 
         
             
                layers = []
         
     | 
| 15 | 
         
             
                for i in range(len(layer_sizes) - 2):
         
     | 
| 16 | 
         
             
                    layers.append(
         
     | 
| 17 | 
         
             
                        layer_init(
         
     | 
| 18 | 
         
            -
                            nn.Linear(layer_sizes[i], layer_sizes[i + 1]), 
     | 
| 
         | 
|
| 
         | 
|
| 19 | 
         
             
                        )
         
     | 
| 20 | 
         
             
                    )
         
     | 
| 21 | 
         
             
                    layers.append(activation())
         
     | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Sequence, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
             
            import numpy as np
         
     | 
| 4 | 
         
             
            import torch.nn as nn
         
     | 
| 5 | 
         | 
| 
         | 
|
| 
         | 
|
| 6 | 
         | 
| 7 | 
         
             
            def mlp(
         
     | 
| 8 | 
         
             
                layer_sizes: Sequence[int],
         
     | 
| 
         | 
|
| 10 | 
         
             
                output_activation: Type[nn.Module] = nn.Identity,
         
     | 
| 11 | 
         
             
                init_layers_orthogonal: bool = False,
         
     | 
| 12 | 
         
             
                final_layer_gain: float = np.sqrt(2),
         
     | 
| 13 | 
         
            +
                hidden_layer_gain: float = np.sqrt(2),
         
     | 
| 14 | 
         
             
            ) -> nn.Module:
         
     | 
| 15 | 
         
             
                layers = []
         
     | 
| 16 | 
         
             
                for i in range(len(layer_sizes) - 2):
         
     | 
| 17 | 
         
             
                    layers.append(
         
     | 
| 18 | 
         
             
                        layer_init(
         
     | 
| 19 | 
         
            +
                            nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
         
     | 
| 20 | 
         
            +
                            init_layers_orthogonal,
         
     | 
| 21 | 
         
            +
                            std=hidden_layer_gain,
         
     | 
| 22 | 
         
             
                        )
         
     | 
| 23 | 
         
             
                    )
         
     | 
| 24 | 
         
             
                    layers.append(activation())
         
     | 
    	
        rl_algo_impls/shared/policy/critic.py
    CHANGED
    
    | 
         @@ -1,27 +1,39 @@ 
     | 
|
| 1 | 
         
            -
            import  
     | 
| 
         | 
|
| 
         | 
|
| 2 | 
         
             
            import torch
         
     | 
| 3 | 
         
             
            import torch.nn as nn
         
     | 
| 4 | 
         | 
| 5 | 
         
            -
            from  
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
             
            from rl_algo_impls.shared.module.module import mlp
         
     | 
| 8 | 
         | 
| 9 | 
         | 
| 10 | 
         
             
            class CriticHead(nn.Module):
         
     | 
| 11 | 
         
             
                def __init__(
         
     | 
| 12 | 
         
             
                    self,
         
     | 
| 13 | 
         
            -
                     
     | 
| 
         | 
|
| 14 | 
         
             
                    activation: Type[nn.Module] = nn.Tanh,
         
     | 
| 15 | 
         
             
                    init_layers_orthogonal: bool = True,
         
     | 
| 16 | 
         
             
                ) -> None:
         
     | 
| 17 | 
         
             
                    super().__init__()
         
     | 
| 18 | 
         
            -
                     
     | 
| 19 | 
         
            -
                     
     | 
| 20 | 
         
            -
                         
     | 
| 21 | 
         
            -
                         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 24 | 
         
             
                    )
         
     | 
| 
         | 
|
| 25 | 
         | 
| 26 | 
         
             
                def forward(self, obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 27 | 
         
             
                    v = self._fc(obs)
         
     | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Sequence, Type
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
             
            import torch
         
     | 
| 5 | 
         
             
            import torch.nn as nn
         
     | 
| 6 | 
         | 
| 7 | 
         
            +
            from rl_algo_impls.shared.encoder import EncoderOutDim
         
     | 
| 
         | 
|
| 8 | 
         
             
            from rl_algo_impls.shared.module.module import mlp
         
     | 
| 9 | 
         | 
| 10 | 
         | 
| 11 | 
         
             
            class CriticHead(nn.Module):
         
     | 
| 12 | 
         
             
                def __init__(
         
     | 
| 13 | 
         
             
                    self,
         
     | 
| 14 | 
         
            +
                    in_dim: EncoderOutDim,
         
     | 
| 15 | 
         
            +
                    hidden_sizes: Sequence[int] = (),
         
     | 
| 16 | 
         
             
                    activation: Type[nn.Module] = nn.Tanh,
         
     | 
| 17 | 
         
             
                    init_layers_orthogonal: bool = True,
         
     | 
| 18 | 
         
             
                ) -> None:
         
     | 
| 19 | 
         
             
                    super().__init__()
         
     | 
| 20 | 
         
            +
                    seq = []
         
     | 
| 21 | 
         
            +
                    if isinstance(in_dim, tuple):
         
     | 
| 22 | 
         
            +
                        seq.append(nn.Flatten())
         
     | 
| 23 | 
         
            +
                        in_channels = int(np.prod(in_dim))
         
     | 
| 24 | 
         
            +
                    else:
         
     | 
| 25 | 
         
            +
                        in_channels = in_dim
         
     | 
| 26 | 
         
            +
                    layer_sizes = (in_channels,) + tuple(hidden_sizes) + (1,)
         
     | 
| 27 | 
         
            +
                    seq.append(
         
     | 
| 28 | 
         
            +
                        mlp(
         
     | 
| 29 | 
         
            +
                            layer_sizes,
         
     | 
| 30 | 
         
            +
                            activation,
         
     | 
| 31 | 
         
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 32 | 
         
            +
                            final_layer_gain=1.0,
         
     | 
| 33 | 
         
            +
                            hidden_layer_gain=1.0,
         
     | 
| 34 | 
         
            +
                        )
         
     | 
| 35 | 
         
             
                    )
         
     | 
| 36 | 
         
            +
                    self._fc = nn.Sequential(*seq)
         
     | 
| 37 | 
         | 
| 38 | 
         
             
                def forward(self, obs: torch.Tensor) -> torch.Tensor:
         
     | 
| 39 | 
         
             
                    v = self._fc(obs)
         
     | 
    	
        rl_algo_impls/shared/policy/on_policy.py
    CHANGED
    
    | 
         @@ -1,24 +1,20 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import gym
         
     | 
| 2 | 
         
             
            import numpy as np
         
     | 
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
            from abc import abstractmethod
         
     | 
| 6 | 
         
             
            from gym.spaces import Box, Discrete, Space
         
     | 
| 7 | 
         
            -
            from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
         
     | 
| 8 | 
         | 
| 9 | 
         
            -
            from rl_algo_impls.shared. 
     | 
| 10 | 
         
            -
            from rl_algo_impls.shared. 
     | 
| 11 | 
         
            -
                PiForward,
         
     | 
| 12 | 
         
            -
                StateDependentNoiseActorHead,
         
     | 
| 13 | 
         
            -
                actor_head,
         
     | 
| 14 | 
         
            -
            )
         
     | 
| 15 | 
         
             
            from rl_algo_impls.shared.policy.critic import CriticHead
         
     | 
| 16 | 
         
             
            from rl_algo_impls.shared.policy.policy import ACTIVATION, Policy
         
     | 
| 17 | 
         
             
            from rl_algo_impls.wrappers.vectorable_wrapper import (
         
     | 
| 18 | 
         
             
                VecEnv,
         
     | 
| 19 | 
         
             
                VecEnvObs,
         
     | 
| 20 | 
         
            -
                single_observation_space,
         
     | 
| 21 | 
         
             
                single_action_space,
         
     | 
| 
         | 
|
| 22 | 
         
             
            )
         
     | 
| 23 | 
         | 
| 24 | 
         | 
| 
         @@ -77,7 +73,12 @@ class OnPolicy(Policy): 
     | 
|
| 77 | 
         
             
                    ...
         
     | 
| 78 | 
         | 
| 79 | 
         
             
                @abstractmethod
         
     | 
| 80 | 
         
            -
                def step(self, obs: VecEnvObs) -> Step:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 81 | 
         
             
                    ...
         
     | 
| 82 | 
         | 
| 83 | 
         | 
| 
         @@ -94,10 +95,11 @@ class ActorCritic(OnPolicy): 
     | 
|
| 94 | 
         
             
                    full_std: bool = True,
         
     | 
| 95 | 
         
             
                    squash_output: bool = False,
         
     | 
| 96 | 
         
             
                    share_features_extractor: bool = True,
         
     | 
| 97 | 
         
            -
                     
     | 
| 98 | 
         
             
                    cnn_style: str = "nature",
         
     | 
| 99 | 
         
             
                    cnn_layers_init_orthogonal: Optional[bool] = None,
         
     | 
| 100 | 
         
             
                    impala_channels: Sequence[int] = (16, 32, 32),
         
     | 
| 
         | 
|
| 101 | 
         
             
                    **kwargs,
         
     | 
| 102 | 
         
             
                ) -> None:
         
     | 
| 103 | 
         
             
                    super().__init__(env, **kwargs)
         
     | 
| 
         @@ -120,52 +122,56 @@ class ActorCritic(OnPolicy): 
     | 
|
| 120 | 
         
             
                    self.action_space = action_space
         
     | 
| 121 | 
         
             
                    self.squash_output = squash_output
         
     | 
| 122 | 
         
             
                    self.share_features_extractor = share_features_extractor
         
     | 
| 123 | 
         
            -
                    self._feature_extractor =  
     | 
| 124 | 
         
             
                        observation_space,
         
     | 
| 125 | 
         
             
                        activation,
         
     | 
| 126 | 
         
             
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 127 | 
         
            -
                         
     | 
| 128 | 
         
             
                        cnn_style=cnn_style,
         
     | 
| 129 | 
         
             
                        cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
         
     | 
| 130 | 
         
             
                        impala_channels=impala_channels,
         
     | 
| 131 | 
         
             
                    )
         
     | 
| 132 | 
         
             
                    self._pi = actor_head(
         
     | 
| 133 | 
         
             
                        self.action_space,
         
     | 
| 134 | 
         
            -
                         
     | 
| 
         | 
|
| 135 | 
         
             
                        init_layers_orthogonal,
         
     | 
| 136 | 
         
             
                        activation,
         
     | 
| 137 | 
         
             
                        log_std_init=log_std_init,
         
     | 
| 138 | 
         
             
                        use_sde=use_sde,
         
     | 
| 139 | 
         
             
                        full_std=full_std,
         
     | 
| 140 | 
         
             
                        squash_output=squash_output,
         
     | 
| 
         | 
|
| 141 | 
         
             
                    )
         
     | 
| 142 | 
         | 
| 143 | 
         
             
                    if not share_features_extractor:
         
     | 
| 144 | 
         
            -
                        self._v_feature_extractor =  
     | 
| 145 | 
         
             
                            observation_space,
         
     | 
| 146 | 
         
             
                            activation,
         
     | 
| 147 | 
         
             
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 148 | 
         
            -
                             
     | 
| 149 | 
         
             
                            cnn_style=cnn_style,
         
     | 
| 150 | 
         
             
                            cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
         
     | 
| 151 | 
         
             
                        )
         
     | 
| 152 | 
         
            -
                         
     | 
| 153 | 
         
            -
                            v_hidden_sizes
         
     | 
| 154 | 
         
            -
                        )
         
     | 
| 155 | 
         
             
                    else:
         
     | 
| 156 | 
         
             
                        self._v_feature_extractor = None
         
     | 
| 157 | 
         
            -
                         
     | 
| 158 | 
         
             
                    self._v = CriticHead(
         
     | 
| 
         | 
|
| 159 | 
         
             
                        hidden_sizes=v_hidden_sizes,
         
     | 
| 160 | 
         
             
                        activation=activation,
         
     | 
| 161 | 
         
             
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 162 | 
         
             
                    )
         
     | 
| 163 | 
         | 
| 164 | 
         
             
                def _pi_forward(
         
     | 
| 165 | 
         
            -
                    self, 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 166 | 
         
             
                ) -> Tuple[PiForward, torch.Tensor]:
         
     | 
| 167 | 
         
             
                    p_fe = self._feature_extractor(obs)
         
     | 
| 168 | 
         
            -
                    pi_forward = self._pi(p_fe, action)
         
     | 
| 169 | 
         | 
| 170 | 
         
             
                    return pi_forward, p_fe
         
     | 
| 171 | 
         | 
| 
         @@ -173,8 +179,13 @@ class ActorCritic(OnPolicy): 
     | 
|
| 173 | 
         
             
                    v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
         
     | 
| 174 | 
         
             
                    return self._v(v_fe)
         
     | 
| 175 | 
         | 
| 176 | 
         
            -
                def forward( 
     | 
| 177 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 178 | 
         
             
                    v = self._v_forward(obs, p_fc)
         
     | 
| 179 | 
         | 
| 180 | 
         
             
                    assert logp_a is not None
         
     | 
| 
         @@ -192,10 +203,11 @@ class ActorCritic(OnPolicy): 
     | 
|
| 192 | 
         
             
                        v = self._v(fe)
         
     | 
| 193 | 
         
             
                    return v.cpu().numpy()
         
     | 
| 194 | 
         | 
| 195 | 
         
            -
                def step(self, obs: VecEnvObs) -> Step:
         
     | 
| 196 | 
         
             
                    o = self._as_tensor(obs)
         
     | 
| 
         | 
|
| 197 | 
         
             
                    with torch.no_grad():
         
     | 
| 198 | 
         
            -
                        (pi, _, _), p_fc = self._pi_forward(o)
         
     | 
| 199 | 
         
             
                        a = pi.sample()
         
     | 
| 200 | 
         
             
                        logp_a = pi.log_prob(a)
         
     | 
| 201 | 
         | 
| 
         @@ -205,13 +217,21 @@ class ActorCritic(OnPolicy): 
     | 
|
| 205 | 
         
             
                    clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
         
     | 
| 206 | 
         
             
                    return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
         
     | 
| 207 | 
         | 
| 208 | 
         
            -
                def act( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 209 | 
         
             
                    if not deterministic:
         
     | 
| 210 | 
         
            -
                        return self.step(obs).clamped_a
         
     | 
| 211 | 
         
             
                    else:
         
     | 
| 212 | 
         
             
                        o = self._as_tensor(obs)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 213 | 
         
             
                        with torch.no_grad():
         
     | 
| 214 | 
         
            -
                            (pi, _, _), _ = self._pi_forward(o)
         
     | 
| 215 | 
         
             
                            a = pi.mode
         
     | 
| 216 | 
         
             
                        return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
         
     | 
| 217 | 
         | 
| 
         @@ -220,7 +240,10 @@ class ActorCritic(OnPolicy): 
     | 
|
| 220 | 
         
             
                    self.reset_noise()
         
     | 
| 221 | 
         | 
| 222 | 
         
             
                def reset_noise(self, batch_size: Optional[int] = None) -> None:
         
     | 
| 223 | 
         
            -
                     
     | 
| 224 | 
         
            -
                        self. 
     | 
| 225 | 
         
            -
             
     | 
| 226 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from abc import abstractmethod
         
     | 
| 2 | 
         
            +
            from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
             
            import gym
         
     | 
| 5 | 
         
             
            import numpy as np
         
     | 
| 6 | 
         
             
            import torch
         
     | 
| 
         | 
|
| 
         | 
|
| 7 | 
         
             
            from gym.spaces import Box, Discrete, Space
         
     | 
| 
         | 
|
| 8 | 
         | 
| 9 | 
         
            +
            from rl_algo_impls.shared.actor import PiForward, actor_head
         
     | 
| 10 | 
         
            +
            from rl_algo_impls.shared.encoder import Encoder
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 11 | 
         
             
            from rl_algo_impls.shared.policy.critic import CriticHead
         
     | 
| 12 | 
         
             
            from rl_algo_impls.shared.policy.policy import ACTIVATION, Policy
         
     | 
| 13 | 
         
             
            from rl_algo_impls.wrappers.vectorable_wrapper import (
         
     | 
| 14 | 
         
             
                VecEnv,
         
     | 
| 15 | 
         
             
                VecEnvObs,
         
     | 
| 
         | 
|
| 16 | 
         
             
                single_action_space,
         
     | 
| 17 | 
         
            +
                single_observation_space,
         
     | 
| 18 | 
         
             
            )
         
     | 
| 19 | 
         | 
| 20 | 
         | 
| 
         | 
|
| 73 | 
         
             
                    ...
         
     | 
| 74 | 
         | 
| 75 | 
         
             
                @abstractmethod
         
     | 
| 76 | 
         
            +
                def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
         
     | 
| 77 | 
         
            +
                    ...
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                @property
         
     | 
| 80 | 
         
            +
                @abstractmethod
         
     | 
| 81 | 
         
            +
                def action_shape(self) -> Tuple[int, ...]:
         
     | 
| 82 | 
         
             
                    ...
         
     | 
| 83 | 
         | 
| 84 | 
         | 
| 
         | 
|
| 95 | 
         
             
                    full_std: bool = True,
         
     | 
| 96 | 
         
             
                    squash_output: bool = False,
         
     | 
| 97 | 
         
             
                    share_features_extractor: bool = True,
         
     | 
| 98 | 
         
            +
                    cnn_flatten_dim: int = 512,
         
     | 
| 99 | 
         
             
                    cnn_style: str = "nature",
         
     | 
| 100 | 
         
             
                    cnn_layers_init_orthogonal: Optional[bool] = None,
         
     | 
| 101 | 
         
             
                    impala_channels: Sequence[int] = (16, 32, 32),
         
     | 
| 102 | 
         
            +
                    actor_head_style: str = "single",
         
     | 
| 103 | 
         
             
                    **kwargs,
         
     | 
| 104 | 
         
             
                ) -> None:
         
     | 
| 105 | 
         
             
                    super().__init__(env, **kwargs)
         
     | 
| 
         | 
|
| 122 | 
         
             
                    self.action_space = action_space
         
     | 
| 123 | 
         
             
                    self.squash_output = squash_output
         
     | 
| 124 | 
         
             
                    self.share_features_extractor = share_features_extractor
         
     | 
| 125 | 
         
            +
                    self._feature_extractor = Encoder(
         
     | 
| 126 | 
         
             
                        observation_space,
         
     | 
| 127 | 
         
             
                        activation,
         
     | 
| 128 | 
         
             
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 129 | 
         
            +
                        cnn_flatten_dim=cnn_flatten_dim,
         
     | 
| 130 | 
         
             
                        cnn_style=cnn_style,
         
     | 
| 131 | 
         
             
                        cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
         
     | 
| 132 | 
         
             
                        impala_channels=impala_channels,
         
     | 
| 133 | 
         
             
                    )
         
     | 
| 134 | 
         
             
                    self._pi = actor_head(
         
     | 
| 135 | 
         
             
                        self.action_space,
         
     | 
| 136 | 
         
            +
                        self._feature_extractor.out_dim,
         
     | 
| 137 | 
         
            +
                        tuple(pi_hidden_sizes),
         
     | 
| 138 | 
         
             
                        init_layers_orthogonal,
         
     | 
| 139 | 
         
             
                        activation,
         
     | 
| 140 | 
         
             
                        log_std_init=log_std_init,
         
     | 
| 141 | 
         
             
                        use_sde=use_sde,
         
     | 
| 142 | 
         
             
                        full_std=full_std,
         
     | 
| 143 | 
         
             
                        squash_output=squash_output,
         
     | 
| 144 | 
         
            +
                        actor_head_style=actor_head_style,
         
     | 
| 145 | 
         
             
                    )
         
     | 
| 146 | 
         | 
| 147 | 
         
             
                    if not share_features_extractor:
         
     | 
| 148 | 
         
            +
                        self._v_feature_extractor = Encoder(
         
     | 
| 149 | 
         
             
                            observation_space,
         
     | 
| 150 | 
         
             
                            activation,
         
     | 
| 151 | 
         
             
                            init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 152 | 
         
            +
                            cnn_flatten_dim=cnn_flatten_dim,
         
     | 
| 153 | 
         
             
                            cnn_style=cnn_style,
         
     | 
| 154 | 
         
             
                            cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
         
     | 
| 155 | 
         
             
                        )
         
     | 
| 156 | 
         
            +
                        critic_in_dim = self._v_feature_extractor.out_dim
         
     | 
| 
         | 
|
| 
         | 
|
| 157 | 
         
             
                    else:
         
     | 
| 158 | 
         
             
                        self._v_feature_extractor = None
         
     | 
| 159 | 
         
            +
                        critic_in_dim = self._feature_extractor.out_dim
         
     | 
| 160 | 
         
             
                    self._v = CriticHead(
         
     | 
| 161 | 
         
            +
                        in_dim=critic_in_dim,
         
     | 
| 162 | 
         
             
                        hidden_sizes=v_hidden_sizes,
         
     | 
| 163 | 
         
             
                        activation=activation,
         
     | 
| 164 | 
         
             
                        init_layers_orthogonal=init_layers_orthogonal,
         
     | 
| 165 | 
         
             
                    )
         
     | 
| 166 | 
         | 
| 167 | 
         
             
                def _pi_forward(
         
     | 
| 168 | 
         
            +
                    self,
         
     | 
| 169 | 
         
            +
                    obs: torch.Tensor,
         
     | 
| 170 | 
         
            +
                    action_masks: Optional[torch.Tensor],
         
     | 
| 171 | 
         
            +
                    action: Optional[torch.Tensor] = None,
         
     | 
| 172 | 
         
             
                ) -> Tuple[PiForward, torch.Tensor]:
         
     | 
| 173 | 
         
             
                    p_fe = self._feature_extractor(obs)
         
     | 
| 174 | 
         
            +
                    pi_forward = self._pi(p_fe, actions=action, action_masks=action_masks)
         
     | 
| 175 | 
         | 
| 176 | 
         
             
                    return pi_forward, p_fe
         
     | 
| 177 | 
         | 
| 
         | 
|
| 179 | 
         
             
                    v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
         
     | 
| 180 | 
         
             
                    return self._v(v_fe)
         
     | 
| 181 | 
         | 
| 182 | 
         
            +
                def forward(
         
     | 
| 183 | 
         
            +
                    self,
         
     | 
| 184 | 
         
            +
                    obs: torch.Tensor,
         
     | 
| 185 | 
         
            +
                    action: torch.Tensor,
         
     | 
| 186 | 
         
            +
                    action_masks: Optional[torch.Tensor] = None,
         
     | 
| 187 | 
         
            +
                ) -> ACForward:
         
     | 
| 188 | 
         
            +
                    (_, logp_a, entropy), p_fc = self._pi_forward(obs, action_masks, action=action)
         
     | 
| 189 | 
         
             
                    v = self._v_forward(obs, p_fc)
         
     | 
| 190 | 
         | 
| 191 | 
         
             
                    assert logp_a is not None
         
     | 
| 
         | 
|
| 203 | 
         
             
                        v = self._v(fe)
         
     | 
| 204 | 
         
             
                    return v.cpu().numpy()
         
     | 
| 205 | 
         | 
| 206 | 
         
            +
                def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
         
     | 
| 207 | 
         
             
                    o = self._as_tensor(obs)
         
     | 
| 208 | 
         
            +
                    a_masks = self._as_tensor(action_masks) if action_masks is not None else None
         
     | 
| 209 | 
         
             
                    with torch.no_grad():
         
     | 
| 210 | 
         
            +
                        (pi, _, _), p_fc = self._pi_forward(o, action_masks=a_masks)
         
     | 
| 211 | 
         
             
                        a = pi.sample()
         
     | 
| 212 | 
         
             
                        logp_a = pi.log_prob(a)
         
     | 
| 213 | 
         | 
| 
         | 
|
| 217 | 
         
             
                    clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
         
     | 
| 218 | 
         
             
                    return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
         
     | 
| 219 | 
         | 
| 220 | 
         
            +
                def act(
         
     | 
| 221 | 
         
            +
                    self,
         
     | 
| 222 | 
         
            +
                    obs: np.ndarray,
         
     | 
| 223 | 
         
            +
                    deterministic: bool = True,
         
     | 
| 224 | 
         
            +
                    action_masks: Optional[np.ndarray] = None,
         
     | 
| 225 | 
         
            +
                ) -> np.ndarray:
         
     | 
| 226 | 
         
             
                    if not deterministic:
         
     | 
| 227 | 
         
            +
                        return self.step(obs, action_masks=action_masks).clamped_a
         
     | 
| 228 | 
         
             
                    else:
         
     | 
| 229 | 
         
             
                        o = self._as_tensor(obs)
         
     | 
| 230 | 
         
            +
                        a_masks = (
         
     | 
| 231 | 
         
            +
                            self._as_tensor(action_masks) if action_masks is not None else None
         
     | 
| 232 | 
         
            +
                        )
         
     | 
| 233 | 
         
             
                        with torch.no_grad():
         
     | 
| 234 | 
         
            +
                            (pi, _, _), _ = self._pi_forward(o, action_masks=a_masks)
         
     | 
| 235 | 
         
             
                            a = pi.mode
         
     | 
| 236 | 
         
             
                        return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
         
     | 
| 237 | 
         | 
| 
         | 
|
| 240 | 
         
             
                    self.reset_noise()
         
     | 
| 241 | 
         | 
| 242 | 
         
             
                def reset_noise(self, batch_size: Optional[int] = None) -> None:
         
     | 
| 243 | 
         
            +
                    self._pi.sample_weights(
         
     | 
| 244 | 
         
            +
                        batch_size=batch_size if batch_size else self.env.num_envs
         
     | 
| 245 | 
         
            +
                    )
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                @property
         
     | 
| 248 | 
         
            +
                def action_shape(self) -> Tuple[int, ...]:
         
     | 
| 249 | 
         
            +
                    return self._pi.action_shape
         
     | 
    	
        rl_algo_impls/shared/policy/policy.py
    CHANGED
    
    | 
         @@ -46,7 +46,12 @@ class Policy(nn.Module, ABC): 
     | 
|
| 46 | 
         
             
                    return self
         
     | 
| 47 | 
         | 
| 48 | 
         
             
                @abstractmethod
         
     | 
| 49 | 
         
            -
                def act( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 50 | 
         
             
                    ...
         
     | 
| 51 | 
         | 
| 52 | 
         
             
                def save(self, path: str) -> None:
         
     | 
| 
         | 
|
| 46 | 
         
             
                    return self
         
     | 
| 47 | 
         | 
| 48 | 
         
             
                @abstractmethod
         
     | 
| 49 | 
         
            +
                def act(
         
     | 
| 50 | 
         
            +
                    self,
         
     | 
| 51 | 
         
            +
                    obs: VecEnvObs,
         
     | 
| 52 | 
         
            +
                    deterministic: bool = True,
         
     | 
| 53 | 
         
            +
                    action_masks: Optional[np.ndarray] = None,
         
     | 
| 54 | 
         
            +
                ) -> np.ndarray:
         
     | 
| 55 | 
         
             
                    ...
         
     | 
| 56 | 
         | 
| 57 | 
         
             
                def save(self, path: str) -> None:
         
     | 
    	
        rl_algo_impls/shared/schedule.py
    CHANGED
    
    | 
         @@ -20,10 +20,38 @@ def constant_schedule(val: float) -> Schedule: 
     | 
|
| 20 | 
         
             
                return lambda f: val
         
     | 
| 21 | 
         | 
| 22 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 23 | 
         
             
            def schedule(name: str, start_val: float) -> Schedule:
         
     | 
| 24 | 
         
             
                if name == "linear":
         
     | 
| 25 | 
         
             
                    return linear_schedule(start_val, 0)
         
     | 
| 26 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 27 | 
         | 
| 28 | 
         | 
| 29 | 
         
             
            def update_learning_rate(optimizer: Optimizer, learning_rate: float) -> None:
         
     | 
| 
         | 
|
| 20 | 
         
             
                return lambda f: val
         
     | 
| 21 | 
         | 
| 22 | 
         | 
| 23 | 
         
            +
            def spike_schedule(
         
     | 
| 24 | 
         
            +
                max_value: float,
         
     | 
| 25 | 
         
            +
                start_fraction: float = 1e-2,
         
     | 
| 26 | 
         
            +
                end_fraction: float = 1e-4,
         
     | 
| 27 | 
         
            +
                peak_progress: float = 0.1,
         
     | 
| 28 | 
         
            +
            ) -> Schedule:
         
     | 
| 29 | 
         
            +
                assert 0 < peak_progress < 1
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def func(progress_fraction: float) -> float:
         
     | 
| 32 | 
         
            +
                    if progress_fraction < peak_progress:
         
     | 
| 33 | 
         
            +
                        fraction = (
         
     | 
| 34 | 
         
            +
                            start_fraction
         
     | 
| 35 | 
         
            +
                            + (1 - start_fraction) * progress_fraction / peak_progress
         
     | 
| 36 | 
         
            +
                        )
         
     | 
| 37 | 
         
            +
                    else:
         
     | 
| 38 | 
         
            +
                        fraction = 1 + (end_fraction - 1) * (progress_fraction - peak_progress) / (
         
     | 
| 39 | 
         
            +
                            1 - peak_progress
         
     | 
| 40 | 
         
            +
                        )
         
     | 
| 41 | 
         
            +
                    return max_value * fraction
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                return func
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
             
            def schedule(name: str, start_val: float) -> Schedule:
         
     | 
| 47 | 
         
             
                if name == "linear":
         
     | 
| 48 | 
         
             
                    return linear_schedule(start_val, 0)
         
     | 
| 49 | 
         
            +
                elif name == "none":
         
     | 
| 50 | 
         
            +
                    return constant_schedule(start_val)
         
     | 
| 51 | 
         
            +
                elif name == "spike":
         
     | 
| 52 | 
         
            +
                    return spike_schedule(start_val)
         
     | 
| 53 | 
         
            +
                else:
         
     | 
| 54 | 
         
            +
                    raise ValueError(f"Schedule {name} not supported")
         
     | 
| 55 | 
         | 
| 56 | 
         | 
| 57 | 
         
             
            def update_learning_rate(optimizer: Optimizer, learning_rate: float) -> None:
         
     | 
    	
        rl_algo_impls/shared/stats.py
    CHANGED
    
    | 
         @@ -1,14 +1,17 @@ 
     | 
|
| 1 | 
         
            -
            import  
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
             
            from dataclasses import dataclass
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 4 | 
         
             
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 5 | 
         
            -
            from typing import Dict, List, Optional, Sequence, Union, TypeVar
         
     | 
| 6 | 
         | 
| 7 | 
         | 
| 8 | 
         
             
            @dataclass
         
     | 
| 9 | 
         
             
            class Episode:
         
     | 
| 10 | 
         
             
                score: float = 0
         
     | 
| 11 | 
         
             
                length: int = 0
         
     | 
| 
         | 
|
| 12 | 
         | 
| 13 | 
         | 
| 14 | 
         
             
            StatisticSelf = TypeVar("StatisticSelf", bound="Statistic")
         
     | 
| 
         @@ -75,12 +78,25 @@ class EpisodesStats: 
     | 
|
| 75 | 
         
             
                simple: bool
         
     | 
| 76 | 
         
             
                score: Statistic
         
     | 
| 77 | 
         
             
                length: Statistic
         
     | 
| 
         | 
|
| 78 | 
         | 
| 79 | 
         
             
                def __init__(self, episodes: Sequence[Episode], simple: bool = False) -> None:
         
     | 
| 80 | 
         
             
                    self.episodes = episodes
         
     | 
| 81 | 
         
             
                    self.simple = simple
         
     | 
| 82 | 
         
             
                    self.score = Statistic(np.array([e.score for e in episodes]))
         
     | 
| 83 | 
         
             
                    self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 84 | 
         | 
| 85 | 
         
             
                def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
         
     | 
| 86 | 
         
             
                    return self.score > o.score
         
     | 
| 
         @@ -118,6 +134,8 @@ class EpisodesStats: 
     | 
|
| 118 | 
         
             
                                "length": self.length.mean,
         
     | 
| 119 | 
         
             
                            }
         
     | 
| 120 | 
         
             
                        )
         
     | 
| 
         | 
|
| 
         | 
|
| 121 | 
         
             
                    for name, value in stats.items():
         
     | 
| 122 | 
         
             
                        tb_writer.add_scalar(f"{main_tag}/{name}", value, global_step=global_step)
         
     | 
| 123 | 
         | 
| 
         @@ -131,19 +149,19 @@ class EpisodeAccumulator: 
     | 
|
| 131 | 
         
             
                def episodes(self) -> List[Episode]:
         
     | 
| 132 | 
         
             
                    return self._episodes
         
     | 
| 133 | 
         | 
| 134 | 
         
            -
                def step(self, reward: np.ndarray, done: np.ndarray) -> None:
         
     | 
| 135 | 
         
             
                    for idx, current in enumerate(self.current_episodes):
         
     | 
| 136 | 
         
             
                        current.score += reward[idx]
         
     | 
| 137 | 
         
             
                        current.length += 1
         
     | 
| 138 | 
         
             
                        if done[idx]:
         
     | 
| 139 | 
         
             
                            self._episodes.append(current)
         
     | 
| 140 | 
         
             
                            self.current_episodes[idx] = Episode()
         
     | 
| 141 | 
         
            -
                            self.on_done(idx, current)
         
     | 
| 142 | 
         | 
| 143 | 
         
             
                def __len__(self) -> int:
         
     | 
| 144 | 
         
             
                    return len(self.episodes)
         
     | 
| 145 | 
         | 
| 146 | 
         
            -
                def on_done(self, ep_idx: int, episode: Episode) -> None:
         
     | 
| 147 | 
         
             
                    pass
         
     | 
| 148 | 
         | 
| 149 | 
         
             
                def stats(self) -> EpisodesStats:
         
     | 
| 
         | 
|
| 1 | 
         
            +
            import dataclasses
         
     | 
| 2 | 
         
            +
            from collections import defaultdict
         
     | 
| 3 | 
         
             
            from dataclasses import dataclass
         
     | 
| 4 | 
         
            +
            from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
             
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 
         | 
|
| 8 | 
         | 
| 9 | 
         | 
| 10 | 
         
             
            @dataclass
         
     | 
| 11 | 
         
             
            class Episode:
         
     | 
| 12 | 
         
             
                score: float = 0
         
     | 
| 13 | 
         
             
                length: int = 0
         
     | 
| 14 | 
         
            +
                info: Dict[str, Dict[str, Any]] = dataclasses.field(default_factory=dict)
         
     | 
| 15 | 
         | 
| 16 | 
         | 
| 17 | 
         
             
            StatisticSelf = TypeVar("StatisticSelf", bound="Statistic")
         
     | 
| 
         | 
|
| 78 | 
         
             
                simple: bool
         
     | 
| 79 | 
         
             
                score: Statistic
         
     | 
| 80 | 
         
             
                length: Statistic
         
     | 
| 81 | 
         
            +
                additional_stats: Dict[str, Statistic]
         
     | 
| 82 | 
         | 
| 83 | 
         
             
                def __init__(self, episodes: Sequence[Episode], simple: bool = False) -> None:
         
     | 
| 84 | 
         
             
                    self.episodes = episodes
         
     | 
| 85 | 
         
             
                    self.simple = simple
         
     | 
| 86 | 
         
             
                    self.score = Statistic(np.array([e.score for e in episodes]))
         
     | 
| 87 | 
         
             
                    self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0)
         
     | 
| 88 | 
         
            +
                    additional_values = defaultdict(list)
         
     | 
| 89 | 
         
            +
                    for e in self.episodes:
         
     | 
| 90 | 
         
            +
                        if e.info:
         
     | 
| 91 | 
         
            +
                            for k, v in e.info.items():
         
     | 
| 92 | 
         
            +
                                if isinstance(v, dict):
         
     | 
| 93 | 
         
            +
                                    for k2, v2 in v.items():
         
     | 
| 94 | 
         
            +
                                        additional_values[f"{k}_{k2}"].append(v2)
         
     | 
| 95 | 
         
            +
                                else:
         
     | 
| 96 | 
         
            +
                                    additional_values[k].append(v)
         
     | 
| 97 | 
         
            +
                    self.additional_stats = {
         
     | 
| 98 | 
         
            +
                        k: Statistic(np.array(values)) for k, values in additional_values.items()
         
     | 
| 99 | 
         
            +
                    }
         
     | 
| 100 | 
         | 
| 101 | 
         
             
                def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
         
     | 
| 102 | 
         
             
                    return self.score > o.score
         
     | 
| 
         | 
|
| 134 | 
         
             
                                "length": self.length.mean,
         
     | 
| 135 | 
         
             
                            }
         
     | 
| 136 | 
         
             
                        )
         
     | 
| 137 | 
         
            +
                        for k, addl_stats in self.additional_stats.items():
         
     | 
| 138 | 
         
            +
                            stats[k] = addl_stats.mean
         
     | 
| 139 | 
         
             
                    for name, value in stats.items():
         
     | 
| 140 | 
         
             
                        tb_writer.add_scalar(f"{main_tag}/{name}", value, global_step=global_step)
         
     | 
| 141 | 
         | 
| 
         | 
|
| 149 | 
         
             
                def episodes(self) -> List[Episode]:
         
     | 
| 150 | 
         
             
                    return self._episodes
         
     | 
| 151 | 
         | 
| 152 | 
         
            +
                def step(self, reward: np.ndarray, done: np.ndarray, info: List[Dict]) -> None:
         
     | 
| 153 | 
         
             
                    for idx, current in enumerate(self.current_episodes):
         
     | 
| 154 | 
         
             
                        current.score += reward[idx]
         
     | 
| 155 | 
         
             
                        current.length += 1
         
     | 
| 156 | 
         
             
                        if done[idx]:
         
     | 
| 157 | 
         
             
                            self._episodes.append(current)
         
     | 
| 158 | 
         
             
                            self.current_episodes[idx] = Episode()
         
     | 
| 159 | 
         
            +
                            self.on_done(idx, current, info[idx])
         
     | 
| 160 | 
         | 
| 161 | 
         
             
                def __len__(self) -> int:
         
     | 
| 162 | 
         
             
                    return len(self.episodes)
         
     | 
| 163 | 
         | 
| 164 | 
         
            +
                def on_done(self, ep_idx: int, episode: Episode, info: Dict) -> None:
         
     | 
| 165 | 
         
             
                    pass
         
     | 
| 166 | 
         | 
| 167 | 
         
             
                def stats(self) -> EpisodesStats:
         
     | 
    	
        rl_algo_impls/shared/vec_env/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from rl_algo_impls.shared.vec_env.make_env import make_env, make_eval_env
         
     | 
    	
        rl_algo_impls/shared/vec_env/make_env.py
    ADDED
    
    | 
         @@ -0,0 +1,66 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from dataclasses import asdict
         
     | 
| 2 | 
         
            +
            from typing import Optional
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from rl_algo_impls.runner.config import Config, EnvHyperparams
         
     | 
| 7 | 
         
            +
            from rl_algo_impls.shared.vec_env.microrts import make_microrts_env
         
     | 
| 8 | 
         
            +
            from rl_algo_impls.shared.vec_env.procgen import make_procgen_env
         
     | 
| 9 | 
         
            +
            from rl_algo_impls.shared.vec_env.vec_env import make_vec_env
         
     | 
| 10 | 
         
            +
            from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def make_env(
         
     | 
| 14 | 
         
            +
                config: Config,
         
     | 
| 15 | 
         
            +
                hparams: EnvHyperparams,
         
     | 
| 16 | 
         
            +
                training: bool = True,
         
     | 
| 17 | 
         
            +
                render: bool = False,
         
     | 
| 18 | 
         
            +
                normalize_load_path: Optional[str] = None,
         
     | 
| 19 | 
         
            +
                tb_writer: Optional[SummaryWriter] = None,
         
     | 
| 20 | 
         
            +
            ) -> VecEnv:
         
     | 
| 21 | 
         
            +
                if hparams.env_type == "procgen":
         
     | 
| 22 | 
         
            +
                    return make_procgen_env(
         
     | 
| 23 | 
         
            +
                        config,
         
     | 
| 24 | 
         
            +
                        hparams,
         
     | 
| 25 | 
         
            +
                        training=training,
         
     | 
| 26 | 
         
            +
                        render=render,
         
     | 
| 27 | 
         
            +
                        normalize_load_path=normalize_load_path,
         
     | 
| 28 | 
         
            +
                        tb_writer=tb_writer,
         
     | 
| 29 | 
         
            +
                    )
         
     | 
| 30 | 
         
            +
                elif hparams.env_type in {"sb3vec", "gymvec"}:
         
     | 
| 31 | 
         
            +
                    return make_vec_env(
         
     | 
| 32 | 
         
            +
                        config,
         
     | 
| 33 | 
         
            +
                        hparams,
         
     | 
| 34 | 
         
            +
                        training=training,
         
     | 
| 35 | 
         
            +
                        render=render,
         
     | 
| 36 | 
         
            +
                        normalize_load_path=normalize_load_path,
         
     | 
| 37 | 
         
            +
                        tb_writer=tb_writer,
         
     | 
| 38 | 
         
            +
                    )
         
     | 
| 39 | 
         
            +
                elif hparams.env_type == "microrts":
         
     | 
| 40 | 
         
            +
                    return make_microrts_env(
         
     | 
| 41 | 
         
            +
                        config,
         
     | 
| 42 | 
         
            +
                        hparams,
         
     | 
| 43 | 
         
            +
                        training=training,
         
     | 
| 44 | 
         
            +
                        render=render,
         
     | 
| 45 | 
         
            +
                        normalize_load_path=normalize_load_path,
         
     | 
| 46 | 
         
            +
                        tb_writer=tb_writer,
         
     | 
| 47 | 
         
            +
                    )
         
     | 
| 48 | 
         
            +
                else:
         
     | 
| 49 | 
         
            +
                    raise ValueError(f"env_type {hparams.env_type} not supported")
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            def make_eval_env(
         
     | 
| 53 | 
         
            +
                config: Config,
         
     | 
| 54 | 
         
            +
                hparams: EnvHyperparams,
         
     | 
| 55 | 
         
            +
                override_n_envs: Optional[int] = None,
         
     | 
| 56 | 
         
            +
                **kwargs,
         
     | 
| 57 | 
         
            +
            ) -> VecEnv:
         
     | 
| 58 | 
         
            +
                kwargs = kwargs.copy()
         
     | 
| 59 | 
         
            +
                kwargs["training"] = False
         
     | 
| 60 | 
         
            +
                if override_n_envs is not None:
         
     | 
| 61 | 
         
            +
                    hparams_kwargs = asdict(hparams)
         
     | 
| 62 | 
         
            +
                    hparams_kwargs["n_envs"] = override_n_envs
         
     | 
| 63 | 
         
            +
                    if override_n_envs == 1:
         
     | 
| 64 | 
         
            +
                        hparams_kwargs["vec_env_class"] = "sync"
         
     | 
| 65 | 
         
            +
                    hparams = EnvHyperparams(**hparams_kwargs)
         
     | 
| 66 | 
         
            +
                return make_env(config, hparams, **kwargs)
         
     | 
    	
        rl_algo_impls/shared/vec_env/microrts.py
    ADDED
    
    | 
         @@ -0,0 +1,94 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from dataclasses import astuple
         
     | 
| 2 | 
         
            +
            from typing import Optional
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import gym
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            from torch.utils.tensorboard.writer import SummaryWriter
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from rl_algo_impls.runner.config import Config, EnvHyperparams
         
     | 
| 9 | 
         
            +
            from rl_algo_impls.wrappers.action_mask_wrapper import MicrortsMaskWrapper
         
     | 
| 10 | 
         
            +
            from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
         
     | 
| 11 | 
         
            +
            from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
         
     | 
| 12 | 
         
            +
            from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
         
     | 
| 13 | 
         
            +
            from rl_algo_impls.wrappers.microrts_stats_recorder import MicrortsStatsRecorder
         
     | 
| 14 | 
         
            +
            from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def make_microrts_env(
         
     | 
| 18 | 
         
            +
                config: Config,
         
     | 
| 19 | 
         
            +
                hparams: EnvHyperparams,
         
     | 
| 20 | 
         
            +
                training: bool = True,
         
     | 
| 21 | 
         
            +
                render: bool = False,
         
     | 
| 22 | 
         
            +
                normalize_load_path: Optional[str] = None,
         
     | 
| 23 | 
         
            +
                tb_writer: Optional[SummaryWriter] = None,
         
     | 
| 24 | 
         
            +
            ) -> VecEnv:
         
     | 
| 25 | 
         
            +
                import gym_microrts
         
     | 
| 26 | 
         
            +
                from gym_microrts import microrts_ai
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                from rl_algo_impls.shared.vec_env.microrts_compat import (
         
     | 
| 29 | 
         
            +
                    MicroRTSGridModeVecEnvCompat,
         
     | 
| 30 | 
         
            +
                )
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                (
         
     | 
| 33 | 
         
            +
                    _,  # env_type
         
     | 
| 34 | 
         
            +
                    n_envs,
         
     | 
| 35 | 
         
            +
                    _,  # frame_stack
         
     | 
| 36 | 
         
            +
                    make_kwargs,
         
     | 
| 37 | 
         
            +
                    _,  # no_reward_timeout_steps
         
     | 
| 38 | 
         
            +
                    _,  # no_reward_fire_steps
         
     | 
| 39 | 
         
            +
                    _,  # vec_env_class
         
     | 
| 40 | 
         
            +
                    _,  # normalize
         
     | 
| 41 | 
         
            +
                    _,  # normalize_kwargs,
         
     | 
| 42 | 
         
            +
                    rolling_length,
         
     | 
| 43 | 
         
            +
                    _,  # train_record_video
         
     | 
| 44 | 
         
            +
                    _,  # video_step_interval
         
     | 
| 45 | 
         
            +
                    _,  # initial_steps_to_truncate
         
     | 
| 46 | 
         
            +
                    _,  # clip_atari_rewards
         
     | 
| 47 | 
         
            +
                    _,  # normalize_type
         
     | 
| 48 | 
         
            +
                    _,  # mask_actions
         
     | 
| 49 | 
         
            +
                    bots,
         
     | 
| 50 | 
         
            +
                ) = astuple(hparams)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                seed = config.seed(training=training)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                make_kwargs = make_kwargs or {}
         
     | 
| 55 | 
         
            +
                if "num_selfplay_envs" not in make_kwargs:
         
     | 
| 56 | 
         
            +
                    make_kwargs["num_selfplay_envs"] = 0
         
     | 
| 57 | 
         
            +
                if "num_bot_envs" not in make_kwargs:
         
     | 
| 58 | 
         
            +
                    make_kwargs["num_bot_envs"] = n_envs - make_kwargs["num_selfplay_envs"]
         
     | 
| 59 | 
         
            +
                if "reward_weight" in make_kwargs:
         
     | 
| 60 | 
         
            +
                    make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"])
         
     | 
| 61 | 
         
            +
                if bots:
         
     | 
| 62 | 
         
            +
                    ai2s = []
         
     | 
| 63 | 
         
            +
                    for ai_name, n in bots.items():
         
     | 
| 64 | 
         
            +
                        for _ in range(n):
         
     | 
| 65 | 
         
            +
                            if len(ai2s) >= make_kwargs["num_bot_envs"]:
         
     | 
| 66 | 
         
            +
                                break
         
     | 
| 67 | 
         
            +
                            ai = getattr(microrts_ai, ai_name)
         
     | 
| 68 | 
         
            +
                            assert ai, f"{ai_name} not in microrts_ai"
         
     | 
| 69 | 
         
            +
                            ai2s.append(ai)
         
     | 
| 70 | 
         
            +
                else:
         
     | 
| 71 | 
         
            +
                    ai2s = [microrts_ai.randomAI for _ in make_kwargs["num_bot_envs"]]
         
     | 
| 72 | 
         
            +
                make_kwargs["ai2s"] = ai2s
         
     | 
| 73 | 
         
            +
                envs = MicroRTSGridModeVecEnvCompat(**make_kwargs)
         
     | 
| 74 | 
         
            +
                envs = HwcToChwObservation(envs)
         
     | 
| 75 | 
         
            +
                envs = IsVectorEnv(envs)
         
     | 
| 76 | 
         
            +
                envs = MicrortsMaskWrapper(envs)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                if seed is not None:
         
     | 
| 79 | 
         
            +
                    envs.action_space.seed(seed)
         
     | 
| 80 | 
         
            +
                    envs.observation_space.seed(seed)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                envs = gym.wrappers.RecordEpisodeStatistics(envs)
         
     | 
| 83 | 
         
            +
                envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99))
         
     | 
| 84 | 
         
            +
                if training:
         
     | 
| 85 | 
         
            +
                    assert tb_writer
         
     | 
| 86 | 
         
            +
                    envs = EpisodeStatsWriter(
         
     | 
| 87 | 
         
            +
                        envs,
         
     | 
| 88 | 
         
            +
                        tb_writer,
         
     | 
| 89 | 
         
            +
                        training=training,
         
     | 
| 90 | 
         
            +
                        rolling_length=rolling_length,
         
     | 
| 91 | 
         
            +
                        additional_keys_to_log=config.additional_keys_to_log,
         
     | 
| 92 | 
         
            +
                    )
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                return envs
         
     | 
    	
        rl_algo_impls/shared/vec_env/microrts_compat.py
    ADDED
    
    | 
         @@ -0,0 +1,49 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import TypeVar
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
         
     | 
| 5 | 
         
            +
            from jpype.types import JArray, JInt
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvStepReturn
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            MicroRTSGridModeVecEnvCompatSelf = TypeVar(
         
     | 
| 10 | 
         
            +
                "MicroRTSGridModeVecEnvCompatSelf", bound="MicroRTSGridModeVecEnvCompat"
         
     | 
| 11 | 
         
            +
            )
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class MicroRTSGridModeVecEnvCompat(MicroRTSGridModeVecEnv):
         
     | 
| 15 | 
         
            +
                def step(self, action: np.ndarray) -> VecEnvStepReturn:
         
     | 
| 16 | 
         
            +
                    indexed_actions = np.concatenate(
         
     | 
| 17 | 
         
            +
                        [
         
     | 
| 18 | 
         
            +
                            np.expand_dims(
         
     | 
| 19 | 
         
            +
                                np.stack(
         
     | 
| 20 | 
         
            +
                                    [np.arange(0, action.shape[1]) for i in range(self.num_envs)]
         
     | 
| 21 | 
         
            +
                                ),
         
     | 
| 22 | 
         
            +
                                axis=2,
         
     | 
| 23 | 
         
            +
                            ),
         
     | 
| 24 | 
         
            +
                            action,
         
     | 
| 25 | 
         
            +
                        ],
         
     | 
| 26 | 
         
            +
                        axis=2,
         
     | 
| 27 | 
         
            +
                    )
         
     | 
| 28 | 
         
            +
                    action_mask = np.array(self.vec_client.getMasks(0), dtype=np.bool8).reshape(
         
     | 
| 29 | 
         
            +
                        indexed_actions.shape[:-1] + (-1,)
         
     | 
| 30 | 
         
            +
                    )
         
     | 
| 31 | 
         
            +
                    valid_action_mask = action_mask[:, :, 0]
         
     | 
| 32 | 
         
            +
                    valid_actions_counts = valid_action_mask.sum(1)
         
     | 
| 33 | 
         
            +
                    valid_actions = indexed_actions[valid_action_mask]
         
     | 
| 34 | 
         
            +
                    valid_actions_idx = 0
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    all_valid_actions = []
         
     | 
| 37 | 
         
            +
                    for env_act_cnt in valid_actions_counts:
         
     | 
| 38 | 
         
            +
                        env_valid_actions = []
         
     | 
| 39 | 
         
            +
                        for _ in range(env_act_cnt):
         
     | 
| 40 | 
         
            +
                            env_valid_actions.append(JArray(JInt)(valid_actions[valid_actions_idx]))
         
     | 
| 41 | 
         
            +
                            valid_actions_idx += 1
         
     | 
| 42 | 
         
            +
                        all_valid_actions.append(JArray(JArray(JInt))(env_valid_actions))
         
     | 
| 43 | 
         
            +
                    return super().step(JArray(JArray(JArray(JInt)))(all_valid_actions))  # type: ignore
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                @property
         
     | 
| 46 | 
         
            +
                def unwrapped(
         
     | 
| 47 | 
         
            +
                    self: MicroRTSGridModeVecEnvCompatSelf,
         
     | 
| 48 | 
         
            +
                ) -> MicroRTSGridModeVecEnvCompatSelf:
         
     | 
| 49 | 
         
            +
                    return self
         
     |