Source code for ns_gym.update_functions.single_param

import ns_gym.base as base
from typing import Union, Any, Type
import numpy as np

"""
These are a collection of parameter update functions that can be used in the ns_bench framework.
These classes take in a single parameter (scalars) and return a new parameter. 
The update functions only "fire" when the scheduler returns True.

To maintain a consistent interface, the update functions can be implemented as a derived class of base.UpdateFn.

But really the only main requirement is that the update function takes in a parameter and a time step in the
__call__() method and returns a new parameter and a boolean value indicating whether the update function fired or not.

"""

##### Update Functions for single parameters (scalars) #####


[docs] class DeterministicTrend(base.UpdateFn): r"""Update the parameter with a deterministic trend. Overview: .. math:: Y_t = Y_{t-1} + slope * t where :math:`Y_t` is the parameter value at time step :math:`t` and slope is the slope of the trend. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. slope (float): The slope of the trend. """ def __init__(self, scheduler: Type[base.Scheduler], slope: float) -> None: super().__init__(scheduler) self.slope = slope def _update(self, param: float, t: float) -> tuple[float, bool]: updated_param = param + self.slope * t return updated_param
[docs] class RandomWalkWithDriftAndTrend(base.UpdateFn): r"""Parameter update function that updates the parameter with white noise and a deterministic trend. Overview: .. math:: Y_t = \alpha + Y_{t-1} + \text{slope} * t + \epsilon_t where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`\alpha` is the drift term, :math:`\text{slope}` is the slope of the trend, and :math:`\epsilon` is white noise. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. alpha (float): The drift term. mu (float): The mean of the white noise. sigma (float): The standard deviation of the white noise. slope (float): The slope of the trend. seed (Union[int, None], optional): Seed for the random number generator. Defaults to None. """ def __init__( self, scheduler: Type[base.Scheduler], alpha: float, mu: float, sigma: float, slope: float, seed: Union[int, None] = None, ) -> None: super().__init__(scheduler) self.mu = mu self.sigma = sigma self.alpha = alpha self.slope = slope self.rng = np.random.default_rng(seed=seed) def _update(self, param: float, t: float) -> tuple[float, bool]: white_noise = self.rng.normal(self.mu, self.sigma, 1) updated_param = self.alpha + param + white_noise + self.slope * t return updated_param
[docs] class RandomWalk(base.UpdateFn): r"""Parameter update function that updates the parameter with white noise. Overview: A pure random walk : :math:`Y_t = Y_{t-1} + \epsilon_t` where :math:`Y_t` is the parameter value at time step :math:`t` and :math:`\epsilon` is white noise. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. mu (Union[float,int], optional): The mean of the white noise. Defaults to 0. sigma (Union[float,int], optional): The standard deviation of the white noise. Defaults to 1. seed (Union[int,None], optional): Seed for the random number generator. Defaults to None. """ def __init__( self, scheduler: Type[base.Scheduler], mu: Union[float, int] = 0, sigma: Union[float, int] = 1, seed=None, ) -> tuple[Any, bool]: super().__init__(scheduler) self.mu = mu self.sigma = sigma self.rng = np.random.default_rng(seed=seed) def _update(self, param: Any, t: Union[int, float]) -> Any: white_noise = self.rng.normal(self.mu, self.sigma, 1) updated_param = param + white_noise return updated_param[0]
[docs] class RandomWalkWithDrift(base.UpdateFn): r"""A parameter update function that updates the parameter with white noise and a drift term. Overview: .. math:: Y_t = \alpha + Y_{t-1} + \epsilon_t where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`\alpha` is the drift term, and :math:`\epsilon` is white noise. Args: alpha (float): The drift term. mu (float): The mean of the white noise. sigma (float): The standard deviation of the white noise. seed (int): Seed for the random number generator. Defaults to None. """ def __init__( self, scheduler: Type[base.Scheduler], alpha: float, mu: float, sigma: float, seed: Union[int, None] = None, ) -> None: super().__init__(scheduler) self.mu = mu self.sigma = sigma self.alpha = alpha self.rng = np.random.default_rng(seed=seed) def _update(self, param: Any, t: int) -> Any: white_noise = self.rng.normal(self.mu, self.sigma, 1) upated_param = self.alpha + param + white_noise return upated_param
[docs] class IncrementUpdate(base.UpdateFn): r"""Increment the the parameter by k. Overview: .. math:: Y_t = Y_{t-1} + k where :math:`Y_t` is the parameter value at time step :math:`t` and :math:`k` is the amount to increment the parameter by. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. k (float): The amount which the parameter is updated. """ def __init__(self, scheduler: Type[base.Scheduler], k: float) -> None: super().__init__(scheduler) self.k = k def _update(self, param: Any, t: int) -> Any: param += self.k return param
class DecrementUpdate(base.UpdateFn): r"""Decrement the probability of going in the intended direction by some k. Overview: .. math:: Y_t = Y_{t-1} - k where :math:`Y_t` is the parameter value at time step :math:`t` and :math:`k` is the amount to decrement the parameter by. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. k (float): The amount which the parameter is updated. """ def __init__(self, scheduler, k) -> None: super().__init__(scheduler) self.k = k def _update(self, param, t) -> Any: param -= self.k return param
[docs] class StepWiseUpdate(base.UpdateFn): r"""Update the parameter at specific time steps. Overview: This function updates the parameter to the next value in the `param_list` when called. If the `param_list` is empty, the parameter is not updated. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. param_list (list): A list of parameters to update to. """ def __init__(self, scheduler: Type[base.Scheduler], param_list: list) -> None: super().__init__(scheduler) self.param_list = param_list def _update(self, param: list, t: int) -> Any: try: param = self.param_list.pop(0) except AssertionError: "No more parameters to update" finally: return param
[docs] class NoUpdate(base.UpdateFn): r"""Do not update the parameter but return correct interface Overview: This function does not update the parameter when called. It is useful for testing and debugging. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. """ def __init__(self, scheduler: Type[base.Scheduler]) -> None: super().__init__(scheduler) def _update(self, param: Any, t: int) -> Any: return param
[docs] class OscillatingUpdate(base.UpdateFn): r"""Update the parameter with an oscillating function. Overview: .. math:: Y_t = Y_{t-1} + \delta * sin(t) where :math:`Y_t` is the parameter value at time step :math:`t` and :math:`\delta` is the amplitude of the sine wave. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. delta (float): The amplitude of the sine wave. """ def __init__(self, scheduler: Type[base.Scheduler], delta: float) -> None: super().__init__(scheduler) self.delta = delta def _update(self, param: Any, t: int) -> Any: oscillation = self.delta * np.sin(t) return param + oscillation
[docs] class ExponentialDecay(base.UpdateFn): r"""Exponential decay of the parameter. Overview: .. math:: Y_t = Y_0 * exp(-\lambda * t) where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`Y_0` is the initial parameter value, and :math:`\lambda` is the rate of decay. Args: scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires. decay_rate (float): The rate of decay. i.e. :math:`\lambda` """ def __init__(self, scheduler: Type[base.Scheduler], decay_rate: float) -> None: super().__init__(scheduler) self.decay_rate = decay_rate def _update(self, param: Any, t: int) -> Any: updated_param = param * np.exp(-self.decay_rate * t) return updated_param
[docs] class GeometricProgression(base.UpdateFn): r"""Apply a geometric progression to the parameter. Overview: .. math:: Y_t = Y_0 * r^t where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`Y_0` is the initial parameter value, and :math:`r` is the common ratio. """ def __init__(self, scheduler, r): super().__init__(scheduler) self.r = r def _update(self, param, t): updated_param = param * self.r return updated_param
if __name__ == "__main__": import inspect # Run this file to automatically generate the __all__ variable. Copy and past the output bellow. public_api = [ name for name, obj in globals().items() if not name.startswith("_") and (inspect.isfunction(obj) or inspect.isclass(obj)) and obj.__module__ == __name__ ] print("__all__ = [") for name in sorted(public_api): print(f' "{name}",') print("]") __all__ = [ "DeterministicTrend", "ExponentialDecay", "GeometricProgression", "IncrementUpdate", "NoUpdate", "OscillatingUpdate", "RandomWalk", "RandomWalkWithDrift", "RandomWalkWithDriftAndTrend", "StepWiseUpdate", ]