Source code for ns_gym.benchmark_algorithms.PAMCTS

"""Policy-Augmented MCTS (PA-MCTS).

Combines a learned action-value head (DDQN, SB3 DQN, or any callable
that returns per-action Q-values for a given (state, env)) with online
MCTS via a convex combination after both heads are normalized to
``[0, 1]``.

Two ways to wire up the Q-head:

1. **Legacy** -- pass ``DDQN_model`` (a torch ``nn.Module`` instance)
   and ``DDQN_model_path`` (a state-dict ``.pth``). PAMCTS builds an
   internal :class:`~ns_gym.benchmark_algorithms.DDQN.DDQN.DQNAgent`
   and feeds it the raw state.

2. **Recommended** -- pass ``q_value_fn``, a callable
   ``q_value_fn(state, env) -> np.ndarray`` that returns the per-action
   Q-values. This decouples PAMCTS from the ns_gym DQN format and lets
   you plug in any value head: a Stable-Baselines3 DQN via
   :class:`~ns_gym.base.StableBaselineWrapper`, a hand-trained PyTorch
   net, even a precomputed table.

Exactly one of those two paths must be provided.
"""

import numpy as np

import ns_gym as nsg
from ns_gym import base


[docs] class PAMCTS(base.Agent): """Policy-Augmented MCTS. Args: alpha (float): Convex combination weight on the learned policy's Q-values. ``alpha = 0`` -> pure MCTS, ``alpha = 1`` -> pure learned head. mcts_iter (int): Total MCTS rollouts (m). mcts_search_depth (int): MCTS rollout depth (d). mcts_discount_factor (float): MCTS discount factor (gamma). mcts_exploration_constant (float): UCB1 exploration constant (c). state_space_size (int): Discrete state-space size, only used by the legacy DDQN wiring path. action_space_size (int): Number of discrete actions, only used by the legacy DDQN wiring path. DDQN_model (torch.nn.Module, optional): Legacy. Architecture instance whose weights will be loaded from ``DDQN_model_path``. DDQN_model_path (str, optional): Legacy. Path to a state-dict ``.pth`` file matching ``DDQN_model``. q_value_fn (Callable[[state, env], np.ndarray], optional): Recommended. A function that returns per-action Q-values for a given ``(state, env)`` pair. Bypasses the legacy DDQN wiring entirely. seed (int): Random seed forwarded to the legacy :class:`DQNAgent`. Unused when ``q_value_fn`` is supplied. Examples: >>> # Legacy: ns_gym DQN architecture + .pth state dict >>> from ns_gym.benchmark_algorithms.DDQN.DDQN import DQN as DDQNNet >>> arch = DDQNNet(state_size=16, action_size=4, ... num_layers=3, num_hidden_units=64, seed=0) >>> agent = PAMCTS(alpha=0.75, mcts_iter=30, mcts_search_depth=20, ... mcts_discount_factor=0.95, ... mcts_exploration_constant=1.4, ... state_space_size=16, action_space_size=4, ... DDQN_model=arch, ... DDQN_model_path="weights.pth") >>> # Recommended: Stable-Baselines3 DQN via StableBaselineWrapper >>> from stable_baselines3 import DQN >>> from ns_gym.base import StableBaselineWrapper >>> sb3 = DQN.load("contextual_ddqn.zip") >>> def obs_fn(state, env): ... return np.concatenate([ ... np.eye(16, dtype=np.float32)[int(state)], ... np.asarray(env.transition_prob, dtype=np.float32), ... ]) >>> wrap = StableBaselineWrapper(sb3, obs_fn=obs_fn) >>> agent = PAMCTS(alpha=0.75, mcts_iter=30, mcts_search_depth=20, ... mcts_discount_factor=0.95, ... mcts_exploration_constant=1.4, ... state_space_size=16, action_space_size=4, ... q_value_fn=wrap.q_values) """ def __init__(self, alpha, mcts_iter, mcts_search_depth, mcts_discount_factor, mcts_exploration_constant, state_space_size, action_space_size, DDQN_model=None, DDQN_model_path=None, q_value_fn=None, seed=0) -> None: self.alpha = alpha self.mcts_iter = mcts_iter self.mcts_search_depth = mcts_search_depth self.mcts_discount_factor = mcts_discount_factor self.mcts_exploration_constant = mcts_exploration_constant self.state_space_size = state_space_size self.action_space_size = action_space_size # Wire the Q-head exactly once. if q_value_fn is not None: if DDQN_model is not None or DDQN_model_path is not None: raise ValueError( "PAMCTS: pass *either* q_value_fn or the legacy " "(DDQN_model, DDQN_model_path) pair -- not both." ) if not callable(q_value_fn): raise TypeError( "PAMCTS: q_value_fn must be a callable " "f(state, env) -> np.ndarray of per-action Q-values." ) self._q_value_fn = q_value_fn self.DDQN_agent = None else: if DDQN_model is None or DDQN_model_path is None: raise ValueError( "PAMCTS: provide either q_value_fn (recommended) or " "both DDQN_model and DDQN_model_path (legacy)." ) self.DDQN_model = DDQN_model self.DDQN_model_path = DDQN_model_path self.DDQN_agent = nsg.benchmark_algorithms.DDQN.DQNAgent( self.state_space_size, self.action_space_size, seed=seed, model=self.DDQN_model, model_path=self.DDQN_model_path, ) self._q_value_fn = None def _learned_q(self, state, env): """Per-action Q-values from whichever wiring was selected.""" if self._q_value_fn is not None: return np.asarray(self._q_value_fn(state, env), dtype=np.float32) # Legacy DDQN_agent.search returns (action, all_q_values) _, q_values = self.DDQN_agent.search(state) return np.asarray(q_values, dtype=np.float32).ravel()
[docs] def search(self, state, env, normalize=True): self.mcts_agent = nsg.benchmark_algorithms.MCTS( env, state, d=self.mcts_search_depth, m=self.mcts_iter, c=self.mcts_exploration_constant, gamma=self.mcts_discount_factor, ) _mcts_action, mcts_action_values = self.mcts_agent.search() ddqn_action_values = self._learned_q(state, env) mcts_action_values = np.asarray(mcts_action_values, dtype=np.float32) if normalize: eps = 1e-8 ddqn_action_values = ( (ddqn_action_values - ddqn_action_values.min()) / (ddqn_action_values.max() - ddqn_action_values.min() + eps) ) mcts_action_values = ( (mcts_action_values - mcts_action_values.min()) / (mcts_action_values.max() - mcts_action_values.min() + eps) ) hybrid = self._get_pa_uct_score( self.alpha, ddqn_action_values, mcts_action_values, ) return int(np.argmax(hybrid)), hybrid
[docs] def act(self, state, env, normalize=True): action, _ = self.search(state, env, normalize) return action
def _get_pa_uct_score(self, alpha, policy_value, mcts_return): return (alpha * policy_value) + ((1.0 - alpha) * mcts_return)
if __name__ == "__main__": pass