Source code for ns_gym.benchmark_algorithms.DDPG

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from ns_gym import base

import math


class Actor(nn.Module):
    def __init__(self, state_dim, action_dim,hidden_size=256):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc_out = nn.Linear(hidden_size, action_dim)
        self.activation = nn.ReLU()  # Hidden layers activation

    def forward(self, state):
        """
        Forward pass through the actor network.

        Args:
            state: Tensor of shape (batch_size, state_dim).

        Returns:
            action: Tensor of shape (batch_size, action_dim).
        """

        if not isinstance(state, torch.Tensor):
            state = torch.tensor(state, dtype=torch.float32) 

        x = self.activation(self.fc1(state))
        x = self.activation(self.fc2(x))
        action = self.fc_out(x)
        return action
    
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim,hidden_size=256):
        super(Critic, self).__init__()
        # First layer processes both state and action
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
        self.activation = nn.ReLU()  # Non-linear activation

    def forward(self, state, action):
        """
        Forward pass through the Critic network.

        Args:
            state: Tensor of shape (batch_size, state_dim).
            action: Tensor of shape (batch_size, action_dim).

        Returns:
            Q-value: Tensor of shape (batch_size, 1).
        """
        if not isinstance(state, torch.Tensor):
            state = torch.tensor(state, dtype=torch.float32)
        if not isinstance(action, torch.Tensor):
            action = torch.tensor(action, dtype=torch.float32)
        
        # Concatenate state and action
        x = torch.cat([state, action], dim=1)

        # Forward pass through layers
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)  # No activation in the output layer
        return x

class ReplayBuffer(Dataset):
    def __init__(self, capacity):
        """
        A replay buffer that is compatible with PyTorch DataLoader.

        Args:
            capacity (int): Maximum number of transitions to store in the buffer.
        """
        self.capacity = capacity
        self.buffer = []
        self.position = 0  # Tracks the position to overwrite old data

    def add(self, state, action, reward, next_state, done):
        """
        Add a new transition to the buffer.

        Args:
            state: The current state (torch tensor or numpy array).
            action: The action taken.
            reward: The reward received.
            next_state: The next state (torch tensor or numpy array).
            done: Boolean indicating if the episode ended.
        """
        # Convert all inputs to tensors
        transition = (
            torch.tensor(state, dtype=torch.float32),
            torch.tensor(action, dtype=torch.float32),
            torch.tensor(reward, dtype=torch.float32),
            torch.tensor(next_state, dtype=torch.float32),
            torch.tensor(done, dtype=torch.float32),
        )

        if len(self.buffer) < self.capacity:
            self.buffer.append(transition)
        else:
            self.buffer[self.position] = transition

        self.position = (self.position + 1) % self.capacity

    def __len__(self):
        """
        Return the current size of the buffer.
        """
        return len(self.buffer)

    def __getitem__(self, idx):
        """
        Retrieve a single transition from the buffer.

        Args:
            idx (int): Index of the transition.

        Returns:
            tuple: A single transition (state, action, reward, next_state, done).
        """
        return self.buffer[idx]


