Source code for ns_gym.update_functions.distribution

import ns_gym.base as base
import ns_gym.utils as utils
from typing import Any, Type, Union, Optional
import numpy as np

"""
These classes update the probability distributions represented as a lists.
"""


[docs] class RandomCategorical(base.UpdateDistributionFn): """Update the distirbution as a random categorical distribution. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. seed (Optional[int], optional): Seed for the random number generator. Defaults to None. Note: This update function would return a new random categorical distribution. The new categorical distribution is sampled from a Dirichlet distribution with all parameters equal to 1. """ def __init__(self, scheduler: base.Scheduler, seed: Optional[int] = None) -> None: super().__init__(scheduler) self.rng = np.random.default_rng(seed=seed) def _update(self, param, t: int) -> Any: """Update the parameter by returning a new uniform random categorical distribution. Args: param (list): parameter to be updated t (int): current time step Returns: Any: updated parameter """ return list(self.rng.dirichlet(np.ones(len(param))))
[docs] class DistributionIncrementUpdate(base.UpdateDistributionFn): """Increment the the parameter by k. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. k (float): The amount which the parameter is updated. Note: This update function is useful for testing the robustness of the agent to changes in the environment. If the parameter is a probability, k would update the probability of going in the intended direction. Otherwise, k would be added to the parameter's value. """ def __init__(self, scheduler: Type[base.Scheduler], k: float) -> None: super().__init__(scheduler) self.k = k def __call__(self, param: list[float], t: int) -> Any: return super().__call__(param, t) def _update(self, param: list[float], t: int, **kwargs) -> Any: """Update the parameter by incrementing the intended direction by k. """ param[0] = min(1, param[0] + self.k) for i in range(1, len(param)): param[i] = (1 - param[0]) / (len(param) - 1) return param
[docs] class DistributionDecrementUpdate(base.UpdateDistributionFn): """Decrement the probability of going in the intended direction by some k. Overview: This function is used to decrement the probability distribution by some k. The probability distribution is represented as a list of probabilities. The intended direction is the first element in the probability distribution. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. k (float): The amount which the parameter is updated. """ def __init__(self, scheduler: base.Scheduler, k: float) -> None: super().__init__(scheduler) self.k = k def __call__(self, param: list[float], t: int) -> Any: return super().__call__(param, t) def _update(self, param: list[float], t: int, **kwargs) -> Any: """Update the parameter by decrementing the intended direction by k. Returns: param (list): Updated probability distribution """ param[0] = max(0, param[0] - self.k) for i in range(1, len(param)): param[i] = (1 - param[0]) / (len(param) - 1) return param
[docs] class DistributionStepWiseUpdate(base.UpdateDistributionFn): """Update the parameter to values to a set of predefined values at specific time steps. Args: scheduler (base.Scheduler): scheduler that determines when the update function fires. update_values (list): A list of values that the parameter is updated to at specific time steps. """ def __init__(self, scheduler: base.Scheduler, update_values: list) -> None: super().__init__(scheduler) self.update_values = update_values def __call__(self, param: Any, t: int) -> Any: return super().__call__(param, t) def _update(self, param, t: int) -> Any: """ Args: param (list): current parameter value t (int): current time step Returns: list: updated parameter value """ try: param = self.update_values.pop(0) except AssertionError: "No more parameters to update" finally: return param
[docs] class LCBoundedDistrubutionUpdate(base.UpdateDistributionFn): """Decrement the parameters so that the change is Lipshitz continuous. Overview: This function would call the decrement update function and check if the change is Lipshitz continuous. If not it would recall the decrement update function until the change is Lipshitz continuous. The Lipshitz continuous constraint between to probability distributions is defined as: .. math:: W_1(p_t(.|s,a),p_{t'}(.|s,a)) <= L * |t - t'| Where :math:`W_1` is the Wasserstein distance between two probability distributions. Args: update_fn (Type[base.UpdateDistributionFn]): The update function that updates the parameter. L (float): The Lipshitz constant. Note: This update function is an implementation of transition fucntion in Lecarpentier and Rechelson et al. 2019 """ def __init__(self, scheduler, L: float, update_fn=None) -> None: super().__init__(scheduler) self.L = L if update_fn is None: self.update_fn = RandomCategorical(scheduler) else: assert issubclass(update_fn, base.UpdateDistributionFn), ( "update_fn must be a subclass of base.UpdateDistributionFn" ) self.update_fn = update_fn(scheduler) def _update(self, param: Any, t: int) -> Any: max_trys = 1e5 count = 0 cur_dist = param updated_dist = self.update_fn.update(param, t) wass_dist = utils.wasserstein_distance(cur_dist, updated_dist) delta_time = abs(t - self.prev_time) d = self.L * delta_time while wass_dist > d and count < max_trys: updated_dist = self.update_fn.update(param, t) wass_dist = utils.wasserstein_distance(cur_dist, updated_dist) count += 1 if count >= max_trys: raise ValueError("Could not find a Lipshitz continuous update") else: return updated_dist
[docs] class BudgetBoundedIncrement(base.UpdateDistributionFn): """Increment the parameters so that the total amount of change is bounded by some budget. Overview: This function contrains the total amount of change in the parameter by some max budget. This formulation is outlined in Cheung et al. 2020. Args: scheduler (base.Scheduler): scheduler that determines when the update function fires. k (float): The amount which the parameter is updated. B (Union[int,float]): The maximum total amount of change allowed in the parameter. """ def __init__( self, scheduler: base.Scheduler, k: float, B: Union[int, float] ) -> None: super().__init__(scheduler, k) self.B = B self.total_change = 0 def __call__(self, param: Any, t: int) -> Any: curr_dist = param updated_param, change = super().__call__(param, t) amount_change = utils.wasserstein_distance(curr_dist, updated_param) if self.total_change + amount_change <= self.B: self.total_change += amount_change return updated_param, change else: return curr_dist, False
[docs] class DistributionNoUpdate(base.UpdateDistributionFn): """Does not update the parameter but return correct ns_bench interface. Overview: This function does not update the parameter. """ def __init__(self, scheduler: base.Scheduler) -> None: super().__init__(scheduler) def __call__(self, param: Any, t: int) -> Any: return super().__call__(param, t) def _update(self, param: Any, t: int) -> Any: return param
[docs] class UniformDrift(base.UpdateDistributionFn): r"""Drift the distribution toward uniform at a fixed rate. Overview: .. math:: p_t = (1 - \alpha) \, p_{t-1} + \alpha \, \mathbf{u} where :math:`\mathbf{u}` is the uniform distribution and :math:`\alpha \in [0, 1]` is the drift rate. Since this is a convex combination of two valid distributions, the result is always a valid distribution. Args: scheduler (base.Scheduler): scheduler that determines when the update function fires. rate (float): Mixing rate toward uniform. 0 = no change, 1 = fully uniform. """ def __init__(self, scheduler: base.Scheduler, rate: float) -> None: super().__init__(scheduler) self.rate = rate def _update(self, param: list[float], t: int) -> Any: n = len(param) uniform = 1.0 / n return [ (1 - self.rate) * p + self.rate * uniform for p in param ]
[docs] class TargetReversion(base.UpdateDistributionFn): r"""Mean-revert toward a target distribution (OU analog for distributions). Overview: .. math:: p_t = p_{t-1} + \theta \, (\text{target} - p_{t-1}) which simplifies to :math:`(1 - \theta) p_{t-1} + \theta \, \text{target}`, a convex combination that stays on the probability simplex for :math:`\theta \in [0, 1]`. Args: scheduler (base.Scheduler): scheduler that determines when the update function fires. target (list[float]): The target distribution to revert toward. theta (float): Reversion speed. 0 = no change, 1 = jump to target. """ def __init__( self, scheduler: base.Scheduler, target: list, theta: float ) -> None: super().__init__(scheduler) self.target = target self.theta = theta def _update(self, param: list[float], t: int) -> Any: return [ p + self.theta * (tgt - p) for p, tgt in zip(param, self.target) ]
[docs] class DistributionLinearInterpolation(base.UpdateDistributionFn): r"""Linearly interpolate between two distributions over ``T`` steps. Overview: .. math:: p_t = \text{start} + (\text{end} - \text{start}) \cdot \min\!\left(\frac{t}{T},\; 1\right) The output replaces the current distribution entirely. After ``t >= T`` the distribution is clamped at ``end_dist``. Args: scheduler (base.Scheduler): scheduler that determines when the update function fires. start_dist (list[float]): Distribution at ``t = 0``. end_dist (list[float]): Distribution at ``t = T``. T (int): Number of steps over which to interpolate. """ def __init__( self, scheduler: base.Scheduler, start_dist: list, end_dist: list, T: int, ) -> None: super().__init__(scheduler) self.start_dist = start_dist self.end_dist = end_dist self.T = T def _update(self, param: list[float], t: int) -> Any: frac = min(t / self.T, 1.0) return [ s + (e - s) * frac for s, e in zip(self.start_dist, self.end_dist) ]
[docs] class DistributionCyclicUpdate(base.UpdateDistributionFn): r"""Cycle through a list of distributions, wrapping around when exhausted. Overview: Each time the update fires, the distribution is set to the next entry in ``dist_list``. After reaching the end, the index wraps back to 0. Args: scheduler (base.Scheduler): scheduler that determines when the update function fires. dist_list (list[list[float]]): Distributions to cycle through. """ def __init__( self, scheduler: base.Scheduler, dist_list: list ) -> None: super().__init__(scheduler) self.dist_list = dist_list self._index = 0 def _update(self, param: list[float], t: int) -> Any: val = self.dist_list[self._index] self._index = (self._index + 1) % len(self.dist_list) return val
if __name__ == "__main__": import inspect # Run this file to automatically generate the __all__ variable. Copy and past the output bellow. public_api = [ name for name, obj in globals().items() if not name.startswith("_") and (inspect.isfunction(obj) or inspect.isclass(obj)) and obj.__module__ == __name__ ] print("__all__ = [") for name in sorted(public_api): print(f' "{name}",') print("]") __all__ = [ "BudgetBoundedIncrement", "DistributionCyclicUpdate", "DistributionDecrementUpdate", "DistributionIncrementUpdate", "DistributionLinearInterpolation", "DistributionNoUpdate", "DistributionStepWiseUpdate", "LCBoundedDistrubutionUpdate", "RandomCategorical", "TargetReversion", "UniformDrift", ]