import gymnasium as gym
import mujoco
import numpy as np
import warnings
from copy import deepcopy
from typing import Any
import ns_gym.base as base
import ns_gym.update_functions as update_functions
import ns_gym.schedulers as schedulers
from typing import Callable
class ConstraintViolationWarning(Warning):
"""Warning issued when a constraint in the application is violated."""
pass
[docs]
class MujocoWrapper(base.NSWrapper):
def __init__(
self,
env: base.Env,
tunable_params: dict,
change_notification: bool = False,
delta_change_notification: bool = False,
in_sim_change: bool = False,
**kwargs: Any,
):
super().__init__(
env=env,
tunable_params=tunable_params,
change_notification=change_notification,
delta_change_notification=delta_change_notification,
in_sim_change=in_sim_change,
**kwargs,
)
self.t = 0
self.delta_t = 1
self._accessors = {}
for key in tunable_params.keys():
self._accessors[key] = param_look_up(
self.unwrapped.__class__.__name__, key
)[0]
self.initial_params = {}
for key in tunable_params.keys():
self.initial_params[key] = deepcopy(self._get_param_value(key))
def _get_param_value(self, key: str) -> Any:
"""Gets a parameter value by calling its specific getter function."""
getter, _ = self._accessors[key]
return getter(self.unwrapped)
def _set_param_value(self, key: str, value: Any):
"""Sets a parameter value by calling its specific setter function."""
_, setter = self._accessors[key]
setter(self.unwrapped, value)
def _dependency_resolver(self):
"""Re-computes derived properties of the MuJoCo model after changes."""
mujoco.mj_forward(self.unwrapped.model, self.unwrapped.data)
def _constraint_checker(self, new_vals: dict) -> dict[str, bool]:
"""Checks if new parameter values violate physical constraints."""
constraint_dict = {key: False for key in new_vals.keys()}
for p, v in new_vals.items():
if "mass" in p and v <= 1e-6:
warnings.warn(
f"Mass for '{p}' must be positive, not updated.",
ConstraintViolationWarning,
)
constraint_dict[p] = True
elif "size" in p and np.any(np.array(v) <= 1e-6):
warnings.warn(
f"Size for '{p}' must have positive elements, not updated.",
ConstraintViolationWarning,
)
constraint_dict[p] = True
elif "damping" in p and v < 0:
warnings.warn(
f"Damping for '{p}' cannot be negative, not updated.",
ConstraintViolationWarning,
)
constraint_dict[p] = True
return constraint_dict
[docs]
def step(self, action: Any) -> tuple[Any, Any, bool, bool, dict[str, Any]]:
"""Applies physics changes and then steps the environment."""
if self.is_sim_env and not self.in_sim_change:
obs, reward, terminated, truncated, info = super().step(
action, env_change=None, delta_change=None
)
else:
env_change = {}
delta_change = {}
new_vals = {}
for p, fn in self.tunable_params.items():
if p == "gravity":
# special handeling for gravity
cur_val = self._get_param_value(p)
val_to_update = cur_val[-1] # Get the z-component of gravity
new_val, change_flag, delta = fn(val_to_update, self.t)
delta_change[p] = delta
env_change[p] = change_flag
new_vals[p] = np.array([cur_val[0], cur_val[1], new_val])
else:
cur_val = self._get_param_value(p)
new_val, change_flag, delta = fn(cur_val, self.t)
delta_change[p] = delta
env_change[p] = change_flag
new_vals[p] = new_val
for k, v in self._constraint_checker(new_vals).items():
if not v: # If not violated, update the parameter
self._set_param_value(k, new_vals[k])
else:
delta_change[k] = False
env_change[k] = False
self._dependency_resolver()
obs, reward, terminated, truncated, info = super().step(
action, env_change=env_change, delta_change=delta_change
)
self.t += self.delta_t
return obs, reward, terminated, truncated, info
[docs]
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
"""Reset environment and restore initial model parameters."""
obs, info = super().reset(seed=seed, options=options)
for k, v in self.initial_params.items():
self._set_param_value(k, deepcopy(v))
self._dependency_resolver()
self.t = 0
return obs, info
def param_look_up(env_name: str, tunable_param: str) -> tuple[Callable, Callable]:
"""Helper function to grab setter and getter functions for various MuJoCo environments.
Maps friendly parameter names to MuJoCo model attributes and indices.
Args:
env_name (str): Name of the MuJoCo environment class (e.g., "AntEnv").
tunable_param (str): Friendly name of the tunable parameter (e.g., "torso_mass").
Returns:
tuple[Callable, Callable]: A tuple containing two callables:
- A getter function that takes an env instance and returns the parameter value.
- A setter function that takes an env instance and a value, and sets the parameter.
"""
mappings = {
"AntEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"torso_mass": (
lambda env: env.model.body_mass[env.model.body("torso").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("torso").id, val
),
),
"floor_friction": (
lambda env: env.model.geom_friction[env.model.geom("floor").id, 0],
# Corrected line below
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("floor").id, 0), val
),
),
},
"HalfCheetahEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"torso_mass": (
lambda env: env.model.body_mass[env.model.body("torso").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("torso").id, val
),
),
"bthigh_mass": (
lambda env: env.model.body_mass[env.model.body("bthigh").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("bthigh").id, val
),
),
"bshin_mass": (
lambda env: env.model.body_mass[env.model.body("bshin").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("bshin").id, val
),
),
"bfoot_mass": (
lambda env: env.model.body_mass[env.model.body("bfoot").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("bfoot").id, val
),
),
"fthigh_mass": (
lambda env: env.model.body_mass[env.model.body("fthigh").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("fthigh").id, val
),
),
"fshin_mass": (
lambda env: env.model.body_mass[env.model.body("fshin").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("fshin").id, val
),
),
"ffeet_mass": (
lambda env: env.model.body_mass[env.model.body("ffoot").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("ffoot").id, val
),
),
"floor_friction": (
lambda env: env.model.geom_friction[env.model.geom("floor").id, 0],
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("floor").id, 0), val
),
),
"bthigh_damping": (
lambda env: env.model.dof_damping[env.model.joint("bthigh").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("bthigh").id, val
),
),
"bshin_damping": (
lambda env: env.model.dof_damping[env.model.joint("bshin").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("bshin").id, val
),
),
"bfoot_damping": (
lambda env: env.model.dof_damping[env.model.joint("bfoot").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("bfoot").id, val
),
),
"fthigh_damping": (
lambda env: env.model.dof_damping[env.model.joint("fthigh").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("fthigh").id, val
),
),
"fshin_damping": (
lambda env: env.model.dof_damping[env.model.joint("fshin").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("fshin").id, val
),
),
"ffeet_damping": (
lambda env: env.model.dof_damping[env.model.joint("ffoot").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("ffoot").id, val
),
),
},
"HopperEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"torso_mass": (
lambda env: env.model.body_mass[env.model.body("torso").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("torso").id, val
),
),
"thigh_mass": (
lambda env: env.model.body_mass[env.model.body("thigh").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("thigh").id, val
),
),
"leg_mass": (
lambda env: env.model.body_mass[env.model.body("leg").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("leg").id, val
),
),
"foot_mass": (
lambda env: env.model.body_mass[env.model.body("foot").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("foot").id, val
),
),
"floor_friction": (
lambda env: env.model.geom_friction[env.model.geom("floor").id, 0],
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("floor").id, 0), val
),
),
"thigh_joint_damping": (
lambda env: env.model.dof_damping[env.model.joint("thigh_joint").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("thigh_joint").id, val
),
),
"leg_joint_damping": (
lambda env: env.model.dof_damping[env.model.joint("leg_joint").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("leg_joint").id, val
),
),
"foot_joint_damping": (
lambda env: env.model.dof_damping[env.model.joint("foot_joint").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("foot_joint").id, val
),
),
},
"HumanoidEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"torso_mass": (
lambda env: env.model.body_mass[env.model.body("torso").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("torso").id, val
),
),
"lwaist_mass": (
lambda env: env.model.body_mass[env.model.body("lwaist").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("lwaist").id, val
),
),
"pelvis_mass": (
lambda env: env.model.body_mass[env.model.body("pelvis").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("pelvis").id, val
),
),
"right_thigh_mass": (
lambda env: env.model.body_mass[env.model.body("right_thigh").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("right_thigh").id, val
),
),
"left_thigh_mass": (
lambda env: env.model.body_mass[env.model.body("left_thigh").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("left_thigh").id, val
),
),
"right_shin_mass": (
lambda env: env.model.body_mass[env.model.body("right_shin").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("right_shin").id, val
),
),
"left_shin_mass": (
lambda env: env.model.body_mass[env.model.body("left_shin").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("left_shin").id, val
),
),
"right_foot_mass": (
lambda env: env.model.body_mass[env.model.body("right_foot").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("right_foot").id, val
),
),
"left_foot_mass": (
lambda env: env.model.body_mass[env.model.body("left_foot").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("left_foot").id, val
),
),
"right_upper_arm_mass": (
lambda env: env.model.body_mass[env.model.body("right_upper_arm").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("right_upper_arm").id, val
),
),
"left_upper_arm_mass": (
lambda env: env.model.body_mass[env.model.body("left_upper_arm").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("left_upper_arm").id, val
),
),
"right_lower_arm_mass": (
lambda env: env.model.body_mass[env.model.body("right_lower_arm").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("right_lower_arm").id, val
),
),
"left_lower_arm_mass": (
lambda env: env.model.body_mass[env.model.body("left_lower_arm").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("left_lower_arm").id, val
),
),
"floor_friction": (
lambda env: env.model.geom_friction[env.model.geom("floor").id, 0],
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("floor").id, 0), val
),
),
"right_knee_damping": (
lambda env: env.model.dof_damping[env.model.joint("right_knee").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("right_knee").id, val
),
),
"left_knee_damping": (
lambda env: env.model.dof_damping[env.model.joint("left_knee").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("left_knee").id, val
),
),
"right_elbow_damping": (
lambda env: env.model.dof_damping[env.model.joint("right_elbow").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("right_elbow").id, val
),
),
"left_elbow_damping": (
lambda env: env.model.dof_damping[env.model.joint("left_elbow").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("left_elbow").id, val
),
),
},
"InvertedPendulumEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"pole_mass": (
lambda env: env.model.body_mass[env.model.body("pole").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("pole").id, val
),
),
"cart_mass": (
lambda env: env.model.body_mass[env.model.body("cart").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("cart").id, val
),
),
"rail_friction": (
lambda env: env.model.geom_friction[env.model.geom("rail").id, 0],
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("rail").id, 0), val
),
),
},
"InvertedDoublePendulumEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"cart_mass": (
lambda env: env.model.body_mass[env.model.body("cart").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("cart").id, val
),
),
"pole1_mass": (
lambda env: env.model.body_mass[env.model.body("pole").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("pole").id, val
),
),
"pole2_mass": (
lambda env: env.model.body_mass[env.model.body("pole2").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("pole2").id, val
),
),
"floor_friction": (
lambda env: env.model.geom_friction[env.model.geom("rail").id, 0],
lambda env, val: env.model.geom_friction.__setitem__(
(env.model.geom("rail").id, 0), val
),
),
"slider_damping": (
lambda env: env.model.dof_damping[env.model.joint("slider").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("slider").id, val
),
),
"hinge1_damping": (
lambda env: env.model.dof_damping[env.model.joint("hinge").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("hinge").id, val
),
),
"hinge2_damping": (
lambda env: env.model.dof_damping[env.model.joint("hinge2").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("hinge2").id, val
),
),
},
"ReacherEnv": {
# "gravity": (
# lambda env: env.model.opt.gravity,
# lambda env, val: np.copyto(env.model.opt.gravity, val),
# ),
"body0_mass": (
lambda env: env.model.body_mass[env.model.body("body0").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("body0").id, val
),
),
"body1_mass": (
lambda env: env.model.body_mass[env.model.body("body1").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("body1").id, val
),
),
# "ground_friction": (
# lambda env: env.model.geom_friction[env.model.geom("ground").id, 0],
# lambda env, val: env.model.geom_friction.__setitem__(
# (env.model.geom("ground").id, 0), val
# ),
# ),
"joint0_damping": (
lambda env: env.model.dof_damping[env.model.joint("joint0").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("joint0").id, val
),
),
"joint1_damping": (
lambda env: env.model.dof_damping[env.model.joint("joint1").id],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("joint1").id, val
),
),
},
"SwimmerEnv": {
# "gravity": (
# lambda env: env.model.opt.gravity,
# lambda env, val: np.copyto(env.model.opt.gravity, val),
# ),
"body_mid_mass": (
lambda env: env.model.body_mass[env.model.body("mid").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("mid").id, val
),
),
# "geom_floor_friction": (
# lambda env: env.model.geom_friction[env.model.geom("floor").id, 0],
# lambda env, val: env.model.geom_friction.__setitem__(
# (env.model.geom("floor").id, 0), val
# ),
# ),
},
"PusherEnv": {
"gravity": (
lambda env: env.model.opt.gravity,
lambda env, val: np.copyto(env.model.opt.gravity, val),
),
"r_shoulder_pan_link_mass": (
lambda env: env.model.body_mass[
env.model.body("r_shoulder_pan_link").id
],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("r_shoulder_pan_link").id, val
),
),
"r_shoulder_lift_link_mass": (
lambda env: env.model.body_mass[
env.model.body("r_shoulder_lift_link").id
],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("r_shoulder_lift_link").id, val
),
),
"r_upper_arm_link_mass": (
lambda env: env.model.body_mass[env.model.body("r_upper_arm_link").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("r_upper_arm_link").id, val
),
),
"r_forearm_link_mass": (
lambda env: env.model.body_mass[env.model.body("r_forearm_link").id],
lambda env, val: env.model.body_mass.__setitem__(
env.model.body("r_forearm_link").id, val
),
),
"r_shoulder_pan_joint_damping": (
lambda env: env.model.dof_damping[
env.model.joint("r_shoulder_pan_joint").id
],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("r_shoulder_pan_joint").id, val
),
),
"r_shoulder_lift_joint_damping": (
lambda env: env.model.dof_damping[
env.model.joint("r_shoulder_lift_joint").id
],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("r_shoulder_lift_joint").id, val
),
),
"r_elbow_flex_joint_damping": (
lambda env: env.model.dof_damping[
env.model.joint("r_elbow_flex_joint").id
],
lambda env, val: env.model.dof_damping.__setitem__(
env.model.joint("r_elbow_flex_joint").id, val
),
),
},
}
# HumanoidStandupEnv uses the same XML and attributes as HumanoidEnv
mappings["HumanoidStandupEnv"] = mappings["HumanoidEnv"]
if env_name in mappings:
env_mapping = mappings[env_name]
if tunable_param in env_mapping:
return [env_mapping[tunable_param]]
else:
raise ValueError(
f"Parameter '{tunable_param}' not recognized for environment '{env_name}'."
)
else:
raise ValueError(f"Environment '{env_name}' not recognized or supported.")
if __name__ == "__main__":
env = gym.make("Ant-v5", render_mode="human", max_episode_steps=1000)
# Define a real update function to make the Ant "floatier" over time
scheduler = schedulers.ContinuousScheduler(start=10, end=1000)
# The step size will reduce gravity's pull each step
updateFn = update_functions.StepWiseUpdate(
scheduler, [np.array([0, 0, -9.8]), np.array([0, 0, -1000.0])]
)
tunable_params = {"gravity": updateFn}
ns_env = MujocoWrapper(env, tunable_params, change_notification=True)
obs, info = ns_env.reset()
print(f"Initial gravity: {ns_env._get_param_value('gravity')}")
for i in range(100):
action = ns_env.action_space.sample()
obs, rew, done, truncated, info = ns_env.step(action)
# Print the gravity every 2 steps to see it change
if (i + 1) % 2 == 0:
print(
f"Gravity at step {i + 1}: {np.round(ns_env._get_param_value('gravity'), 2)}"
)
if done or truncated:
obs, info = ns_env.reset()
print("\n--- ENV RESET ---")
print(f"Gravity after reset: {ns_env._get_param_value('gravity')}\n")
ns_env.close()