Source code for ns_gym.update_functions.single_param

import ns_gym.base as base
from typing import Union, Any, Type
import numpy as np

"""
These are a collection of parameter update functions that can be used in the ns_bench framework.
These classes take in a single parameter (scalars) and return a new parameter. 
The update functions only "fire" when the scheduler returns True.

To maintain a consistent interface, the update functions can be implemented as a derived class of base.UpdateFn.

But really the only main requirement is that the update function takes in a parameter and a time step in the
__call__() method and returns a new parameter and a boolean value indicating whether the update function fired or not.

"""

##### Update Functions for single parameters (scalars) #####


[docs] class DeterministicTrend(base.UpdateFn): r"""Update the parameter with a deterministic trend. Overview: .. math:: Y_t = Y_{t-1} + slope * t where :math:`Y_t` is the parameter value at time step :math:`t` and slope is the slope of the trend. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. slope (float): The slope of the trend. """ def __init__(self, scheduler: Type[base.Scheduler], slope: float) -> None: super().__init__(scheduler) self.slope = slope def _update(self, param: float, t: float) -> tuple[float, bool]: updated_param = param + self.slope * t return updated_param
[docs] class RandomWalkWithDriftAndTrend(base.UpdateFn): r"""Parameter update function that updates the parameter with white noise and a deterministic trend. Overview: .. math:: Y_t = \alpha + Y_{t-1} + \text{slope} * t + \epsilon_t where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`\alpha` is the drift term, :math:`\text{slope}` is the slope of the trend, and :math:`\epsilon` is white noise. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. alpha (float): The drift term. mu (float): The mean of the white noise. sigma (float): The standard deviation of the white noise. slope (float): The slope of the trend. seed (Union[int, None], optional): Seed for the random number generator. Defaults to None. """ def __init__( self, scheduler: Type[base.Scheduler], alpha: float, mu: float, sigma: float, slope: float, seed: Union[int, None] = None, ) -> None: super().__init__(scheduler) self.mu = mu self.sigma = sigma self.alpha = alpha self.slope = slope self.rng = np.random.default_rng(seed=seed) def _update(self, param: float, t: float) -> tuple[float, bool]: white_noise = self.rng.normal(self.mu, self.sigma) updated_param = self.alpha + param + white_noise + self.slope * t return updated_param
[docs] class RandomWalk(base.UpdateFn): r"""Parameter update function that updates the parameter with white noise. Overview: A pure random walk : :math:`Y_t = Y_{t-1} + \epsilon_t` where :math:`Y_t` is the parameter value at time step :math:`t` and :math:`\epsilon` is white noise. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. mu (Union[float,int], optional): The mean of the white noise. Defaults to 0. sigma (Union[float,int], optional): The standard deviation of the white noise. Defaults to 1. seed (Union[int,None], optional): Seed for the random number generator. Defaults to None. """ def __init__( self, scheduler: Type[base.Scheduler], mu: Union[float, int] = 0, sigma: Union[float, int] = 1, seed=None, ) -> tuple[Any, bool]: super().__init__(scheduler) self.mu = mu self.sigma = sigma self.rng = np.random.default_rng(seed=seed) def _update(self, param: Any, t: Union[int, float]) -> Any: white_noise = self.rng.normal(self.mu, self.sigma) updated_param = param + white_noise return updated_param
[docs] class RandomWalkWithDrift(base.UpdateFn): r"""A parameter update function that updates the parameter with white noise and a drift term. Overview: .. math:: Y_t = \alpha + Y_{t-1} + \epsilon_t where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`\alpha` is the drift term, and :math:`\epsilon` is white noise. Args: alpha (float): The drift term. mu (float): The mean of the white noise. sigma (float): The standard deviation of the white noise. seed (int): Seed for the random number generator. Defaults to None. """ def __init__( self, scheduler: Type[base.Scheduler], alpha: float, mu: float, sigma: float, seed: Union[int, None] = None, ) -> None: super().__init__(scheduler) self.mu = mu self.sigma = sigma self.alpha = alpha self.rng = np.random.default_rng(seed=seed) def _update(self, param: Any, t: int) -> Any: white_noise = self.rng.normal(self.mu, self.sigma) upated_param = self.alpha + param + white_noise return upated_param
[docs] class IncrementUpdate(base.UpdateFn): r"""Increment the the parameter by k. Overview: .. math:: Y_t = Y_{t-1} + k where :math:`Y_t` is the parameter value at time step :math:`t` and :math:`k` is the amount to increment the parameter by. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. k (float): The amount which the parameter is updated. """ def __init__(self, scheduler: Type[base.Scheduler], k: float) -> None: super().__init__(scheduler) self.k = k def _update(self, param: Any, t: int) -> Any: param += self.k return param
[docs] class DecrementUpdate(base.UpdateFn): r"""Decrement the probability of going in the intended direction by some k. Overview: .. math:: Y_t = Y_{t-1} - k where :math:`Y_t` is the parameter value at time step :math:`t` and :math:`k` is the amount to decrement the parameter by. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. k (float): The amount which the parameter is updated. """ def __init__(self, scheduler, k) -> None: super().__init__(scheduler) self.k = k def _update(self, param, t) -> Any: param -= self.k return param
[docs] class StepWiseUpdate(base.UpdateFn): r"""Update the parameter at specific time steps. Overview: This function updates the parameter to the next value in the `param_list` when called. If the `param_list` is empty, the parameter is not updated. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. param_list (list): A list of parameters to update to. """ def __init__(self, scheduler: Type[base.Scheduler], param_list: list) -> None: super().__init__(scheduler) self.param_list = param_list def _update(self, param: list, t: int) -> Any: try: param = self.param_list.pop(0) except AssertionError: "No more parameters to update" finally: return param
[docs] class NoUpdate(base.UpdateFn): r"""Do not update the parameter but return correct interface Overview: This function does not update the parameter when called. It is useful for testing and debugging. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. """ def __init__(self, scheduler: Type[base.Scheduler]) -> None: super().__init__(scheduler) def _update(self, param: Any, t: int) -> Any: return param
[docs] class OscillatingUpdate(base.UpdateFn): r"""Update the parameter with an oscillating function. Overview: .. math:: Y_t = Y_{t-1} + \delta * sin(t) where :math:`Y_t` is the parameter value at time step :math:`t` and :math:`\delta` is the amplitude of the sine wave. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. delta (float): The amplitude of the sine wave. """ def __init__(self, scheduler: Type[base.Scheduler], delta: float) -> None: super().__init__(scheduler) self.delta = delta def _update(self, param: Any, t: int) -> Any: oscillation = self.delta * np.sin(t) return param + oscillation
[docs] class ExponentialDecay(base.UpdateFn): r"""Exponential decay of the parameter. Overview: .. math:: Y_t = Y_0 * exp(-\lambda * t) where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`Y_0` is the initial parameter value, and :math:`\lambda` is the rate of decay. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. decay_rate (float): The rate of decay. i.e. :math:`\lambda` """ def __init__(self, scheduler: Type[base.Scheduler], decay_rate: float) -> None: super().__init__(scheduler) self.decay_rate = decay_rate def _update(self, param: Any, t: int) -> Any: updated_param = param * np.exp(-self.decay_rate * t) return updated_param
[docs] class GeometricProgression(base.UpdateFn): r"""Apply a geometric progression to the parameter. Overview: .. math:: Y_t = Y_0 * r^t where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`Y_0` is the initial parameter value, and :math:`r` is the common ratio. """ def __init__(self, scheduler, r): super().__init__(scheduler) self.r = r def _update(self, param, t): updated_param = param * self.r return updated_param
[docs] class OrnsteinUhlenbeck(base.UpdateFn): r"""Mean-reverting update via the Ornstein-Uhlenbeck process. Overview: .. math:: Y_t = Y_{t-1} + \theta (\mu - Y_{t-1}) + \sigma \epsilon_t where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`\theta` controls the speed of mean reversion, :math:`\mu` is the long-run mean, and :math:`\sigma` scales the noise. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. theta (float): Speed of mean reversion. mu (float): Long-run equilibrium value. sigma (float): Volatility / noise scale. seed (Union[int, None], optional): Seed for the random number generator. Defaults to None. """ def __init__( self, scheduler: Type[base.Scheduler], theta: float, mu: float, sigma: float = 0.0, seed: Union[int, None] = None, ) -> None: super().__init__(scheduler) self.theta = theta self.mu = mu self.sigma = sigma self.rng = np.random.default_rng(seed=seed) def _update(self, param: Any, t: int) -> Any: noise = self.rng.normal(0, self.sigma) if self.sigma > 0 else 0.0 return param + self.theta * (self.mu - param) + noise
[docs] class SigmoidTransition(base.UpdateFn): r"""Smooth transition between two values via a logistic sigmoid. Overview: .. math:: Y_t = a + \frac{b - a}{1 + \exp(-k (t - t_0))} The output replaces the current parameter value entirely (it is independent of :math:`Y_{t-1}`), producing a smooth S-curve from :math:`a` to :math:`b` centred at :math:`t_0`. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. a (float): Starting value (as :math:`t \to -\infty`). b (float): Ending value (as :math:`t \to +\infty`). k (float): Steepness of the transition. Larger = sharper. t0 (float): Midpoint time of the transition. """ def __init__( self, scheduler: Type[base.Scheduler], a: float, b: float, k: float, t0: float, ) -> None: super().__init__(scheduler) self.a = a self.b = b self.k = k self.t0 = t0 def _update(self, param: Any, t: Union[int, float]) -> Any: sigmoid = 1.0 / (1.0 + np.exp(-self.k * (t - self.t0))) return self.a + (self.b - self.a) * sigmoid
[docs] class CyclicUpdate(base.UpdateFn): r"""Cycle through a list of values, wrapping around when exhausted. Overview: Each time the update fires, the parameter is set to the next value in ``value_list``. After reaching the end, the index wraps back to 0. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. value_list (list): Values to cycle through. """ def __init__(self, scheduler: Type[base.Scheduler], value_list: list) -> None: super().__init__(scheduler) self.value_list = value_list self._index = 0 def _update(self, param: Any, t: int) -> Any: val = self.value_list[self._index] self._index = (self._index + 1) % len(self.value_list) return val
[docs] class BoundedRandomWalk(base.UpdateFn): r"""Random walk clamped to ``[lo, hi]``. Overview: .. math:: Y_t = \text{clip}(Y_{t-1} + \epsilon_t,\; lo,\; hi) where :math:`\epsilon_t \sim \mathcal{N}(\mu, \sigma^2)`. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. mu (float): Mean of the noise. sigma (float): Standard deviation of the noise. lo (float): Lower bound. hi (float): Upper bound. seed (Union[int, None], optional): Seed for the random number generator. Defaults to None. """ def __init__( self, scheduler: Type[base.Scheduler], mu: float, sigma: float, lo: float, hi: float, seed: Union[int, None] = None, ) -> None: super().__init__(scheduler) self.mu = mu self.sigma = sigma self.lo = lo self.hi = hi self.rng = np.random.default_rng(seed=seed) def _update(self, param: Any, t: int) -> Any: noise = self.rng.normal(self.mu, self.sigma) return float(np.clip(param + noise, self.lo, self.hi))
[docs] class PolynomialTrend(base.UpdateFn): r"""Deterministic polynomial trend. Overview: .. math:: Y_t = Y_{t-1} + \sum_{i=1}^{n} a_i \, t^{i} where ``coeffs = [a_1, a_2, ..., a_n]``. A single-element list ``[slope]`` is equivalent to :class:`DeterministicTrend`. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. coeffs (list[float]): Polynomial coefficients ``[a_1, a_2, ...]`` for powers ``t, t^2, ...``. """ def __init__(self, scheduler: Type[base.Scheduler], coeffs: list) -> None: super().__init__(scheduler) self.coeffs = coeffs def _update(self, param: Any, t: Union[int, float]) -> Any: trend = sum(a * t ** (i + 1) for i, a in enumerate(self.coeffs)) return param + trend
[docs] class LinearInterpolation(base.UpdateFn): r"""Linearly interpolate from ``start_val`` to ``end_val`` over ``T`` steps. Overview: .. math:: Y_t = start\_val + (end\_val - start\_val) \cdot \min\!\left(\frac{t}{T},\; 1\right) The output replaces the current parameter value entirely. After ``t >= T`` the value is clamped at ``end_val``. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. start_val (float): Value at ``t = 0``. end_val (float): Value at ``t = T``. T (int): Number of steps over which to interpolate. """ def __init__( self, scheduler: Type[base.Scheduler], start_val: float, end_val: float, T: int, ) -> None: super().__init__(scheduler) self.start_val = start_val self.end_val = end_val self.T = T def _update(self, param: Any, t: Union[int, float]) -> Any: frac = min(t / self.T, 1.0) return self.start_val + (self.end_val - self.start_val) * frac
if __name__ == "__main__": import inspect # Run this file to automatically generate the __all__ variable. Copy and past the output bellow. public_api = [ name for name, obj in globals().items() if not name.startswith("_") and (inspect.isfunction(obj) or inspect.isclass(obj)) and obj.__module__ == __name__ ] print("__all__ = [") for name in sorted(public_api): print(f' "{name}",') print("]") __all__ = [ "BoundedRandomWalk", "CyclicUpdate", "DecrementUpdate", "DeterministicTrend", "ExponentialDecay", "GeometricProgression", "IncrementUpdate", "LinearInterpolation", "NoUpdate", "OrnsteinUhlenbeck", "OscillatingUpdate", "PolynomialTrend", "RandomWalk", "RandomWalkWithDrift", "RandomWalkWithDriftAndTrend", "SigmoidTransition", "StepWiseUpdate", ]