import ns_gym.base as base
from typing import Union, Any, Type
import numpy as np
"""
These are a collection of parameter update functions that can be used in the ns_bench framework.
These classes take in a single parameter (scalars) and return a new parameter. 
The update functions only "fire" when the scheduler returns True.
To maintain a consistent interface, the update functions can be implemented as a derived class of base.UpdateFn.
But really the only main requirement is that the update function takes in a parameter and a time step in the
__call__() method and returns a new parameter and a boolean value indicating whether the update function fired or not.
"""
##### Update Functions for single parameters (scalars) #####
[docs]
class DeterministicTrend(base.UpdateFn):
    r"""Update the parameter with a deterministic trend.
    Overview:
        .. math::
            Y_t = Y_{t-1} + slope * t
        where :math:`Y_t` is the parameter value at time step :math:`t` and slope is the slope of the trend.
    Args:
        scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
        slope (float): The slope of the trend.
    """
    def __init__(self, scheduler: Type[base.Scheduler], slope: float) -> None:
        super().__init__(scheduler)
        self.slope = slope
    def _update(self, param: float, t: float) -> tuple[float, bool]:
        updated_param = param + self.slope * t
        return updated_param 
[docs]
class RandomWalkWithDriftAndTrend(base.UpdateFn):
    r"""Parameter update function that updates the parameter with white noise and a deterministic trend.
    Overview:
        .. math::
            Y_t = \alpha + Y_{t-1} + \text{slope} * t + \epsilon_t
        where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`\alpha` is the drift term, :math:`\text{slope}` is the slope of the trend, and :math:`\epsilon` is white noise.
    Args:
        scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
        alpha (float): The drift term.
        mu (float): The mean of the white noise.
        sigma (float): The standard deviation of the white noise.
        slope (float): The slope of the trend.
        seed (Union[int, None], optional): Seed for the random number generator. Defaults to None.
    """
    def __init__(
        self,
        scheduler: Type[base.Scheduler],
        alpha: float,
        mu: float,
        sigma: float,
        slope: float,
        seed: Union[int, None] = None,
    ) -> None:
        super().__init__(scheduler)
        self.mu = mu
        self.sigma = sigma
        self.alpha = alpha
        self.slope = slope
        self.rng = np.random.default_rng(seed=seed)
    def _update(self, param: float, t: float) -> tuple[float, bool]:
        white_noise = self.rng.normal(self.mu, self.sigma, 1)
        updated_param = self.alpha + param + white_noise + self.slope * t
        return updated_param 
[docs]
class RandomWalk(base.UpdateFn):
    r"""Parameter update function that updates the parameter with white noise.
    Overview:
        A pure random walk : :math:`Y_t = Y_{t-1} + \epsilon_t` where :math:`Y_t` is the parameter value at time step :math:`t`
        and :math:`\epsilon` is white noise.
    Args:
        scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
        mu (Union[float,int], optional): The mean of the white noise. Defaults to 0.
        sigma (Union[float,int], optional): The standard deviation of the white noise. Defaults to 1.
        seed (Union[int,None], optional): Seed for the random number generator. Defaults to None.
    """
    def __init__(
        self,
        scheduler: Type[base.Scheduler],
        mu: Union[float, int] = 0,
        sigma: Union[float, int] = 1,
        seed=None,
    ) -> tuple[Any, bool]:
        super().__init__(scheduler)
        self.mu = mu
        self.sigma = sigma
        self.rng = np.random.default_rng(seed=seed)
    def _update(self, param: Any, t: Union[int, float]) -> Any:
        white_noise = self.rng.normal(self.mu, self.sigma, 1)
        updated_param = param + white_noise
        return updated_param[0] 
[docs]
class RandomWalkWithDrift(base.UpdateFn):
    r"""A parameter update function that updates the parameter with white noise and a drift term.
    Overview:
        .. math::
            Y_t = \alpha + Y_{t-1} + \epsilon_t
        where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`\alpha` is the drift term, and :math:`\epsilon` is white noise.
    Args:
        alpha (float): The drift term.
        mu (float): The mean of the white noise.
        sigma (float): The standard deviation of the white noise.
        seed (int): Seed for the random number generator. Defaults to None.
    """
    def __init__(
        self,
        scheduler: Type[base.Scheduler],
        alpha: float,
        mu: float,
        sigma: float,
        seed: Union[int, None] = None,
    ) -> None:
        super().__init__(scheduler)
        self.mu = mu
        self.sigma = sigma
        self.alpha = alpha
        self.rng = np.random.default_rng(seed=seed)
    def _update(self, param: Any, t: int) -> Any:
        white_noise = self.rng.normal(self.mu, self.sigma, 1)
        upated_param = self.alpha + param + white_noise
        return upated_param 
