from typing import Any, Union
import gymnasium as gym
from copy import deepcopy
import warnings
import ns_gym.base as base
class ConstraintViolationWarning(Warning):
    """Warning issued when a constraint in the environment is violated."""
    pass
[docs]
class NSClassicControlWrapper(base.NSWrapper):
    """A non-stationary wrapper for Gymnasium's Classic Control environments.
    Args:
        env (gym.Env): Base gym environment.
        tunable_params (dict[str,base.UpdateFn]): Dictionary of parameter names and their associated update functions.
        change_notification (bool, optional): Flag to indicate whether to notify the agent of changes in the environment. Defaults to False.
        delta_change_notification (bool, optional): Flag to indicate whether to notify the agent of changes in the transition function. Defaults to False.
        in_sim_change (bool, optional): Flag to allow environmental changes to occur in the 'planning' environment. Defaults to False.
    """
    def __init__(
        self,
        env,
        tunable_params,
        change_notification: bool = False,
        delta_change_notification: bool = False,
        in_sim_change: bool = False,
        **kwargs: Any,
    ):
        assert env.unwrapped.__class__.__name__ in base.TUNABLE_PARAMS.keys(), (
            f"{env.unwrapped.__class__.__name__} is not a supported environment"
        )
        super().__init__(
            env=env,
            tunable_params=tunable_params,
            change_notification=change_notification,
            delta_change_notification=delta_change_notification,
            in_sim_change=in_sim_change,
            **kwargs,
        )
        self.t = 0
        self.delta_t = 1
        self.initial_params = {}
        for key in tunable_params.keys():
            assert (
                key in base.TUNABLE_PARAMS[self.unwrapped.__class__.__name__].keys()
            ), (
                f"{key} is not a tunable parameter for {self.unwrapped.__class__.__name__}"
            )
            self.initial_params[key] = deepcopy(getattr(self.unwrapped, key))
[docs]
    def step(self, action: Union[float, int]):
        """Step through environment and update environmental parameters
        Args:
            action (Union[float,int]): Action to take in environment
        Returns:
            tuple[dict[str, Any], base.Reward, bool, bool, dict[str, Any]]: NS-Gym Observation dictionary, reward, done flag, truncated flag, info dictionary
        """
        if self.is_sim_env and not self.in_sim_change:
            obs, reward, terminated, truncated, info = super().step(
                action, env_change=None, delta_change=None
            )
        else:
            env_change = {}
            delta_change = {}
            new_vals = {}
            for p, fn in self.tunable_params.items():
                cur_val = getattr(self.unwrapped, p)
                new_val, change_flag, delta = fn(cur_val, self.t)
                delta_change[p] = delta
                env_change[p] = change_flag
                new_vals[p] = new_val
            for k, v in self._constraint_checker(new_vals).items():
                if not v:  # If the constraint is not violated, update the parameter
                    setattr(self.unwrapped, k, new_vals[k])
                else:
                    delta_change[k] = 0.0
                    env_change[k] = 0
            self._dependency_resolver()
            obs, reward, terminated, truncated, info = super().step(
                action, env_change=env_change, delta_change=delta_change
            )
            info["prob"] = 1.0
        return obs, reward, terminated, truncated, info 
[docs]
    def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
        """Reset environment"""
        obs, info = super().reset(seed=seed, options=options)
        for k, v in self.initial_params.items():
            setattr(self.unwrapped, k, deepcopy(v))
        return obs, info 
[docs]
    def close(self):
        return super().close() 
    def __str__(self):
        return super().__str__()
    def __repr__(self):
        return super().__repr__()
[docs]
    def get_planning_env(self):
        """Return a copy of the environment
        NOTE:
            - If the environment is a simulation environment, the function returns a deepcopy of the simulation environment.
            - If change notification is enabled, the function returns a deepcopy of the current environment because the decision making agent needs to be aware of the changes in the environment.
            - If change notification is disabled, the function returns a deepcopy of the environment with the initial parameters.
        """
        assert self.has_reset, (
            "The environment must be reset before getting the planning environment."
        )
        if self.is_sim_env or self.change_notification:
            return deepcopy(self)
        elif not self.change_notification:
            planning_env = deepcopy(self)
            for k, v in self.initial_params.items():
                setattr(planning_env.unwrapped, k, deepcopy(v))
            return planning_env 
    def __deepcopy__(self, memo):
        if self.unwrapped.__class__.__name__ in ["MountainCarEnv", "MountainCarContinuousEnv"]:
            warnings.warn(
                f"Deepcopy for {self.unwrapped.__class__.__name__} has a known "
                "issue with state divergence in the test suite. While parameter "
                "updates and notifications function correctly, be cautious if "
                "using `get_planning_env()` for simulation, as the test assertion "
                "`assert not np.array_equal(sim_obs['state'], obs['state'])` fails.",
                UserWarning
            )
        env_kwargs = self.unwrapped.spec.kwargs
        sim_env = gym.make(self.unwrapped.spec.id,**env_kwargs)
        sim_env = NSClassicControlWrapper(
            sim_env,
            deepcopy(self.tunable_params),
            self.change_notification,
            self.delta_change_notification,
            self.in_sim_change,
        )
        sim_env.reset()
        sim_env.unwrapped.state = deepcopy(self.unwrapped.state)
        sim_env.t = deepcopy(self.t)
        for k, v in self.tunable_params.items():
            setattr(sim_env.unwrapped, k, deepcopy(getattr(self.unwrapped, k)))
        sim_env._dependency_resolver()
        sim_env.is_sim_env = True
        return sim_env
    
