import ns_gym.base as base
import ns_gym.utils as utils
from typing import Any, Type, Union, Optional
import numpy as np
"""
These classes update the probability distributions represented as a lists.
"""
[docs]
class RandomCategorical(base.UpdateDistributionFn):
"""Update the distirbution as a random categorical distribution.
Args:
scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
seed (Optional[int], optional): Seed for the random number generator. Defaults to None.
Note:
This update function would return a new random categorical distribution.
The new categorical distribution is sampled from a Dirichlet distribution with all parameters equal to 1.
"""
def __init__(self, scheduler: base.Scheduler, seed: Optional[int] = None) -> None:
super().__init__(scheduler)
self.rng = np.random.default_rng(seed=seed)
def _update(self, param, t: int) -> Any:
"""Update the parameter by returning a new uniform random categorical distribution.
Args:
param (list): parameter to be updated
t (int): current time step
Returns:
Any: updated parameter
"""
return list(self.rng.dirichlet(np.ones(len(param))))
[docs]
class DistributionIncrementUpdate(base.UpdateDistributionFn):
"""Increment the the parameter by k.
Args:
scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
k (float): The amount which the parameter is updated.
Note:
This update function is useful for testing the robustness of the agent to changes in the environment.
If the parameter is a probability, k would update the probability of going in the intended direction.
Otherwise, k would be added to the parameter's value.
"""
def __init__(self, scheduler: Type[base.Scheduler], k: float) -> None:
super().__init__(scheduler)
self.k = k
def __call__(self, param: list[float], t: int) -> Any:
return super().__call__(param, t)
def _update(self, param: list[float], t: int, **kwargs) -> Any:
"""Update the parameter by incrementing the intended direction by k.
"""
param[0] = min(1, param[0] + self.k)
for i in range(1, len(param)):
param[i] = (1 - param[0]) / (len(param) - 1)
return param
[docs]
class DistributionDecrementUpdate(base.UpdateDistributionFn):
"""Decrement the probability of going in the intended direction by some k.
Overview:
This function is used to decrement the probability distribution by some k. The probability distribution is represented as a list of probabilities. The intended direction is the first element in the probability distribution.
Args:
scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
k (float): The amount which the parameter is updated.
"""
def __init__(self, scheduler: base.Scheduler, k: float) -> None:
super().__init__(scheduler)
self.k = k
def __call__(self, param: list[float], t: int) -> Any:
return super().__call__(param, t)
def _update(self, param: list[float], t: int, **kwargs) -> Any:
"""Update the parameter by decrementing the intended direction by k.
Returns:
param (list): Updated probability distribution
"""
param[0] = max(0, param[0] - self.k)
for i in range(1, len(param)):
param[i] = (1 - param[0]) / (len(param) - 1)
return param
[docs]
class DistributionStepWiseUpdate(base.UpdateDistributionFn):
"""Update the parameter to values to a set of predefined values at specific time steps.
Args:
scheduler (base.Scheduler): scheduler that determines when the update function fires.
update_values (list): A list of values that the parameter is updated to at specific time steps.
"""
def __init__(self, scheduler: base.Scheduler, update_values: list) -> None:
super().__init__(scheduler)
self.update_values = update_values
def __call__(self, param: Any, t: int) -> Any:
return super().__call__(param, t)
def _update(self, param, t: int) -> Any:
"""
Args:
param (list): current parameter value
t (int): current time step
Returns:
list: updated parameter value
"""
try:
param = self.update_values.pop(0)
except AssertionError:
"No more parameters to update"
finally:
return param
[docs]
class LCBoundedDistrubutionUpdate(base.UpdateDistributionFn):
"""Decrement the parameters so that the change is Lipshitz continuous.
Overview:
This function would call the decrement update function and check if the change is Lipshitz continuous.
If not it would recall the decrement update function until the change is Lipshitz continuous.
The Lipshitz continuous constraint between to probability distributions is defined as:
.. math::
W_1(p_t(.|s,a),p_{t'}(.|s,a)) <= L * |t - t'|
Where :math:`W_1` is the Wasserstein distance between two probability distributions.
Args:
update_fn (Type[base.UpdateDistributionFn]): The update function that updates the parameter.
L (float): The Lipshitz constant.
Note:
This update function is an implementation of transition fucntion in Lecarpentier and Rechelson et al. 2019
"""
def __init__(self, scheduler, L: float, update_fn=None) -> None:
super().__init__(scheduler)
self.L = L
if update_fn is None:
self.update_fn = RandomCategorical(scheduler)
else:
assert issubclass(update_fn, base.UpdateDistributionFn), (
"update_fn must be a subclass of base.UpdateDistributionFn"
)
self.update_fn = update_fn(scheduler)
def _update(self, param: Any, t: int) -> Any:
max_trys = 1e5
count = 0
cur_dist = param
updated_dist = self.update_fn.update(param, t)
wass_dist = utils.wasserstein_distance(cur_dist, updated_dist)
delta_time = abs(t - self.prev_time)
d = self.L * delta_time
while wass_dist > d and count < max_trys:
updated_dist = self.update_fn.update(param, t)
wass_dist = utils.wasserstein_distance(cur_dist, updated_dist)
count += 1
if count >= max_trys:
raise ValueError("Could not find a Lipshitz continuous update")
else:
return updated_dist
[docs]
class BudgetBoundedIncrement(base.UpdateDistributionFn):
"""Increment the parameters so that the total amount of change is bounded by some budget.
Overview:
This function contrains the total amount of change in the parameter by some max budget.
This formulation is outlined in Cheung et al. 2020.
Args:
scheduler (base.Scheduler): scheduler that determines when the update function fires.
k (float): The amount which the parameter is updated.
B (Union[int,float]): The maximum total amount of change allowed in the parameter.
"""
def __init__(
self, scheduler: base.Scheduler, k: float, B: Union[int, float]
) -> None:
super().__init__(scheduler, k)
self.B = B
self.total_change = 0
def __call__(self, param: Any, t: int) -> Any:
curr_dist = param
updated_param, change = super().__call__(param, t)
amount_change = utils.wasserstein_distance(curr_dist, updated_param)
if self.total_change + amount_change <= self.B:
self.total_change += amount_change
return updated_param, change
else:
return curr_dist, False
[docs]
class DistributionNoUpdate(base.UpdateDistributionFn):
"""Does not update the parameter but return correct ns_bench interface.
Overview:
This function does not update the parameter.
"""
def __init__(self, scheduler: base.Scheduler) -> None:
super().__init__(scheduler)
def __call__(self, param: Any, t: int) -> Any:
return super().__call__(param, t)
def _update(self, param: Any, t: int) -> Any:
return param
[docs]
class TargetReversion(base.UpdateDistributionFn):
r"""Mean-revert toward a target distribution (OU analog for distributions).
Overview:
.. math::
p_t = p_{t-1} + \theta \, (\text{target} - p_{t-1})
which simplifies to :math:`(1 - \theta) p_{t-1} + \theta \, \text{target}`,
a convex combination that stays on the probability simplex for
:math:`\theta \in [0, 1]`.
Args:
scheduler (base.Scheduler): scheduler that determines when the update function fires.
target (list[float]): The target distribution to revert toward.
theta (float): Reversion speed. 0 = no change, 1 = jump to target.
"""
def __init__(
self, scheduler: base.Scheduler, target: list, theta: float
) -> None:
super().__init__(scheduler)
self.target = target
self.theta = theta
def _update(self, param: list[float], t: int) -> Any:
return [
p + self.theta * (tgt - p)
for p, tgt in zip(param, self.target)
]
[docs]
class DistributionLinearInterpolation(base.UpdateDistributionFn):
r"""Linearly interpolate between two distributions over ``T`` steps.
Overview:
.. math::
p_t = \text{start} + (\text{end} - \text{start}) \cdot \min\!\left(\frac{t}{T},\; 1\right)
The output replaces the current distribution entirely.
After ``t >= T`` the distribution is clamped at ``end_dist``.
Args:
scheduler (base.Scheduler): scheduler that determines when the update function fires.
start_dist (list[float]): Distribution at ``t = 0``.
end_dist (list[float]): Distribution at ``t = T``.
T (int): Number of steps over which to interpolate.
"""
def __init__(
self,
scheduler: base.Scheduler,
start_dist: list,
end_dist: list,
T: int,
) -> None:
super().__init__(scheduler)
self.start_dist = start_dist
self.end_dist = end_dist
self.T = T
def _update(self, param: list[float], t: int) -> Any:
frac = min(t / self.T, 1.0)
return [
s + (e - s) * frac
for s, e in zip(self.start_dist, self.end_dist)
]
[docs]
class DistributionCyclicUpdate(base.UpdateDistributionFn):
r"""Cycle through a list of distributions, wrapping around when exhausted.
Overview:
Each time the update fires, the distribution is set to the next entry
in ``dist_list``. After reaching the end, the index wraps back to 0.
Args:
scheduler (base.Scheduler): scheduler that determines when the update function fires.
dist_list (list[list[float]]): Distributions to cycle through.
"""
def __init__(
self, scheduler: base.Scheduler, dist_list: list
) -> None:
super().__init__(scheduler)
self.dist_list = dist_list
self._index = 0
def _update(self, param: list[float], t: int) -> Any:
val = self.dist_list[self._index]
self._index = (self._index + 1) % len(self.dist_list)
return val
if __name__ == "__main__":
import inspect
# Run this file to automatically generate the __all__ variable. Copy and past the output bellow.
public_api = [
name
for name, obj in globals().items()
if not name.startswith("_")
and (inspect.isfunction(obj) or inspect.isclass(obj))
and obj.__module__ == __name__
]
print("__all__ = [")
for name in sorted(public_api):
print(f' "{name}",')
print("]")
__all__ = [
"BudgetBoundedIncrement",
"DistributionCyclicUpdate",
"DistributionDecrementUpdate",
"DistributionIncrementUpdate",
"DistributionLinearInterpolation",
"DistributionNoUpdate",
"DistributionStepWiseUpdate",
"LCBoundedDistrubutionUpdate",
"RandomCategorical",
"TargetReversion",
"UniformDrift",
]