Spaces:
Runtime error
Runtime error
from tqdm import tqdm | |
import numpy as np | |
import torch | |
import collections | |
import random | |
class ReplayBuffer: | |
def __init__(self, capacity): | |
self.buffer = collections.deque(maxlen=capacity) | |
def add(self, state, action, reward, next_state, done): | |
self.buffer.append((state, action, reward, next_state, done)) | |
def sample(self, batch_size): | |
transitions = random.sample(self.buffer, batch_size) | |
state, action, reward, next_state, done = zip(*transitions) | |
return np.array(state), np.array(action), reward, np.array(next_state), done | |
def size(self): | |
return len(self.buffer) |