import gymnasium as gym
import mujoco
import numpy as np
import warnings
from copy import deepcopy
from typing import Any, Callable
import ns_gym.base as base
import ns_gym.update_functions as update_functions
import ns_gym.schedulers as schedulers
class ConstraintViolationWarning(Warning):
"""Warning issued when a constraint in the application is violated."""
pass
[docs]
class MujocoWrapper(base.NSWrapper):
def __init__(
self,
env: base.Env,
tunable_params: dict,
change_notification: bool = False,
delta_change_notification: bool = False,
in_sim_change: bool = False,
**kwargs: Any,
):
super().__init__(
env=env,
tunable_params=tunable_params,
change_notification=change_notification,
delta_change_notification=delta_change_notification,
in_sim_change=in_sim_change,
**kwargs,
)
self._accessors = {}
for key in tunable_params.keys():
self._accessors[key] = param_look_up(
self.unwrapped.__class__.__name__, key
)[0]
self.initial_params = {}
for key in tunable_params.keys():
self.initial_params[key] = deepcopy(self._get_param_value(key))
def _get_param_value(self, key: str) -> Any:
"""Gets a parameter value by calling its specific getter function."""
getter, _ = self._accessors[key]
return getter(self.unwrapped)
def _set_param_value(self, key: str, value: Any):
"""Sets a parameter value by calling its specific setter function."""
_, setter = self._accessors[key]
setter(self.unwrapped, value)
def _dependency_resolver(self):
"""Re-computes derived properties of the MuJoCo model after changes.
mj_setConst recomputes derived model constants (e.g. body_subtreemass)
that depend on mass parameters but are not updated by mj_forward alone.
However, mj_setConst resets qpos to qpos0, so we save and restore
the full integration state around the call.
"""
model = self.unwrapped.model
data = self.unwrapped.data
# Save integration state (qpos, qvel, act, time, etc.)
spec = mujoco.mjtState.mjSTATE_INTEGRATION
state_size = mujoco.mj_stateSize(model, spec)
state = np.empty(state_size)
mujoco.mj_getState(model, data, state, spec)
# Recompute derived model constants (body_subtreemass, etc.)
mujoco.mj_setConst(model, data)
# Restore integration state and recompute forward dynamics
mujoco.mj_setState(model, data, state, spec)
mujoco.mj_forward(model, data)
def _constraint_checker(self, new_vals: dict) -> dict[str, bool]:
"""Checks if new parameter values violate physical constraints."""
constraint_dict = {key: False for key in new_vals.keys()}
for p, v in new_vals.items():
if "mass" in p and v <= 1e-6:
warnings.warn(
f"Mass for '{p}' must be positive, not updated.",
ConstraintViolationWarning,
)
constraint_dict[p] = True
elif "size" in p and np.any(np.array(v) <= 1e-6):
warnings.warn(
f"Size for '{p}' must have positive elements, not updated.",
ConstraintViolationWarning,
)
constraint_dict[p] = True
elif "damping" in p and v < 0:
warnings.warn(
f"Damping for '{p}' cannot be negative, not updated.",
ConstraintViolationWarning,
)
constraint_dict[p] = True
return constraint_dict
[docs]
def step(self, action: Any) -> tuple[Any, Any, bool, bool, dict[str, Any]]:
"""Applies physics changes and then steps the environment."""
if self.is_sim_env and not self.in_sim_change:
env_change = {p: 0 for p in self.tunable_params.keys()}
delta_change = {p: 0.0 for p in self.tunable_params.keys()}
obs, reward, terminated, truncated, info = super().step(
action, env_change=env_change, delta_change=delta_change
)
else:
env_change = {}
delta_change = {}
new_vals = {}
for p, fn in self.tunable_params.items():
if p == "gravity":
# special handeling for gravity
cur_val = self._get_param_value(p)
val_to_update = cur_val[-1] # Get the z-component of gravity
new_val, change_flag, delta = fn(val_to_update, self.t)
delta_change[p] = delta
env_change[p] = change_flag
new_vals[p] = np.array([cur_val[0], cur_val[1], new_val])
else:
cur_val = self._get_param_value(p)
new_val, change_flag, delta = fn(cur_val, self.t)
delta_change[p] = delta
env_change[p] = change_flag
new_vals[p] = new_val
for k, v in self._constraint_checker(new_vals).items():
if not v: # If not violated, update the parameter
self._set_param_value(k, new_vals[k])
else:
delta_change[k] = False
env_change[k] = False
self._dependency_resolver()
obs, reward, terminated, truncated, info = super().step(
action, env_change=env_change, delta_change=delta_change
)
return obs, reward, terminated, truncated, info
[docs]
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
"""Reset environment and restore initial model parameters."""
obs, info = super().reset(seed=seed, options=options)
if not self.persistent_params:
for k, v in self.initial_params.items():
self._set_param_value(k, deepcopy(v))
self._dependency_resolver()
return obs, info
[docs]
def get_planning_env(self):
"""Return a copy of the environment for planning.
NOTE:
- If the environment is a simulation environment, the function returns a deepcopy of the simulation environment.
- If change notification is enabled, the function returns a deepcopy of the current environment because the decision making agent needs to be aware of the changes in the environment.
- If change notification is disabled, the function returns a deepcopy of the environment with the initial parameters.
"""
assert self.has_reset, (
"The environment must be reset before getting the planning environment."
)
if self.is_sim_env or self.delta_change_notification:
return deepcopy(self)
elif not self.delta_change_notification:
planning_env = deepcopy(self)
for k, v in self.initial_params.items():
planning_env._set_param_value(k, deepcopy(v))
planning_env._dependency_resolver()
return planning_env
def __deepcopy__(self, memo):
base_class_name = self.unwrapped.__class__.__name__
id_map = {
"AntEnv": "Ant-v5",
"HalfCheetahEnv": "HalfCheetah-v5",
"HopperEnv": "Hopper-v5",
"HumanoidEnv": "Humanoid-v5",
"HumanoidStandupEnv": "HumanoidStandup-v5",
"InvertedPendulumEnv": "InvertedPendulum-v5",
"InvertedDoublePendulumEnv": "InvertedDoublePendulum-v5",
"ReacherEnv": "Reacher-v5",
"SwimmerEnv": "Swimmer-v5",
"PusherEnv": "Pusher-v5",
}
base_id = id_map.get(base_class_name, self.unwrapped.spec.id)
sim_env = gym.make(base_id)
sim_env = MujocoWrapper(
sim_env,
deepcopy(self.tunable_params),
self.change_notification,
self.delta_change_notification,
self.in_sim_change,
scalar_reward=self.scalar_reward,
persistent_params=self.persistent_params,
)
sim_env.reset()
# Copy full MuJoCo integration state (time, qpos, qvel, act,
# qacc_warmstart) using the built-in state API for exact continuation
spec = mujoco.mjtState.mjSTATE_INTEGRATION
state_size = mujoco.mj_stateSize(self.unwrapped.model, spec)
state = np.empty(state_size)
mujoco.mj_getState(self.unwrapped.model, self.unwrapped.data, state, spec)
mujoco.mj_setState(sim_env.unwrapped.model, sim_env.unwrapped.data, state, spec)
# Copy wrapper state
sim_env.t = deepcopy(self.t)
# Copy current tunable parameter values into the new model
for k in self.tunable_params:
sim_env._set_param_value(k, deepcopy(self._get_param_value(k)))
# Recompute all derived MuJoCo quantities from the copied state
sim_env._dependency_resolver()
sim_env.is_sim_env = True
sim_env._reseed_planning_env_rngs()
return sim_env
def param_look_up(env_name: str, tunable_param: str) -> tuple[Callable, Callable]:
"""Helper function to grab setter and getter functions for various MuJoCo environments.
Maps friendly parameter names to MuJoCo model attributes and indices.
Args:
env_name (str): Name of the MuJoCo environment class (e.g., "AntEnv").
tunable_param (str): Friendly name of the tunable parameter (e.g., "torso_mass").
Returns:
tuple[Callable, Callable]: A tuple containing two callables:
- A getter function that takes an env instance and returns the parameter value.
- A setter function that takes an env instance and a value, and sets the parameter.
"""
mappings = {
"AntEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"torso_mass": (
lambda env: env.model.body_mass[env.model.body("torso").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("torso").id, val
),
),
"floor_friction": (
lambda env: env.model.geom_friction[env.model.geom("floor").id, 0],
# Corrected line below
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("floor").id, 0), val
),
),
},
"HalfCheetahEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"torso_mass": (
lambda env: env.model.body_mass[env.model.body("torso").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("torso").id, val
),
),
"bthigh_mass": (
lambda env: env.model.body_mass[env.model.body("bthigh").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("bthigh").id, val
),
),
"bshin_mass": (
lambda env: env.model.body_mass[env.model.body("bshin").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("bshin").id, val
),
),
"bfoot_mass": (
lambda env: env.model.body_mass[env.model.body("bfoot").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("bfoot").id, val
),
),
"fthigh_mass": (
lambda env: env.model.body_mass[env.model.body("fthigh").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("fthigh").id, val
),
),
"fshin_mass": (
lambda env: env.model.body_mass[env.model.body("fshin").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("fshin").id, val
),
),
"ffeet_mass": (
lambda env: env.model.body_mass[env.model.body("ffoot").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("ffoot").id, val
),
),
"floor_friction": (
lambda env: env.model.geom_friction[env.model.geom("floor").id, 0],
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("floor").id, 0), val
),
),
"bthigh_damping": (
lambda env: env.model.dof_damping[env.model.joint("bthigh").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("bthigh").id, val
),
),
"bshin_damping": (
lambda env: env.model.dof_damping[env.model.joint("bshin").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("bshin").id, val
),
),
"bfoot_damping": (
lambda env: env.model.dof_damping[env.model.joint("bfoot").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("bfoot").id, val
),
),
"fthigh_damping": (
lambda env: env.model.dof_damping[env.model.joint("fthigh").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("fthigh").id, val
),
),
"fshin_damping": (
lambda env: env.model.dof_damping[env.model.joint("fshin").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("fshin").id, val
),
),
"ffeet_damping": (
lambda env: env.model.dof_damping[env.model.joint("ffoot").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("ffoot").id, val
),
),
},
"HopperEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"torso_mass": (
lambda env: env.model.body_mass[env.model.body("torso").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("torso").id, val
),
),
"thigh_mass": (
lambda env: env.model.body_mass[env.model.body("thigh").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("thigh").id, val
),
),
"leg_mass": (
lambda env: env.model.body_mass[env.model.body("leg").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("leg").id, val
),
),
"foot_mass": (
lambda env: env.model.body_mass[env.model.body("foot").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("foot").id, val
),
),
"floor_friction": (
lambda env: env.model.geom_friction[env.model.geom("floor").id, 0],
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("floor").id, 0), val
),
),
"thigh_joint_damping": (
lambda env: env.model.dof_damping[env.model.joint("thigh_joint").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("thigh_joint").id, val
),
),
"leg_joint_damping": (
lambda env: env.model.dof_damping[env.model.joint("leg_joint").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("leg_joint").id, val
),
),
"foot_joint_damping": (
lambda env: env.model.dof_damping[env.model.joint("foot_joint").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("foot_joint").id, val
),
),
},
"HumanoidEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"torso_mass": (
lambda env: env.model.body_mass[env.model.body("torso").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("torso").id, val
),
),
"lwaist_mass": (
lambda env: env.model.body_mass[env.model.body("lwaist").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("lwaist").id, val
),
),
"pelvis_mass": (
lambda env: env.model.body_mass[env.model.body("pelvis").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("pelvis").id, val
),
),
"right_thigh_mass": (
lambda env: env.model.body_mass[env.model.body("right_thigh").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("right_thigh").id, val
),
),
"left_thigh_mass": (
lambda env: env.model.body_mass[env.model.body("left_thigh").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("left_thigh").id, val
),
),
"right_shin_mass": (
lambda env: env.model.body_mass[env.model.body("right_shin").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("right_shin").id, val
),
),
"left_shin_mass": (
lambda env: env.model.body_mass[env.model.body("left_shin").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("left_shin").id, val
),
),
"right_foot_mass": (
lambda env: env.model.body_mass[env.model.body("right_foot").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("right_foot").id, val
),
),
"left_foot_mass": (
lambda env: env.model.body_mass[env.model.body("left_foot").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("left_foot").id, val
),
),
"right_upper_arm_mass": (
lambda env: env.model.body_mass[env.model.body("right_upper_arm").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("right_upper_arm").id, val
),
),
"left_upper_arm_mass": (
lambda env: env.model.body_mass[env.model.body("left_upper_arm").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("left_upper_arm").id, val
),
),
"right_lower_arm_mass": (
lambda env: env.model.body_mass[env.model.body("right_lower_arm").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("right_lower_arm").id, val
),
),
"left_lower_arm_mass": (
lambda env: env.model.body_mass[env.model.body("left_lower_arm").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("left_lower_arm").id, val
),
),
"floor_friction": (
lambda env: env.model.geom_friction[env.model.geom("floor").id, 0],
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("floor").id, 0), val
),
),
"right_knee_damping": (
lambda env: env.model.dof_damping[env.model.joint("right_knee").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("right_knee").id, val
),
),
"left_knee_damping": (
lambda env: env.model.dof_damping[env.model.joint("left_knee").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("left_knee").id, val
),
),
"right_elbow_damping": (
lambda env: env.model.dof_damping[env.model.joint("right_elbow").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("right_elbow").id, val
),
),
"left_elbow_damping": (
lambda env: env.model.dof_damping[env.model.joint("left_elbow").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("left_elbow").id, val
),
),
},
"InvertedPendulumEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"pole_mass": (
lambda env: env.model.body_mass[env.model.body("pole").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("pole").id, val
),
),
"cart_mass": (
lambda env: env.model.body_mass[env.model.body("cart").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("cart").id, val
),
),
"rail_friction": (
lambda env: env.model.geom_friction[env.model.geom("rail").id, 0],
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("rail").id, 0), val
),
),
},
"InvertedDoublePendulumEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"cart_mass": (
lambda env: env.model.body_mass[env.model.body("cart").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("cart").id, val
),
),
"pole1_mass": (
lambda env: env.model.body_mass[env.model.body("pole").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("pole").id, val
),
),
"pole2_mass": (
lambda env: env.model.body_mass[env.model.body("pole2").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("pole2").id, val
),
),
"floor_friction": (
lambda env: env.model.geom_friction[env.model.geom("rail").id, 0],
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("rail").id, 0), val
),
),
"slider_damping": (
lambda env: env.model.dof_damping[env.model.joint("slider").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("slider").id, val
),
),
"hinge1_damping": (
lambda env: env.model.dof_damping[env.model.joint("hinge").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("hinge").id, val
),
),
"hinge2_damping": (
lambda env: env.model.dof_damping[env.model.joint("hinge2").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("hinge2").id, val
),
),
},
"ReacherEnv": {
# "gravity": (
# lambda env: env.model.opt.gravity,
# lambda env, val: np.copyto(env.model.opt.gravity, val),
# ),
"body0_mass": (
lambda env: env.model.body_mass[env.model.body("body0").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("body0").id, val
),
),
"body1_mass": (
lambda env: env.model.body_mass[env.model.body("body1").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("body1").id, val
),
),
# "ground_friction": (
# lambda env: env.model.geom_friction[env.model.geom("ground").id, 0],
# lambda env, val: env.model.geom_friction.__setitem__(
# (env.model.geom("ground").id, 0), val
# ),
# ),
"joint0_damping": (
lambda env: env.model.dof_damping[env.model.joint("joint0").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("joint0").id, val
),
),
"joint1_damping": (
lambda env: env.model.dof_damping[env.model.joint("joint1").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("joint1").id, val
),
),
},
"SwimmerEnv": {
# "gravity": (
# lambda env: env.model.opt.gravity,
# lambda env, val: np.copyto(env.model.opt.gravity, val),
# ),
"body_mid_mass": (
lambda env: env.model.body_mass[env.model.body("mid").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("mid").id, val
),
),
# "geom_floor_friction": (
# lambda env: env.model.geom_friction[env.model.geom("floor").id, 0],
# lambda env, val: env.model.geom_friction.__setitem__(
# (env.model.geom("floor").id, 0), val
# ),
# ),
},
"PusherEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"r_shoulder_pan_link_mass": (
lambda env: env.model.body_mass[
env.model.body("r_shoulder_pan_link").id
],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("r_shoulder_pan_link").id, val
),
),
"r_shoulder_lift_link_mass": (
lambda env: env.model.body_mass[
env.model.body("r_shoulder_lift_link").id
],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("r_shoulder_lift_link").id, val
),
),
"r_upper_arm_link_mass": (
lambda env: env.model.body_mass[env.model.body("r_upper_arm_link").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("r_upper_arm_link").id, val
),
),
"r_forearm_link_mass": (
lambda env: env.model.body_mass[env.model.body("r_forearm_link").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("r_forearm_link").id, val
),
),
"r_shoulder_pan_joint_damping": (
lambda env: env.model.dof_damping[
env.model.joint("r_shoulder_pan_joint").id
],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("r_shoulder_pan_joint").id, val
),
),
"r_shoulder_lift_joint_damping": (
lambda env: env.model.dof_damping[
env.model.joint("r_shoulder_lift_joint").id
],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("r_shoulder_lift_joint").id, val
),
),
"r_elbow_flex_joint_damping": (
lambda env: env.model.dof_damping[
env.model.joint("r_elbow_flex_joint").id
],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("r_elbow_flex_joint").id, val
),
),
},
}
# HumanoidStandupEnv uses the same XML and attributes as HumanoidEnv
mappings["HumanoidStandupEnv"] = mappings["HumanoidEnv"]
if env_name in mappings:
env_mapping = mappings[env_name]
if tunable_param in env_mapping:
return [env_mapping[tunable_param]]
else:
raise ValueError(
f"Parameter '{tunable_param}' not recognized for environment '{env_name}'."
)
else:
raise ValueError(f"Environment '{env_name}' not recognized or supported.")
if __name__ == "__main__":
env = gym.make("Ant-v5", render_mode="human", max_episode_steps=1000)
# Define a real update function to make the Ant "floatier" over time
scheduler = schedulers.ContinuousScheduler(start=10, end=1000)
# The step size will reduce gravity's pull each step
updateFn = update_functions.StepWiseUpdate(
scheduler, [np.array([0, 0, -9.8]), np.array([0, 0, -1000.0])]
)
tunable_params = {"gravity": updateFn}
ns_env = MujocoWrapper(env, tunable_params, change_notification=True)
obs, info = ns_env.reset()
print(f"Initial gravity: {ns_env._get_param_value('gravity')}")
for i in range(100):
action = ns_env.action_space.sample()
obs, rew, done, truncated, info = ns_env.step(action)
# Print the gravity every 2 steps to see it change
if (i + 1) % 2 == 0:
print(
f"Gravity at step {i + 1}: {np.round(ns_env._get_param_value('gravity'), 2)}"
)
if done or truncated:
obs, info = ns_env.reset()
print("\n--- ENV RESET ---")
print(f"Gravity after reset: {ns_env._get_param_value('gravity')}\n")
ns_env.close()