Source code for ns_gym.benchmark_algorithms.MCTS

import gymnasium as gym
import numpy as np
from copy import deepcopy
import ns_gym as nsg

import ns_gym.base as base
import random


"""
MCTS with Chance Nodes to handle stochastic environments. This implementation used a global table to store the Q values and visit counts for state-action pairs and states. Compatible with OpenAI Gym environments.
"""

class DecisionNode:
    """
    Decision node class, labelled by a state.
    """
    def __init__(self, parent, state, weight, is_terminal,reward):
        """
        Args:
            parent (ChanceNode): The parent node of the decision node.
            state (Union[int,np.ndarray]): Environment state.
            weight (float): Probability to occur given the parent (state-action pair)
            is_terminal (bool): Is the state terminal.
            reward (float): immediate reward for reaching this state.

        Attributes:
            children (list): List of child nodes.
            value (float): Value of the state.
            weighted_value (float): Weighted value of the state.

        """
        self.parent = parent
        state, _ = nsg.utils.type_mismatch_checker(observation=state,reward=None)

        assert not isinstance(state, dict), "State is still a dict after type checking."
        # if isinstance(state, np.ndarray):
        #     state = tuple(state)

        # if isinstance(state, dict) and 'state' in state:
        #     state = state['state']
        # if isinstance(state, np.ndarray):
        #     state = tuple(state)


        self.state = state
        self.weight = weight  # Probability to occur
        self.is_terminal = is_terminal
        if self.parent is None:  # Root node
            self.depth = 0
        else:  # Non root node
            self.depth = parent.depth + 1
        self.children = []
        self.value = 0 # value of state
        self.reward = reward# immediate reward
        self.weighted_value = self.weight * self.value

class ChanceNode:
    """
    Chance node class, labelled by a state-action pair.
    The state is accessed via the parent attribute.
    """
    def __init__(self, parent, action):
        """
        Args:
            parent (DecicionsNode): Parent node of the chance node, a decision node.
            action (int): Action taken from the parent node, ie state_1 has child (state_2,action_1) say 
        
        Attributes:
            children (list): List of child nodes (DecisionNode)
            value (float): Value of the state-action pair.
            depth (int): Depth of the node in the tree.
        """
        self.parent = parent
        self.action = action
        self.depth = parent.depth
        self.children = []
        self.value = 0