[docs]
    def get_default_params(self):
        """Get dictionary of default parameters and their initial values"""
        return super().get_default_params() 
    def _constraint_checker(self, new_vals) -> dict[str, bool]:
        """Check if the physical constraints of the environment are being violated, and all dependent parameters are updated accordingly.
        Checks evironment parameters after each update step. If a constraint is violated, the parameter does not update and a warning is issued.
        Args:
            new_vals (dict[str,float]): New value of the parameter.
        Returns:
            constraint_dict (dict[str,bool]): Dictionary of parameters and their constraint violation status. True is a constraint is violated, False otherwise.
        Note:
            Since each environement has different physical contraints, I can either create a new class for of each environment or just implement this method in the wrapper
            that check the base environment name.
        - Relook at contrains to see if they make sense, no division by zero, no negative values, etc.
        - Make sure all dependent parameters are updated accordingly.
        - Should we store the previous values of the parameters?
        """
        constraint_dict: dict[str, bool] = {}
        if self.unwrapped.__class__.__name__ == "CartPoleEnv":
            for p, v in new_vals.items():
                constraint_dict[p] = False
                if p == "length" and v <= 0:
                    warnings.warn(
                        "Length of the pole cannot be negative, length not updated.",
                        ConstraintViolationWarning,
                    )
                    constraint_dict[p] = True
                elif p == "masscart" and v <= 0:
                    warnings.warn(
                        "Mass of the cart must be greater than zero, cart mass not updated",
                        ConstraintViolationWarning,
                    )
                    constraint_dict[p] = True
                elif p == "masspole" and v <= 0:
                    warnings.warn(
                        "Mass of the pole must be greater than zero, pole mass not updated",
                        ConstraintViolationWarning,
                    )
                    constraint_dict[p] = True
                elif p == "gravity" and v < 0:
                    warnings.warn(
                        "Gravity cannot be negative, gravity not updated",
                        ConstraintViolationWarning,
                    )
                    constraint_dict[p] = True
            return constraint_dict
        elif self.unwrapped.__class__.__name__ == "AcrobotEnv":
            for p, new_val in new_vals.items():
                constraint_dict[p] = False
                if p == "LINK_LENGTH_1":
                    if new_val <= 0:  # Make sure the length of the link is not negative
                        warnings.warn(
                            "Length of link must be greater than zero, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
                    elif (
                        "LINK_COM_POS_1" in new_vals.keys()
                        and new_vals["LINK_COM_POS_1"] > new_val
                    ):
                        warnings.warn(
                            "Length of link must be greater than the position of its center of mass, link 1 length parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
                    elif new_val < self.unwrapped.LINK_COM_POS_1:
                        warnings.warn(
                            "Length of link must be greater than the position of its center of mass, link 1 length parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
                elif p == "LINK_LENGTH_2" and new_val <= 0:
                    if new_val <= 0:  # Make sure the length of the link is not negative
                        warnings.warn(
                            "Length of link must be greater than zero, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
                    elif (
                        "LINK_COM_POS_2" in new_vals.keys()
                        and new_vals["LINK_COM_POS_2"] > new_val
                    ):
                        warnings.warn(
                            "Length of link must be greater than the position of its center of mass, link 2 length parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
                    elif new_val < self.unwrapped.LINK_COM_POS_2:
                        warnings.warn(
                            "Length of link must be greater than the position of its center of mass, link 2 length parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
                elif p == "LINK_MASS_1" and new_val <= 0:
                    warnings.warn(
                        "Mass of link 1 must be greater than zero, parameter not updated",
                        ConstraintViolationWarning,
                    )
                    constraint_dict[p] = True
                elif p == "LINK_MASS_2" and new_val <= 0:
                    warnings.warn(
                        "Mass of link 2 must be greater than zero, parameter not updated",
                        ConstraintViolationWarning,
                    )
                    constraint_dict[p] = True
                elif p == "LINK_COM_POS_1":
                    if new_val <= 0:
                        warnings.warn(
                            "Center of mass of link 1 must be greater than zero, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
                    elif (
                        "LINK_LENGTH_1" in new_vals.keys()
                        and new_vals["LINK_LENGTH_1"] < new_val
                    ):
                        warnings.warn(
                            "Center of mass of link 1 must be less than the length of the link, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
                    elif new_val > self.unwrapped.LINK_LENGTH_1:
                        warnings.warn(
                            "Center of mass of link 1 must be less than the length of the link, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
                elif p == "LINK_COM_POS_2":
                    if new_val <= 0:
                        warnings.warn(
                            "Center of mass of link 2 must be greater than zero, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
                    elif (
                        "LINK_LENGTH_2" in new_vals.keys()
                        and new_vals["LINK_LENGTH_2"] < new_val
                    ):
                        warnings.warn(
                            "Center of mass of link 2 must be less than the length of the link, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
                    elif new_val > self.unwrapped.LINK_LENGTH_2:
                        warnings.warn(
                            "Center of mass of link 2 must be less than the length of the link, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                        continue
        elif self.unwrapped.__class__.__name__ == "MountainCarEnv":
            for p, new_val in new_vals.items():
                constraint_dict[p] = False
                if p == "gravity":
                    if new_val <= 0:
                        warnings.warn(
                            "Gravity must be greater than zero, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                if p == "force":
                    if new_val <= 0:
                        warnings.warn(
                            "Force must be greater than zero, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
        elif self.unwrapped.__class__.__name__ == "Continuous_MountainCarEnv":
            for p, new_val in new_vals.items():
                constraint_dict[p] = False
                if p == "power":
                    if new_val <= 0:
                        warnings.warn(
                            "Power must be greater than zero, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
        elif self.unwrapped.__class__.__name__ == "PendulumEnv":
            for p, new_val in new_vals.items():
                constraint_dict[p] = False
                if p == "m":
                    if new_val <= 0:
                        warnings.warn(
                            "Mass of the pendulum must be greater than zero, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                elif p == "l":
                    if new_val <= 0:
                        warnings.warn(
                            "Length of the pendulum must be greater than zero, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                elif p == "g":
                    if new_val < 0:
                        warnings.warn(
                            "Gravity must be greater than or equal to zero, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
                elif p == "dt":
                    if new_val <= 0:
                        warnings.warn(
                            "Time step must be greater than zero, parameter not updated",
                            ConstraintViolationWarning,
                        )
                        constraint_dict[p] = True
        return constraint_dict
    def _dependency_resolver(self):
        """Check if the dependent parameters are updated accordingly."""
        if self.unwrapped.__class__.__name__ == "CartPoleEnv":
            if not (
                self.unwrapped.total_mass
                == self.unwrapped.masspole + self.unwrapped.masscart
            ):
                setattr(
                    self.unwrapped,
                    "total_mass",
                    self.unwrapped.masspole + self.unwrapped.masscart,
                )
            if not (
                self.unwrapped.polemass_length
                == self.unwrapped.length * self.unwrapped.masspole
            ):
                setattr(
                    self.unwrapped,
                    "polemass_length",
                    self.unwrapped.length * self.unwrapped.masspole,
                )
        elif (
            self.unwrapped.__class__.__name__ == "AcrobotEnv"
        ):  # no dependencies need to be updated
            pass
        elif self.unwrapped.__class__.__name__ == "MountainCarEnv":
            pass
        elif self.unwrapped.__class__.__name__ == "PendulumEnv":
            pass
        elif self.unwrapped.__class__.__name__ == "Continuous_MountainCarEnv":
            pass 
# nice
if __name__ == "__main__":
    import ns_gym.base as base
    import ns_gym.update_functions as update_functions
    import ns_gym.schedulers as schedulers
    import copy
    scheduler1 = schedulers.ContinuousScheduler()
    updateFn1 = update_functions.RandomWalk(scheduler=scheduler1)
    env = gym.make("CartPole-v1")
    params = {"force_mag": updateFn1}
    env = NSClassicControlWrapper(env, params)
    obs, info = env.reset()
    obs, reward, terminated, truncated, info = env.step(0)
    sim_env = copy.deepcopy(env)
    for _ in range(5):
        action = sim_env.action_space.sample()
        obs, reward, terminated, truncated, info = sim_env.step(action)
        obs, reward, terminated, truncated, info = env.step(action)
        print(f"sim_env {sim_env.unwrapped.force_mag}")
        print(f"env {env.unwrapped.force_mag}")