[docs]
class IncrementUpdate(base.UpdateFn):
    r"""Increment the the parameter by k.
    Overview:
        .. math::
            Y_t = Y_{t-1} + k
        where :math:`Y_t` is the parameter value at time step :math:`t` and :math:`k` is the amount to increment the parameter by.
    Args:
        scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
        k (float): The amount which the parameter is updated.
    """
    def __init__(self, scheduler: Type[base.Scheduler], k: float) -> None:
        super().__init__(scheduler)
        self.k = k
    def _update(self, param: Any, t: int) -> Any:
        param += self.k
        return param 
class DecrementUpdate(base.UpdateFn):
    r"""Decrement the probability of going in the intended direction by some k.
    Overview:
        .. math::
            Y_t = Y_{t-1} - k
        where :math:`Y_t` is the parameter value at time step :math:`t` and :math:`k` is the amount to decrement the parameter by.
    Args:
        scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
        k (float): The amount which the parameter is updated.
    """
    def __init__(self, scheduler, k) -> None:
        super().__init__(scheduler)
        self.k = k
    def _update(self, param, t) -> Any:
        param -= self.k
        return param
[docs]
class StepWiseUpdate(base.UpdateFn):
    r"""Update the parameter at specific time steps.
    Overview:
        This function updates the parameter to the next value in the `param_list` when called. If the `param_list` is empty, the parameter is not updated.
    Args:
        scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
        param_list (list): A list of parameters to update to.
    """
    def __init__(self, scheduler: Type[base.Scheduler], param_list: list) -> None:
        super().__init__(scheduler)
        self.param_list = param_list
    def _update(self, param: list, t: int) -> Any:
        try:
            param = self.param_list.pop(0)
        except AssertionError:
            "No more parameters to update"
        finally:
            return param 
[docs]
class NoUpdate(base.UpdateFn):
    r"""Do not update the parameter but return correct interface
    Overview:
        This function does not update the parameter when called. It is useful for testing and debugging.
    Args:
        scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
    """
    def __init__(self, scheduler: Type[base.Scheduler]) -> None:
        super().__init__(scheduler)
    def _update(self, param: Any, t: int) -> Any:
        return param 
[docs]
class OscillatingUpdate(base.UpdateFn):
    r"""Update the parameter with an oscillating function.
    Overview:
        .. math::
            Y_t = Y_{t-1} + \delta * sin(t)
        where :math:`Y_t` is the parameter value at time step :math:`t` and :math:`\delta` is the amplitude of the sine wave.
    Args:
        scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
        delta (float): The amplitude of the sine wave.
    """
    def __init__(self, scheduler: Type[base.Scheduler], delta: float) -> None:
        super().__init__(scheduler)
        self.delta = delta
    def _update(self, param: Any, t: int) -> Any:
        oscillation = self.delta * np.sin(t)
        return param + oscillation 
[docs]
class ExponentialDecay(base.UpdateFn):
    r"""Exponential decay of the parameter.
    
    Overview:
        .. math::
            Y_t = Y_0 * exp(-\lambda * t)
        where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`Y_0` is the initial parameter value, and :math:`\lambda` is the rate of decay.
    Args:
        scheduler (Type[base.Scheduler]): scheduler that determines when the update function fires.
        decay_rate (float): The rate of decay. i.e. :math:`\lambda`
    """
    def __init__(self, scheduler: Type[base.Scheduler], decay_rate: float) -> None:
        super().__init__(scheduler)
        self.decay_rate = decay_rate
    def _update(self, param: Any, t: int) -> Any:
        updated_param = param * np.exp(-self.decay_rate * t)
        return updated_param 
[docs]
class GeometricProgression(base.UpdateFn):
    r"""Apply a geometric progression to the parameter.
    Overview:
        .. math::
            Y_t = Y_0 * r^t
        where :math:`Y_t` is the parameter value at time step :math:`t`, :math:`Y_0` is the initial parameter value, and :math:`r` is the common ratio.
    """
    def __init__(self, scheduler, r):
        super().__init__(scheduler)
        self.r = r
    def _update(self, param, t):
        updated_param = param * self.r
        return updated_param 
if __name__ == "__main__":
    import inspect
    # Run this file to automatically generate the __all__ variable. Copy and past the output bellow.
    public_api = [
        name
        for name, obj in globals().items()
        if not name.startswith("_")
        and (inspect.isfunction(obj) or inspect.isclass(obj))
        and obj.__module__ == __name__
    ]
    print("__all__ = [")
    for name in sorted(public_api):
        print(f'    "{name}",')
    print("]")
__all__ = [
    "DeterministicTrend",
    "ExponentialDecay",
    "GeometricProgression",
    "IncrementUpdate",
    "NoUpdate",
    "OscillatingUpdate",
    "RandomWalk",
    "RandomWalkWithDrift",
    "RandomWalkWithDriftAndTrend",
    "StepWiseUpdate",
]