PyCIL / utils /rl_utils /rl_utils.py
HungNP
New single commit message
cb80c28
raw
history blame contribute delete
637 Bytes
from tqdm import tqdm
import numpy as np
import torch
import collections
import random
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = collections.deque(maxlen=capacity)
def add(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
transitions = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = zip(*transitions)
return np.array(state), np.array(action), reward, np.array(next_state), done
def size(self):
return len(self.buffer)