[docs] class MCTS(base.Agent): """Vanilla MCTS with Chance Nodes. Compatible with OpenAI Gym environments. Selection and expansion are combined into the "treepolicy method" The rollout/simulation is the "default" policy. Args: env (gym.Env): The environment to run the MCTS on. state (Union[int, np.ndarray]): The state to start the MCTS from. d (int): The depth of the MCTS. m (int): The number of simulations to run. c (float): The exploration constant. gamma (float): The discount factor. Attributes: v0 (DecisionNode): The root node of the tree. possible_actions (list): List of possible actions in the environment. Qsa (dict): Dictionary to store Q values for state-action pairs. Nsa (dict): Dictionary to store visit counts for state-action pairs. Ns (dict): Dictionary to store visit counts for states. """ def __init__(self,env:gym.Env,state,d,m,c,gamma) -> None: """ """ self.env = env # This is the current state of the mdp self.d = d # depth self.m = m # number of simulations self.c = c # exploration constant state, _ = nsg.utils.type_mismatch_checker(observation=state,reward=None) self.v0 = DecisionNode(parent=None,state=state,weight=1,is_terminal=False,reward=0) if not isinstance(env.action_space,gym.spaces.Discrete): raise ValueError("Only discrete action spaces are supported") self.possible_actions = [x for x in range(env.action_space.n)] self.gamma = gamma self.Qsa = {} # stores Q values for s,a pairs, defaults to Qsa of 0 self.Nsa = {} # stores visit counts for s,a pairs, default to Nsa of 0 self.Ns = {} # stores visit counts for states, default to Ns of 0
[docs] def search(self): """Do the MCTS by doing m simulations from the current state s. After doing m simulations we simply choose the action that maximizes the estimate of Q(s,a) Returns: best_action(int): best action to take action_values(list): list of Q values for each action. """ for k in range(self.m): self.sim_env = deepcopy(self.env) # make a deep copy of of the og env at the root nod vl = self._tree_policy(self.v0) #vl is the last node visitied by the tree search as chance node expanded_node = self._expand(vl) if type(expanded_node) == ChanceNode: expanded_node = self._expand(expanded_node) #DecisionNode R = self._default_policy(expanded_node) #R is the reward from the simulation (default policy) self._backpropagation(R,expanded_node) ba = self.best_action(self.v0) # best action action_values = [self.Qsa[(self.v0.state,a)] for a in self.possible_actions] # Q values for s a pairs ba = np.argmax(action_values) return ba,action_values
def _tree_policy(self, node) -> ChanceNode: """Tree policy for MCTS. Traverse the tree from the root node to a leaf node. Args: node (DecisionNode): The root node of the tree. Returns: ChanceNode: The leaf node reached by the tree policy. """ while node.children: if type(node) == DecisionNode: node = self._selection(node) assert(type(node) == ChanceNode) else: # chance node assert(type(node) == ChanceNode) node = self._expand(node) assert(type(node) == DecisionNode),f"got {type(node)} instead of DecisionNode" return node def _default_policy(self,v:DecisionNode): """Simulate/Playout step While state is non-terminal choose an action uniformly at random, transition to new state. Return the reward for final state. Args: v (DecisionNode): The node to start the simulation from. """ if v.is_terminal: return v.reward tot_reward = 0 terminated = False truncated = False depth = 0 while not terminated and depth < self.d and not truncated: action = np.random.choice(self.possible_actions) observation,reward,terminated,truncated,info = self.sim_env.step(action) observation ,reward = self.type_checker(observation,reward) tot_reward += reward*self.gamma**depth depth+=1 return tot_reward def _selection(self,v:DecisionNode): """Pick the next node to go down in the search tree based on UTC formula. """ best_child = self.best_child(v) return best_child def _expand(self,node): """Expand the tree by adding a new node to the tree. Handles both decision and chance nodes. """ if type(node) == DecisionNode: if node.is_terminal: return node for a in range(self.sim_env.action_space.n): new_node = ChanceNode(parent=node,action=a) node.children.append(new_node) return np.random.choice(node.children) else: # chance node action = node.action assert(type(node)==ChanceNode) obs,reward,term,_,info = self.sim_env.step(action) obs,reward = self.type_checker(obs,reward) existing_child = [child for child in node.children if child.state == obs] if existing_child: return existing_child[0] else: if "prob" in info: w = info["prob"] else: w = 1 new_node = DecisionNode(parent=node,state=obs,weight=w,is_terminal=term,reward=reward) node.children.append(new_node) return new_node def _backpropagation(self,R,v,depth=0): """Backtrack to update the number of times a node has beenm visited and the value of a node untill we reach the root node. """ depth = 0 while v: v.value += R if type(v) == ChanceNode: assert not isinstance(v.parent.state, dict), "Parent state is still a dict after type checking." self.update_metrics_chance_node(v.parent.state,v.action,R) else: self.update_metrics_decision_node(v.state) R = R*(self.gamma**depth) depth+=1 v = v.parent
[docs] def update_metrics_chance_node(self, state, action, reward): """Update the Q values and visit counts for state-action pairs and states. Args: state (Union[int,np.ndarray]): The state. action (Union[int,float,np.ndarray]): action taken at the state. reward (float): The reward received after taking the action at the state. """ if isinstance(state, np.ndarray): state = tuple(state) if isinstance(action, np.ndarray): action = tuple(action) sa = (state, action) if sa in self.Qsa: self.Qsa[sa] = (self.Qsa[sa] * self.Nsa[sa] + reward) / (self.Nsa[sa] + 1) self.Nsa[sa] += 1 else: self.Qsa[sa] = reward self.Nsa[sa] = 1
[docs] def update_metrics_decision_node(self, state): """Update the visit counts for states. """ if state in self.Ns: self.Ns[state] += 1 else: self.Ns[state] = 1
[docs] def type_checker(self, observation, reward): """Converts the observation and reward from dict and base.Reward type to the correct type if they are not already. Args: observation (Union[dict, np.ndarray]): Observation to convert. reward (Union[float, base.Reward]): Reward to convert. Returns: (int,np.ndarray): Converted observation. (float): Converted reward. """ if isinstance(observation, dict) and 'state' in observation: observation = observation['state'] if isinstance(observation, np.ndarray): observation = tuple(observation) if isinstance(reward, base.Reward): reward = reward.reward #DEBUGGING ASSERTION assert not isinstance(observation, dict), "Observation is still a dict after type checking." return observation, reward
[docs] def best_child(self,v): """Find the best child nodes based on the UCT value. This method is only called for decision nodes. Args: exploration_constant (_type_, optional): _description_. Defaults to math.sqrt(2). Returns: Node: The best child node based on the UCT value. action: The action that leads to the best child node. """ best_value = -np.inf best_nodes = [] children = v.children for child in children: sa = (child.parent.state, child.action) if sa in self.Qsa: ucb_value = self.Qsa[sa] + self.c * np.sqrt( np.log(self.Ns.get(sa[0], 1)) / self.Nsa[sa]) else: ucb_value = self.c * np.sqrt( np.log(self.Ns.get(sa[0], 1)) / 1) # Assume at least one visit ucb_value = np.inf if ucb_value > best_value: best_value = ucb_value best_nodes = [child] elif ucb_value == best_value: best_nodes.append(child) return random.choice(best_nodes) if best_nodes else None
[docs] def best_action(self,v): """Select the best action based on the Q values of the state-action pairs. Returns: best_action(int): best action to) """ best_action = None best_avg_value = -np.inf s = v.state # root is Type[Node] # Iterate through all possible actions from this state for a in range(self.env.action_space.n): sa = (s, a) # Create a state-action pair # Check if this state-action pair has been explored if sa in self.Qsa and sa in self.Nsa and self.Nsa[sa] > 0: avg_value = self.Qsa[sa] / self.Nsa[sa] # Calculate average value if avg_value > best_avg_value: best_avg_value = avg_value best_action = a # Ensure a valid action is selected, even if no action has been explored if best_action is None and self.possible_actions: best_action = np.random.choice(self.possible_actions) return best_action
[docs] def act(self, observation, env): """ Decide on an action using the MCTS search, reinitializing the tree structure. Args: observation (Union[int, np.ndarray]): The current state or observation of the environment. Returns: int: The selected action. """ observation,_ = nsg.utils.type_mismatch_checker(observation=observation,reward=None) if isinstance(observation, np.ndarray): observation = tuple(observation) # Reinitialize the instance by calling __init__ self.__init__(env, observation, self.d, self.m, self.c, self.gamma) # Perform MCTS search to determine the best action best_action, _ = self.search() return best_action
if __name__ == "__main__": env = gym.make("CartPole-v1",max_episode_steps=500) scheduler = nsg.schedulers.ContinuousScheduler() update_fn= nsg.update_functions.NoUpdate(scheduler=scheduler) env = nsg.wrappers.NSClassicControlWrapper(env,tunable_params={"masspole":update_fn}) ########### EXAMPLE USAGE ################ # env = gym.make("FrozenLake-v1",max_episode_steps=100) # scheduler = nsg.schedulers.ContinuousScheduler() # update_fn= nsg.update_functions.DistributionNoUpdate(scheduler=scheduler) # env = nsg.wrappers.NSFrozenLakeWrapper(env,tunable_params={"P":update_fn}) decision_times = [] for i in range(1): obs, info = env.reset() done = False truncated = False step = 0 max_steps = 500 mcts_agent = MCTS(env,obs,15,50,1.44,0.999) reward_list = [] while not done and not truncated and step < max_steps: a = mcts_agent.act(obs,env) obs,reward,done,truncated,info = env.step(a) reward_list.append(reward.reward) # decision_times.append(time.time()-start) step+=1 if step%10 == 0: print("Step ",step) print("Reward ",np.sum(reward_list))