File size: 634 Bytes
9839b09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
import torch

from dataclasses import dataclass
from typing import List


@dataclass
class Trajectory:
    obs: List[np.ndarray]
    act: List[np.ndarray]
    rew: List[float]
    v: List[float]
    terminated: bool

    def __init__(self) -> None:
        self.obs = []
        self.act = []
        self.rew = []
        self.v = []
        self.terminated = False

    def add(self, obs: np.ndarray, act: np.ndarray, rew: float, v: float):
        self.obs.append(obs)
        self.act.append(act)
        self.rew.append(rew)
        self.v.append(v)

    def __len__(self) -> int:
        return len(self.obs)