[docs] class DDPG(base.Agent): """Deep Deterministic Policy Gradient (DDPG) algorithm. Args: state_dim (int): Dimension of the state space. action_dim (int): Dimension of the action space. hidden_size (int): Number of hidden units in each layer of the networks. lr_actor (float): Learning rate for the actor network. lr_critic (float): Learning rate for the critic network. Warning: This implementation works though the StableBaselines3 implementation is likely better optimized. """ def __init__(self,state_dim=8,action_dim=2,hidden_size=256,lr_actor=0.001,lr_critic=0.001): ####### Initialize the Actor and Critic Networks ######## self.actor = Actor(state_dim,action_dim) actor_weights = self.actor.state_dict() # make sure actor target have the same initial weights self.actor_target = Actor(state_dim,action_dim) self.actor_target.load_state_dict(actor_weights) self.best_score = -math.inf self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=lr_actor) self.critic = Critic(state_dim,action_dim) # make sure critic target have the same initial weights self.critic_target = Critic(state_dim,action_dim) self.critic_target.load_state_dict(self.critic.state_dict()) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=lr_critic) ######### Initialize the Replay Buffer ######### self.replay_buffer = ReplayBuffer(capacity=1000000)
[docs] def train(self, env, num_episodes=10000, batch_size=64, gamma=0.99, tau=0.005,warmup_episodes=300,save_path="models/"): self.warmup(env,warmup_episodes) # Warmup the replay buffer ep_reward_list = [] self.actor.train() self.critic.train() initial_noise_scale = 0.1 decay_rate = 0.99 min_noise_scale = 0.01 ep = 0 self.best_score = -math.inf while True: state,info = env.reset() t = 0 done = False truncated = False reward_list = [] critic_loss_list = [] actor_loss_list = [] while not done and not truncated: # take an action in the environment # noise_scale = max(initial_noise_scale * (decay_rate ** ep), min_noise_scale) # if ep < 600: # noise = torch.randn(2) * 0.15 # else: # noise = torch.randn(2) * 0.05 # noise = torch.randn(2) * 0.15 action = self.actor(torch.tensor(state)) + noise action = torch.clamp(action, -1, 1) # Clip action to [-1, 1] with torch.no_grad(): action = action.squeeze(0).cpu().numpy() # Remove batch dimension and convert to numpy next_state, reward, done, truncated, info = env.step(action) # store the reward reward_list.append(reward) # add the transition to the replay buffer self.replay_buffer.add(state, action, reward, next_state, done) state = next_state # sample a batch from the replay buffer dl = DataLoader(self.replay_buffer, batch_size=batch_size, shuffle=True) sample = next(iter(dl)) states, actions, rewards, next_states, dones = sample # update the actor and critic networks critic_loss,actor_loss = self.update(states, actions, rewards, next_states, dones, gamma, tau) critic_loss_list.append(critic_loss) actor_loss_list.append(actor_loss) t += 1 ep += 1 ep_reward_list.append(sum(reward_list)) if ep % 100 == 0: print(f"Episode: {ep}, Reward: {sum(reward_list)}, Critic Loss: {sum(critic_loss_list)/len(critic_loss_list)}, Actor Loss: {sum(actor_loss_list)/len(actor_loss_list)}") print("ave reward: ", sum(ep_reward_list[-100:])/100) if len(ep_reward_list) > 100: avg_reward = sum(ep_reward_list[-100:])/100 if avg_reward > self.best_score: self.best_score = avg_reward torch.save(self.actor.state_dict(), save_path+"actor.pth") torch.save(self.critic.state_dict(), save_path+"critic.pth") return self.best_score
[docs] def update(self, states, actions, rewards, next_states, dones, gamma=0.99, tau=0.001): """ Update the actor and critic networks for one training step in DDPG. Args: states: Batch of current states. actions: Batch of actions taken. rewards: Batch of rewards received. next_states: Batch of next states. dones: Batch of done flags (indicating episode termination). gamma: Discount factor. tau: Target network soft update parameter. """ # Convert dones to a mask for termination dones = dones.float() # 1. Update Critic Network # Compute target Q-value with torch.no_grad(): next_actions = self.actor_target(next_states) target_q_values = self.critic_target(next_states, next_actions) target_q_values = target_q_values.squeeze(1) y = rewards + gamma * target_q_values * (1 - dones) # Mask out terminal states # Compute critic loss current_q_values = self.critic(states, actions) current_q_values = current_q_values.squeeze(1) critic_loss = nn.MSELoss()(current_q_values, y) # Optimize critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # 2. Update Actor Network actor_loss = -self.critic(states, self.actor(states)).mean() # Optimize actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # 3. Soft Update Target Networks for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) return critic_loss.item(), actor_loss.item()
[docs] def warmup(self,env,warmup_episodes=300): for i in range(warmup_episodes): state,info = env.reset() done = False truncated = False t = 0 print(f"Warmup Episode: {i}") while not (done or truncated): action = self.actor(torch.tensor(state)) + torch.randn(2) * 0.15 # Add noise to the action with torch.no_grad(): action = action.squeeze(0).cpu().numpy() # Remove batch dimension and convert to numpy next_state, reward, done, truncated, info = env.step(action) # if truncated: # print("Truncated") self.replay_buffer.add(state, action, reward, next_state, done) state = next_state t += 1 # if done or truncated: # state,info = env.reset() print("Warmup Done")
[docs] def run_eval_episode(self,env,T=100,visualize=False): self.actor.eval() reward_list = [] state,info = env.reset() done = False truncated = False t = 0 while not done and not truncated: action = self.actor(state) with torch.no_grad(): action = action.squeeze(0).cpu().numpy() # Remove batch dimension and convert to numpy next_state, reward, done, truncated, info = env.step(action) if visualize: env.render() if truncated: print("Truncated") reward_list.append(reward) state = next_state return sum(reward_list), reward_list
[docs] def act(self,observation,clip=False): if clip: return torch.clamp(self.actor(observation),-1,1) return self.actor(observation)
if __name__ == '__main__': import gymnasium as gym env = gym.make("LunarLanderContinuous-v3")